Former-commit-id: ce77d98872fa377fd4bc961701b07982f4b51491
This commit is contained in:
hiyouga 2024-04-03 14:47:59 +08:00
parent 42ccb9b073
commit 88d9f47a0b
2 changed files with 2 additions and 2 deletions

View File

@ -120,7 +120,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
if (
finetuning_args.stage == "ppo"
and training_args.report_to is not None
and training_args.report_to
and training_args.report_to[0] not in ["wandb", "tensorboard"]
):
raise ValueError("PPO only accepts wandb or tensorboard logger.")

View File

@ -66,7 +66,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
use_score_norm=finetuning_args.ppo_score_norm,
whiten_rewards=finetuning_args.ppo_whiten_rewards,
accelerator_kwargs={"step_scheduler_with_optimizer": False},
log_with=training_args.report_to[0] if training_args.report_to is not None else None,
log_with=training_args.report_to[0] if training_args.report_to else None,
project_kwargs={"logging_dir": training_args.logging_dir},
)