[trainer] fix key error (#7635)

This commit is contained in:
Shawn Tao 2025-04-08 18:39:50 +08:00 committed by GitHub
parent f75b91077b
commit acb09fa3a3
5 changed files with 5 additions and 5 deletions

View File

@ -92,7 +92,7 @@ def run_dpo(
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss", "rewards/accuracies"] keys = ["loss", "rewards/accuracies"]
if isinstance(dataset_module["eval_dataset"], dict): if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else: else:
keys += ["eval_loss"] keys += ["eval_loss"]

View File

@ -83,7 +83,7 @@ def run_kto(
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss", "rewards/chosen"] keys = ["loss", "rewards/chosen"]
if isinstance(dataset_module["eval_dataset"], dict): if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else: else:
keys += ["eval_loss"] keys += ["eval_loss"]

View File

@ -67,7 +67,7 @@ def run_pt(
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"] keys = ["loss"]
if isinstance(dataset_module["eval_dataset"], dict): if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else: else:
keys += ["eval_loss"] keys += ["eval_loss"]

View File

@ -75,7 +75,7 @@ def run_rm(
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"] keys = ["loss"]
if isinstance(dataset_module["eval_dataset"], dict): if isinstance(dataset_module.get("eval_dataset"), dict):
keys += sum( keys += sum(
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], [] [[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
) )

View File

@ -111,7 +111,7 @@ def run_sft(
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"] keys = ["loss"]
if isinstance(dataset_module["eval_dataset"], dict): if isinstance(dataset_module.get("eval_dataset"), dict):
keys += sum( keys += sum(
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], [] [[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
) )