diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index d06c4c66..97262ad5 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -92,7 +92,7 @@ def run_dpo( trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: keys = ["loss", "rewards/accuracies"] - if isinstance(dataset_module["eval_dataset"], dict): + if isinstance(dataset_module.get("eval_dataset"), dict): keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] else: keys += ["eval_loss"] diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index 3a9a6a76..7b16d1d0 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -83,7 +83,7 @@ def run_kto( trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: keys = ["loss", "rewards/chosen"] - if isinstance(dataset_module["eval_dataset"], dict): + if isinstance(dataset_module.get("eval_dataset"), dict): keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] else: keys += ["eval_loss"] diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index da5583e2..c6f48298 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -67,7 +67,7 @@ def run_pt( trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: keys = ["loss"] - if isinstance(dataset_module["eval_dataset"], dict): + if isinstance(dataset_module.get("eval_dataset"), dict): keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] else: keys += ["eval_loss"] diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index 81f555dd..18d562e8 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -75,7 +75,7 @@ def run_rm( trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: keys = ["loss"] - if isinstance(dataset_module["eval_dataset"], dict): + if isinstance(dataset_module.get("eval_dataset"), dict): keys += sum( [[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], [] ) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index b80cdfc4..cf6f80a9 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -111,7 +111,7 @@ def run_sft( trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: keys = ["loss"] - if isinstance(dataset_module["eval_dataset"], dict): + if isinstance(dataset_module.get("eval_dataset"), dict): keys += sum( [[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], [] )