mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 08:08:09 +08:00
9 lines
267 B
Python
9 lines
267 B
Python
from typing import Dict, Sequence, Tuple, Union
|
|
|
|
import numpy as np
|
|
|
|
|
|
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
|
preds, _ = eval_preds
|
|
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|