fix ppo trainer

Former-commit-id: a03b2e5ef0d5d6b1b27753438745385d290cb211
This commit is contained in:
hiyouga 2024-07-10 11:05:45 +08:00
parent 834c4e8ad9
commit 446129ca7a

View File

@ -106,7 +106,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
if ppo_config.log_with == "tensorboard": # tensorboard raises error about accelerator_kwargs
if ppo_config.log_with is not None:
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
ppo_config.log_with = None
# Create optimizer and scheduler