diff --git a/src/llmtuner/hparams/parser.py b/src/llmtuner/hparams/parser.py index 4fbc3db9..9264d1ee 100644 --- a/src/llmtuner/hparams/parser.py +++ b/src/llmtuner/hparams/parser.py @@ -120,7 +120,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if ( finetuning_args.stage == "ppo" - and training_args.report_to is not None + and training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"] ): raise ValueError("PPO only accepts wandb or tensorboard logger.") diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index 6be45958..020d54cf 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -66,7 +66,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): use_score_norm=finetuning_args.ppo_score_norm, whiten_rewards=finetuning_args.ppo_whiten_rewards, accelerator_kwargs={"step_scheduler_with_optimizer": False}, - log_with=training_args.report_to[0] if training_args.report_to is not None else None, + log_with=training_args.report_to[0] if training_args.report_to else None, project_kwargs={"logging_dir": training_args.logging_dir}, )