diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index c6f48298..85158c2c 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -77,12 +77,22 @@ def run_pt( # Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval") - try: - perplexity = math.exp(metrics["eval_loss"]) - except OverflowError: - perplexity = float("inf") - metrics["perplexity"] = perplexity + if isinstance(dataset_module.get("eval_dataset"), dict): + for key in dataset_module["eval_dataset"].keys(): + try: + perplexity = math.exp(metrics[f"eval_{key}_loss"]) + except OverflowError: + perplexity = float("inf") + metrics[f"eval_{key}_perplexity"] = perplexity + else: + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + metrics["eval_perplexity"] = perplexity + + trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics)