diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 70d01919..c5f6e175 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -99,10 +99,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ) # Add deepspeed config - ppo_config.accelerator_kwargs["kwargs_handlers"] = [ - DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters) - ] if training_args.deepspeed_plugin is not None: + ppo_config.accelerator_kwargs["kwargs_handlers"] = [ + DistributedDataParallelKwargs(find_unused_parameters=training_args.ddp_find_unused_parameters) + ] ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin # Create optimizer and scheduler