fix ChatGLM lm_head #494

This commit is contained in:
hiyouga
2023-08-14 14:14:48 +08:00
parent 20a29297b1
commit d019956808
3 changed files with 12 additions and 8 deletions

View File

@@ -32,11 +32,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
if hasattr(self, "accelerator"):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
else:
raise AttributeError("Please update `transformers`.")
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,