Former-commit-id: 5dd2e5c3323f56420b5845a5ed28bcd9d4da5e41
This commit is contained in:
hiyouga 2024-07-01 05:43:17 +08:00
parent 4357e42391
commit 973cf8e980

View File

@ -48,7 +48,7 @@ def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]:
preds, labels = eval_preds.predictions, eval_preds.label_ids
accuracies = []
for i in range(len(preds)):
pred, label = preds[i, 1:], labels[i, :-1]
pred, label = preds[i, :-1], labels[i, 1:]
label_mask = label != IGNORE_INDEX
accuracies.append(np.mean(pred[label_mask] == label[label_mask]))