This commit is contained in:
hiyouga
2024-12-19 12:16:30 +00:00
parent ffbb4dbdb0
commit d4c1fda1ad
6 changed files with 22 additions and 16 deletions

View File

@@ -91,7 +91,7 @@ def run_sft(
)
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict()
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()
@@ -130,7 +130,7 @@ def run_sft(
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, gen_kwargs)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)