From 16e26236eb1ee2743e2a7afafcd938dd3af06a3f Mon Sep 17 00:00:00 2001 From: "Ma, Xiaochen" Date: Mon, 19 May 2025 18:01:26 +0800 Subject: [PATCH] [trainer] fix KeyError at end of pretrain (#8099) --- src/llamafactory/train/pt/workflow.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index c6f48298..85158c2c 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -77,12 +77,22 @@ def run_pt( # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") - metrics["perplexity"] = perplexity + if isinstance(dataset_module.get("eval_dataset"), dict): + for key in dataset_module["eval_dataset"].keys(): + try: + perplexity = math.exp(metrics[f"eval_{key}_loss"]) + except OverflowError: + perplexity = float("inf") + metrics[f"eval_{key}_perplexity"] = perplexity + else: + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + metrics["eval_perplexity"] = perplexity + + trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics)