Files
LLaMA-Factory/src/llmtuner/train/rm/metric.py
2023-11-15 16:29:09 +08:00

8 lines
266 B
Python

import numpy as np
from typing import Dict, Sequence, Tuple, Union
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
preds, _ = eval_preds
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}