fix RM accuracy

Former-commit-id: 7826a8ca7722b138e79b13c42b1070771f6d5994
This commit is contained in:
hiyouga 2023-06-28 01:40:13 +08:00
parent 204541b56c
commit c3cd2067b2

View File

@ -13,8 +13,7 @@ logger = get_logger(__name__)
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds preds, _ = eval_preds
preds = np.array(preds) return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
return {"accuracy": (preds[:, 0] > preds[:, 1]).sum() / len(preds)}
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding): class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
@ -49,9 +48,13 @@ class PairwisePeftTrainer(PeftTrainer):
We use score on the EOS token to represent reward of the whole sentence. We use score on the EOS token to represent reward of the whole sentence.
Subclass and override to inject custom behavior. It should not be directly used by external scripts. Subclass and override to inject custom behavior. It should not be directly used by external scripts.
Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
""" """
batch_size = inputs["input_ids"].size(0) // 2 batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs) _, _, values = model(**inputs)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0) r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, torch.stack((r_accept, r_reject), dim=-1)) if return_outputs else loss return (loss, [loss, r_accept, r_reject]) if return_outputs else loss