fix ChatGLM RLHF

This commit is contained in:
hiyouga
2023-08-15 11:19:20 +08:00
parent a7dd9611db
commit af6c011fcb
2 changed files with 4 additions and 0 deletions

View File

@@ -182,6 +182,8 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0):
values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards