support eval remote dataset

Former-commit-id: 2d42be32c1b32b26548ea5af5fc3c810f4d668c1
This commit is contained in:
hiyouga 2023-11-14 02:42:30 +08:00
parent 42bb8b6400
commit 5c4ddebde5

View File

@ -6,9 +6,11 @@ import torch
import tiktoken
import numpy as np
from tqdm import tqdm, trange
from datasets import load_dataset
from typing import Any, Dict, List, Optional
from datasets import load_dataset
from transformers.utils import cached_file
from llmtuner.eval.constants import CHOICES, SUBJECTS
from llmtuner.eval.parser import get_eval_args
from llmtuner.eval.template import get_eval_template
@ -20,8 +22,8 @@ from llmtuner.tuner.core import load_model_and_tokenizer
class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(self.model_args, finetuning_args)
self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(self.data_args.template, self.tokenizer)
@ -45,7 +47,13 @@ class Evaluator:
return [chr(ord("A") + offset.item()) for offset in torch.argmax(choice_probs, dim=-1)]
def eval(self) -> None:
mapping = os.path.join(self.eval_args.task_dir, self.eval_args.task, "mapping.json")
mapping = cached_file(
path_or_repo_id = os.path.join(self.eval_args.task_dir, self.eval_args.task),
filename="mapping.json",
cache_dir=self.model_args.cache_dir,
token=self.model_args.hf_hub_token,
revision=self.model_args.model_revision
)
with open(mapping, "r", encoding="utf-8") as f:
categorys: Dict[str, Dict[str, str]] = json.load(f)