mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
support eval remote dataset
Former-commit-id: 2d42be32c1b32b26548ea5af5fc3c810f4d668c1
This commit is contained in:
parent
42bb8b6400
commit
5c4ddebde5
@ -6,9 +6,11 @@ import torch
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from datasets import load_dataset
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
from llmtuner.eval.constants import CHOICES, SUBJECTS
|
from llmtuner.eval.constants import CHOICES, SUBJECTS
|
||||||
from llmtuner.eval.parser import get_eval_args
|
from llmtuner.eval.parser import get_eval_args
|
||||||
from llmtuner.eval.template import get_eval_template
|
from llmtuner.eval.template import get_eval_template
|
||||||
@ -20,8 +22,8 @@ from llmtuner.tuner.core import load_model_and_tokenizer
|
|||||||
class Evaluator:
|
class Evaluator:
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
|
||||||
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
|
||||||
@ -45,7 +47,13 @@ class Evaluator:
|
|||||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
mapping = os.path.join(self.eval_args.task_dir, self.eval_args.task, "mapping.json")
|
mapping = cached_file(
|
||||||
|
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
|
filename="mapping.json",
|
||||||
|
cache_dir=self.model_args.cache_dir,
|
||||||
|
token=self.model_args.hf_hub_token,
|
||||||
|
revision=self.model_args.model_revision
|
||||||
|
)
|
||||||
with open(mapping, "r", encoding="utf-8") as f:
|
with open(mapping, "r", encoding="utf-8") as f:
|
||||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user