diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 9d4b68bd..1a9d673c 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -119,6 +119,9 @@ def get_train_args( if general_args.stage == "ppo" and model_args.reward_model is None: raise ValueError("Reward model is necessary for PPO training.") + if general_args.stage == "ppo" and training_args.deepspeed is not None: + raise ValueError("PPO training is incompatible with DeepSpeed, use Accelerate instead.") + if general_args.stage == "ppo" and data_args.streaming: raise ValueError("Streaming mode does not suppport PPO training currently.") diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 8e7204c3..21c8350d 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -39,6 +39,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): **kwargs ): PPOTrainer.__init__(self, **kwargs) + if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None: + raise ValueError("PPOTrainer is incompatible with DeepSpeed.") + self.args = training_args self.finetuning_args = finetuning_args self.generating_args = generating_args