fix generating args

This commit is contained in:
hiyouga
2023-06-13 01:33:56 +08:00
parent cec6524d6b
commit 531a3764d9
5 changed files with 20 additions and 16 deletions

View File

@@ -30,8 +30,8 @@ def main():
# Override the decoding parameters of Seq2SeqTrainer
training_args.generation_max_length = training_args.generation_max_length if \
training_args.generation_max_length is not None else data_args.max_target_length
training_args.generation_num_beams = data_args.num_beams if \
data_args.num_beams is not None else training_args.generation_num_beams
training_args.generation_num_beams = data_args.eval_num_beams if \
data_args.eval_num_beams is not None else training_args.generation_num_beams
# Split the dataset
if training_args.do_train: