Former-commit-id: aab14b15268dbe74ded22549dbd3677474868cbb
This commit is contained in:
hiyouga 2024-06-26 19:52:35 +08:00
parent cf2dc4c444
commit 72ba29d81a

View File

@ -99,10 +99,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
)
# Add deepspeed config
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
if training_args.deepspeed_plugin is not None:
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
# Create optimizer and scheduler