fix decoding in seq2seq

This commit is contained in:
hiyouga
2023-06-27 19:33:08 +08:00
parent 33f2141507
commit 1c732e2537
2 changed files with 17 additions and 20 deletions

View File

@@ -80,15 +80,19 @@ def main():
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
metrics.pop("eval_loss", None)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(predict_results, tokenizer)
trainer.save_predictions(predict_results)
def _mp_fn(index):