Former-commit-id: 12043aab9c9ea2acaa3ff80232ce0451ccaa557d
This commit is contained in:
hiyouga 2024-01-19 21:44:32 +08:00
parent e199967391
commit 397eb8fbb9
4 changed files with 18 additions and 11 deletions

View File

@ -1,4 +1,5 @@
import os import os
import inspect
from typing import TYPE_CHECKING, List, Literal, Union from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
@ -82,6 +83,11 @@ def load_single_dataset(
except ImportError: except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`") raise ImportError("Please install modelscope via `pip install modelscope -U`")
else: else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
name=data_name, name=data_name,
@ -90,7 +96,8 @@ def load_single_dataset(
split=data_args.split, split=data_args.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")) streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
**kwargs
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True

View File

@ -32,8 +32,8 @@ class Template:
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: Optional[str] = None,
tools: str, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000 cutoff_len: Optional[int] = 1_000_000
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""

View File

@ -4,6 +4,7 @@ import os
import json import json
import torch import torch
import numpy as np import numpy as np
import inspect
from tqdm import tqdm, trange from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -53,13 +54,18 @@ class Evaluator:
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {} results = {}
for subject in pbar: for subject in pbar:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset( dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task), path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject, name=subject,
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode, download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token, token=self.model_args.hf_hub_token,
trust_remote_code=True **kwargs
) )
pbar.set_postfix_str(categorys[subject]["name"]) pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], [] inputs, outputs, labels = [], [], []

View File

@ -27,8 +27,7 @@ class EvalTemplate:
self, self,
target_data: Dict[str, str], target_data: Dict[str, str],
support_set: "Dataset", support_set: "Dataset",
subject_name: str, subject_name: str
use_history: bool
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
messages = [] messages = []
for k in range(len(support_set)): for k in range(len(support_set)):
@ -39,12 +38,7 @@ class EvalTemplate:
prompt, response = self.parse_example(target_data) prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt}) messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response}) messages.append({"role": Role.ASSISTANT, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
if not use_history:
messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
return messages return messages