From 08296f40927a4432f55ec92b26189b95e70f9bb7 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 2 Jul 2024 17:34:56 +0800 Subject: [PATCH] fix ppo callbacks Former-commit-id: 4c296001c4b77b814e4bd6cb4049a279718cb775 --- src/llamafactory/train/ppo/trainer.py | 6 +++--- src/llamafactory/train/ppo/workflow.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 57f0b848..8b89e38a 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -70,7 +70,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", - callbacks: List["TrainerCallback"], + callbacks: Optional[List["TrainerCallback"]], model: "AutoModelForCausalLMWithValueHead", reward_model: Optional["AutoModelForCausalLMWithValueHead"], ref_model: Optional["AutoModelForCausalLMWithValueHead"], @@ -78,7 +78,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): processor: Optional["ProcessorMixin"], dataset: "Dataset", data_collator: "DataCollatorWithPadding", - ): + ) -> None: backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps ppo_config = PPOConfig( model_name=model_args.model_name_or_path, @@ -144,7 +144,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None self.callback_handler = CallbackHandler( - [callbacks], self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler + callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler ) if self.args.max_steps > 0: diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 651296f3..df22dae5 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -22,7 +22,7 @@ from transformers import DataCollatorWithPadding from ...data import get_dataset from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..callbacks import FixValueHeadModelCallback, fix_valuehead_checkpoint +from ..callbacks import fix_valuehead_checkpoint from ..trainer_utils import create_ref_model, create_reward_model from .trainer import CustomPPOTrainer @@ -59,7 +59,7 @@ def run_ppo( training_args=training_args, finetuning_args=finetuning_args, generating_args=generating_args, - callbacks=callbacks + [FixValueHeadModelCallback()], + callbacks=callbacks, model=model, reward_model=reward_model, ref_model=ref_model,