mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 15:20:36 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -39,7 +39,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
def __init__(self, args: Optional[dict[str, Any]] = None) -> None:
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.tokenizer = load_tokenizer(self.model_args)["tokenizer"]
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
@@ -69,7 +69,7 @@ class Evaluator:
|
||||
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
|
||||
def batch_inference(self, batch_input: dict[str, "torch.Tensor"]) -> list[str]:
|
||||
logits = self.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
@@ -88,7 +88,7 @@ class Evaluator:
|
||||
)
|
||||
|
||||
with open(mapping, encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
categorys: dict[str, dict[str, str]] = json.load(f)
|
||||
|
||||
category_corrects = {subj: np.array([], dtype="bool") for subj in SUBJECTS}
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
@@ -136,7 +136,7 @@ class Evaluator:
|
||||
pbar.close()
|
||||
self._save_results(category_corrects, results)
|
||||
|
||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||
def _save_results(self, category_corrects: dict[str, "NDArray"], results: dict[str, dict[int, str]]) -> None:
|
||||
score_info = "\n".join(
|
||||
[
|
||||
f"{category_name:>15}: {100 * np.mean(category_correct):.2f}"
|
||||
|
||||
@@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Sequence, Tuple
|
||||
|
||||
from ..data import Role
|
||||
from ..extras.constants import CHOICES
|
||||
@@ -25,20 +25,19 @@ class EvalTemplate:
|
||||
choice: str
|
||||
answer: str
|
||||
|
||||
def _parse_example(self, example: Dict[str, str]) -> Tuple[str, str]:
|
||||
r"""
|
||||
def _parse_example(self, example: dict[str, str]) -> tuple[str, str]:
|
||||
r"""Parse eval example.
|
||||
|
||||
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
|
||||
output: a tuple of (prompt, response)
|
||||
output: a tuple of (prompt, response).
|
||||
"""
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in CHOICES if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self, target_data: Dict[str, str], support_set: Sequence[Dict[str, str]], subject_name: str
|
||||
) -> List[Dict[str, str]]:
|
||||
r"""
|
||||
Converts dataset examples to messages.
|
||||
"""
|
||||
self, target_data: dict[str, str], support_set: Sequence[dict[str, str]], subject_name: str
|
||||
) -> list[dict[str, str]]:
|
||||
r"""Convert dataset examples to messages."""
|
||||
messages = []
|
||||
for k in range(len(support_set)):
|
||||
prompt, response = self._parse_example(support_set[k])
|
||||
@@ -52,7 +51,7 @@ class EvalTemplate:
|
||||
return messages
|
||||
|
||||
|
||||
eval_templates: Dict[str, "EvalTemplate"] = {}
|
||||
eval_templates: dict[str, "EvalTemplate"] = {}
|
||||
|
||||
|
||||
def _register_eval_template(name: str, system: str, choice: str, answer: str) -> None:
|
||||
|
||||
Reference in New Issue
Block a user