From 5030f05126e6667b46875877819a74cdc7109864 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 7 Sep 2023 19:12:40 +0800 Subject: [PATCH] add deepspeed check in PPO training Former-commit-id: ed1c2c5557bb2714c3341294f0ea86f6496d4b0c --- src/llmtuner/tuner/core/parser.py | 3 +++ src/llmtuner/tuner/ppo/trainer.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 9d4b68bd..1a9d673c 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -119,6 +119,9 @@ def get_train_args( if general_args.stage == "ppo" and model_args.reward_model is None: raise ValueError("Reward model is necessary for PPO training.") + if general_args.stage == "ppo" and training_args.deepspeed is not None: + raise ValueError("PPO training is incompatible with DeepSpeed, use Accelerate instead.") + if general_args.stage == "ppo" and data_args.streaming: raise ValueError("Streaming mode does not suppport PPO training currently.") diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 8e7204c3..21c8350d 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -39,6 +39,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): **kwargs ): PPOTrainer.__init__(self, **kwargs) + if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None: + raise ValueError("PPOTrainer is incompatible with DeepSpeed.") + self.args = training_args self.finetuning_args = finetuning_args self.generating_args = generating_args