From a9ab8f71d7a0d8ab9c41cfad7afa83450f6d41dd Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 15 Aug 2023 11:19:20 +0800 Subject: [PATCH] fix ChatGLM RLHF Former-commit-id: af6c011fcb8ea9e5cf2eb4699da33d8668df04b4 --- src/llmtuner/tuner/ppo/trainer.py | 2 ++ src/llmtuner/tuner/rm/trainer.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 26a8db99..cc73854f 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -182,6 +182,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): replace_model(unwrapped_model, target="reward") batch = self.prepare_model_inputs(queries, responses) _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) + if values.size(0) != batch["input_ids"].size(0): + values = torch.transpose(values, 0, 1) rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type replace_model(unwrapped_model, target="default") return rewards diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py index e69d48a8..55790c07 100644 --- a/src/llmtuner/tuner/rm/trainer.py +++ b/src/llmtuner/tuner/rm/trainer.py @@ -42,6 +42,8 @@ class PairwisePeftTrainer(PeftTrainer): """ batch_size = inputs["input_ids"].size(0) // 2 _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) + if values.size(0) != inputs["input_ids"].size(0): + values = torch.transpose(values, 0, 1) r_accept, r_reject = values[:, -1].split(batch_size, dim=0) loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() return (loss, [loss, r_accept, r_reject]) if return_outputs else loss