support full-parameter PPO

This commit is contained in:
hiyouga
2023-11-16 02:08:04 +08:00
parent 8350bcf85d
commit ce78303600
20 changed files with 288 additions and 145 deletions

View File

@@ -9,8 +9,9 @@ from transformers.optimization import get_scheduler
from llmtuner.data import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import SavePeftModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.ploting import plot_loss
from llmtuner.model import load_model_and_tokenizer
from llmtuner.model import create_ref_model, create_reward_model, load_model_and_tokenizer
from llmtuner.train.ppo.trainer import CustomPPOTrainer
if TYPE_CHECKING:
@@ -18,6 +19,9 @@ if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
logger = get_logger(__name__)
def run_ppo(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -33,6 +37,11 @@ def run_ppo(
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo")
reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config
ppo_config = PPOConfig(
model_name=model_args.model_name_or_path,
learning_rate=training_args.learning_rate,
@@ -47,9 +56,11 @@ def run_ppo(
log_with=finetuning_args.ppo_logger,
use_score_scaling=finetuning_args.ppo_score_norm,
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False}
)
# Create optimizer and scheduler
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
if training_args.max_steps > 0:
num_training_steps = training_args.max_steps
@@ -73,9 +84,10 @@ def run_ppo(
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks + [SavePeftModelCallback()],
reward_model=reward_model,
config=ppo_config,
model=model,
ref_model=None,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dataset,
data_collator=data_collator,
@@ -88,5 +100,5 @@ def run_ppo(
ppo_trainer.ppo_train()
ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])