mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
fix generation in seq2seq.py
This commit is contained in:
@@ -25,7 +25,10 @@ def main():
|
||||
dataset = prepare_data(model_args, data_args)
|
||||
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
|
||||
data_collator = DynamicDataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate)
|
||||
)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length if \
|
||||
|
||||
Reference in New Issue
Block a user