diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index 458e99db..3f7f7af5 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -10,7 +10,7 @@ from llmtuner.tuner.core.trainer import PeftModelMixin if TYPE_CHECKING: from transformers import PreTrainedModel - from llmtuner.hparams import FinetuningArguments, GeneratingArguments + from llmtuner.hparams import FinetuningArguments class DPOPeftTrainer(PeftModelMixin, DPOTrainer): @@ -18,12 +18,10 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer): def __init__( self, finetuning_args: "FinetuningArguments", - generating_args: "GeneratingArguments", ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, **kwargs ): self.finetuning_args = finetuning_args - self.generating_args = generating_args self.ref_model = ref_model self.use_dpo_data_collator = True # hack to avoid warning self.label_pad_token_id = IGNORE_INDEX diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/tuner/dpo/workflow.py index af184ce5..4f41d4c6 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/tuner/dpo/workflow.py @@ -13,7 +13,7 @@ from llmtuner.tuner.dpo.trainer import DPOPeftTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback - from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments + from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments def run_dpo( @@ -21,7 +21,6 @@ def run_dpo( data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", - generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None ): dataset = get_dataset(model_args, data_args) @@ -38,7 +37,6 @@ def run_dpo( # Initialize our Trainer trainer = DPOPeftTrainer( finetuning_args=finetuning_args, - generating_args=generating_args, ref_model=ref_model, model=model, args=training_args,