From ba6d290d0bcdfa9a9d37828a911804dc4b77e308 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 30 Nov 2023 21:02:00 +0800 Subject: [PATCH] fix #1668 Former-commit-id: 1585962eb7ed042890d4c56422aae749c669dda8 --- src/llmtuner/train/ppo/trainer.py | 3 ++- src/llmtuner/train/rm/trainer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index e6c3d0e3..1bba733b 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -298,7 +298,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer): with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 logits, _, values = model(**input_kwargs) - if getattr(model.config, "model_type", None) == "chatglm": + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + if getattr(unwrapped_model.config, "model_type", None) == "chatglm": values = torch.transpose(values, 0, 1) logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) diff --git a/src/llmtuner/train/rm/trainer.py b/src/llmtuner/train/rm/trainer.py index 9be64264..b018a8c4 100644 --- a/src/llmtuner/train/rm/trainer.py +++ b/src/llmtuner/train/rm/trainer.py @@ -40,7 +40,8 @@ class PairwiseTrainer(Trainer): # Compute rewards _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) - if getattr(model.config, "model_type", None) == "chatglm": + unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model) + if getattr(unwrapped_model.config, "model_type", None) == "chatglm": values = torch.transpose(values, 0, 1) # Split the inputs and rewards into two parts, chosen and rejected