diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index acd78b0e..e6c3d0e3 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -250,7 +250,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): reward_model = self.reward_model if self.reward_model is not None else self.model _, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True) - if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2 + if getattr(unwrapped_model.config, "model_type", None) == "chatglm": values = torch.transpose(values, 0, 1) rewards = [] @@ -298,7 +298,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 logits, _, values = model(**input_kwargs) - if values.size(0) != input_ids.size(0): # adapt to chatglm2 + if getattr(model.config, "model_type", None) == "chatglm": values = torch.transpose(values, 0, 1) logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) diff --git a/src/llmtuner/train/rm/trainer.py b/src/llmtuner/train/rm/trainer.py index 45703bc5..9be64264 100644 --- a/src/llmtuner/train/rm/trainer.py +++ b/src/llmtuner/train/rm/trainer.py @@ -39,7 +39,8 @@ class PairwiseTrainer(Trainer): """ # Compute rewards _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) - if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2 + + if getattr(model.config, "model_type", None) == "chatglm": values = torch.transpose(values, 0, 1) # Split the inputs and rewards into two parts, chosen and rejected