fix ChatGLM2 ppo #527 #528

Former-commit-id: 9f4c2adc9a
This commit is contained in:
hiyouga
2023-08-18 00:34:59 +08:00
parent 623a34b16f
commit caf4a61e21
6 changed files with 72 additions and 11 deletions

View File

@@ -42,7 +42,7 @@ class PairwisePeftTrainer(PeftTrainer):
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0):
if values.size(0) != inputs["input_ids"].size(0): # adapt chatglm2
values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()