Former-commit-id: 8d6cd69ac43afd4bd7c14bd02b0061455827ac9e
This commit is contained in:
hiyouga 2024-06-26 19:52:35 +08:00
parent 654116c0b1
commit 28e613efd0

View File

@ -99,10 +99,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
) )
# Add deepspeed config # Add deepspeed config
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
if training_args.deepspeed_plugin is not None: if training_args.deepspeed_plugin is not None:
ppo_config.accelerator_kwargs["kwargs_handlers"] = [
DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters)
]
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
# Create optimizer and scheduler # Create optimizer and scheduler