mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
support eval remote dataset
Former-commit-id: 2d42be32c1b32b26548ea5af5fc3c810f4d668c1
This commit is contained in:
parent
42bb8b6400
commit
5c4ddebde5
@ -6,9 +6,11 @@ import torch
|
||||
import tiktoken
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
from datasets import load_dataset
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers.utils import cached_file
|
||||
|
||||
from llmtuner.eval.constants import CHOICES, SUBJECTS
|
||||
from llmtuner.eval.parser import get_eval_args
|
||||
from llmtuner.eval.template import get_eval_template
|
||||
@ -20,8 +22,8 @@ from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
class Evaluator:
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
||||
@ -45,7 +47,13 @@ class Evaluator:
|
||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||
|
||||
def eval(self) -> None:
|
||||
mapping = os.path.join(self.eval_args.task_dir, self.eval_args.task, "mapping.json")
|
||||
mapping = cached_file(
|
||||
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||
filename="mapping.json",
|
||||
cache_dir=self.model_args.cache_dir,
|
||||
token=self.model_args.hf_hub_token,
|
||||
revision=self.model_args.model_revision
|
||||
)
|
||||
with open(mapping, "r", encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user