[trainer] fix key error (#7635)

This commit is contained in:
Shawn Tao 2025-04-08 18:39:50 +08:00 committed by GitHub
parent 8ee26642f3
commit 8f5f4cc559
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 5 additions and 5 deletions

View File

@ -92,7 +92,7 @@ def run_dpo(
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss", "rewards/accuracies"]
if isinstance(dataset_module["eval_dataset"], dict):
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]

View File

@ -83,7 +83,7 @@ def run_kto(
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss", "rewards/chosen"]
if isinstance(dataset_module["eval_dataset"], dict):
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]

View File

@ -67,7 +67,7 @@ def run_pt(
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module["eval_dataset"], dict):
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]

View File

@ -75,7 +75,7 @@ def run_rm(
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module["eval_dataset"], dict):
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += sum(
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
)

View File

@ -111,7 +111,7 @@ def run_sft(
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module["eval_dataset"], dict):
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += sum(
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()], []
)