support full-parameter PPO

This commit is contained in:
hiyouga
2023-11-16 02:08:04 +08:00
parent 8350bcf85d
commit ce78303600
20 changed files with 288 additions and 145 deletions

View File

@@ -43,7 +43,11 @@ class CustomDPOTrainer(DPOTrainer):
if ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
if not (
getattr(ref_model, "is_loaded_in_8bit", False)
or getattr(ref_model, "is_loaded_in_4bit", False)
): # quantized models are already set on the correct device
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)