fix ppo trainer

Former-commit-id: fb0c40011689b3ae84cc3b258bf3c66af3e1e430
This commit is contained in:
hiyouga 2024-07-10 11:05:45 +08:00
parent aa15ca1719
commit d7130ec635

View File

@ -106,7 +106,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
if ppo_config.log_with == "tensorboard": # tensorboard raises error about accelerator_kwargs
if ppo_config.log_with is not None:
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
ppo_config.log_with = None
# Create optimizer and scheduler