# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Optional import numpy as np from ...extras.misc import numpify if TYPE_CHECKING: from transformers import EvalPrediction @dataclass class ComputeAccuracy: def __post_init__(self): self.score_dict = {"accuracy": []} def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) if not chosen_scores.shape: self.score_dict["accuracy"].append(chosen_scores > rejected_scores) else: for i in range(len(chosen_scores)): self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i]) if compute_result: return {"accuracy": float(np.mean(self.score_dict["accuracy"]))}