mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
add deepspeed check in PPO training
This commit is contained in:
@@ -39,6 +39,9 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
|
||||
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")
|
||||
|
||||
self.args = training_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.generating_args = generating_args
|
||||
|
||||
Reference in New Issue
Block a user