[trainer] fix KeyError at end of pretrain (#8099)

This commit is contained in:
Ma, Xiaochen 2025-05-19 18:01:26 +08:00 committed by GitHub
parent 89a0d10c18
commit 16e26236eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -77,12 +77,22 @@ def run_pt(
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["perplexity"] = perplexity
if isinstance(dataset_module.get("eval_dataset"), dict):
for key in dataset_module["eval_dataset"].keys():
try:
perplexity = math.exp(metrics[f"eval_{key}_loss"])
except OverflowError:
perplexity = float("inf")
metrics[f"eval_{key}_perplexity"] = perplexity
else:
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["eval_perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)