mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-19 13:20:36 +08:00
@@ -26,8 +26,16 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
def __post_init__(self):
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||
|
||||
self.score_dict = {"accuracy": []}
|
||||
return result
|
||||
|
||||
def __post_init__(self):
|
||||
self._dump()
|
||||
|
||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
|
||||
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
||||
@@ -38,4 +46,4 @@ class ComputeAccuracy:
|
||||
self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i])
|
||||
|
||||
if compute_result:
|
||||
return {"accuracy": float(np.mean(self.score_dict["accuracy"]))}
|
||||
return self._dump()
|
||||
|
||||
Reference in New Issue
Block a user