mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
parent
42ccb9b073
commit
88d9f47a0b
@ -120,7 +120,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.stage == "ppo"
|
finetuning_args.stage == "ppo"
|
||||||
and training_args.report_to is not None
|
and training_args.report_to
|
||||||
and training_args.report_to[0] not in ["wandb", "tensorboard"]
|
and training_args.report_to[0] not in ["wandb", "tensorboard"]
|
||||||
):
|
):
|
||||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||||
|
@ -66,7 +66,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
use_score_norm=finetuning_args.ppo_score_norm,
|
use_score_norm=finetuning_args.ppo_score_norm,
|
||||||
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
whiten_rewards=finetuning_args.ppo_whiten_rewards,
|
||||||
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
accelerator_kwargs={"step_scheduler_with_optimizer": False},
|
||||||
log_with=training_args.report_to[0] if training_args.report_to is not None else None,
|
log_with=training_args.report_to[0] if training_args.report_to else None,
|
||||||
project_kwargs={"logging_dir": training_args.logging_dir},
|
project_kwargs={"logging_dir": training_args.logging_dir},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user