From 0e6f4f981e30e68284eb27e00b08eea3d82c1af0 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 28 Nov 2023 20:57:24 +0800 Subject: [PATCH] fix #1658 Former-commit-id: 3126687c4820c34daa6a2e9e3bf9065ad59e92dc --- src/llmtuner/train/ppo/trainer.py | 4 ++-- src/llmtuner/train/rm/trainer.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index acd78b0e..e6c3d0e3 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -250,7 +250,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): reward_model = self.reward_model if self.reward_model is not None else self.model _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True) - if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2 + if getattr(unwrapped_model.config, "model_type", None) == "chatglm": values = torch.transpose(values, 0, 1) rewards = [] @@ -298,7 +298,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 logits, _, values = model(**input_kwargs) - if values.size(0) != input_ids.size(0): # adapt to chatglm2 + if getattr(model.config, "model_type", None) == "chatglm": values = torch.transpose(values, 0, 1) logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) diff --git a/src/llmtuner/train/rm/trainer.py b/src/llmtuner/train/rm/trainer.py index 45703bc5..9be64264 100644 --- a/src/llmtuner/train/rm/trainer.py +++ b/src/llmtuner/train/rm/trainer.py @@ -39,7 +39,8 @@ class PairwiseTrainer(Trainer): """ # Compute rewards _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) - if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2 + + if getattr(model.config, "model_type", None) == "chatglm": values = torch.transpose(values, 0, 1) # Split the inputs and rewards into two parts, chosen and rejected