diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 05942780..d45571d2 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -39,7 +39,7 @@ def run_sft( # Override the decoding parameters of Seq2SeqTrainer training_args_dict = training_args.to_dict() training_args_dict.update(dict( - generation_max_length=training_args.generation_max_length or data_args.max_target_length, + generation_max_length=training_args.generation_max_length or data_args.cutoff_len, generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams )) training_args = Seq2SeqTrainingArguments(**training_args_dict)