fix ChatGLM RLHF

This commit is contained in:
hiyouga
2023-08-15 11:19:20 +08:00
parent a7dd9611db
commit af6c011fcb
2 changed files with 4 additions and 0 deletions

View File

@@ -42,6 +42,8 @@ class PairwisePeftTrainer(PeftTrainer):
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0):
values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss