fix RM accuracy

Former-commit-id: 532a385ea60693fdf835e6bc8e240ff8d55ff3a7
This commit is contained in:
hiyouga 2023-06-28 01:40:13 +08:00
parent eca15bf252
commit 4ae8a20e1d

View File

@ -13,8 +13,7 @@ logger = get_logger(__name__)
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds
preds = np.array(preds)
return {"accuracy": (preds[:, 0] > preds[:, 1]).sum() / len(preds)}
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
@ -49,9 +48,13 @@ class PairwisePeftTrainer(PeftTrainer):
We use score on the EOS token to represent reward of the whole sentence.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
"""
batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, torch.stack((r_accept, r_reject), dim=-1)) if return_outputs else loss
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss