This commit is contained in:
hiyouga
2023-08-21 18:16:11 +08:00
parent 02d69b6fde
commit 5235b15c91
2 changed files with 2 additions and 6 deletions

View File

@@ -13,7 +13,7 @@ from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
def run_dpo(
@@ -21,7 +21,6 @@ def run_dpo(
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
@@ -38,7 +37,6 @@ def run_dpo(
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
generating_args=generating_args,
ref_model=ref_model,
model=model,
args=training_args,