mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
parent
42ccb9b073
commit
88d9f47a0b
@ -120,7 +120,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
if (
|
||||
finetuning_args.stage == "ppo"
|
||||
and training_args.report_to is not None
|
||||
and training_args.report_to
|
||||
and training_args.report_to[0] not in ["wandb", "tensorboard"]
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
|
@ -66,7 +66,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
use_score_norm=finetuning_args.ppo_score_norm,
|
||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||
log_with=training_args.report_to[0] if training_args.report_to is not None else None,
|
||||
log_with=training_args.report_to[0] if training_args.report_to else None,
|
||||
project_kwargs={"logging_dir": training_args.logging_dir},
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user