Former-commit-id: 7ec64588c541422875adfdaf5692a27d05b96cb9
This commit is contained in:
hiyouga 2024-01-19 21:44:32 +08:00
parent 384f0e7678
commit 0868d5c550
4 changed files with 18 additions and 11 deletions

View File

@ -1,4 +1,5 @@
import os import os
import inspect
from typing import TYPE_CHECKING, List, Literal, Union from typing import TYPE_CHECKING, List, Literal, Union
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
@ -82,6 +83,11 @@ def load_single_dataset(
except ImportError: except ImportError:
raise ImportError("Please install modelscope via `pip install modelscope -U`") raise ImportError("Please install modelscope via `pip install modelscope -U`")
else: else:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset( dataset = load_dataset(
path=data_path, path=data_path,
name=data_name, name=data_name,
@ -90,7 +96,8 @@ def load_single_dataset(
split=data_args.split, split=data_args.split,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
streaming=(data_args.streaming and (dataset_attr.load_from != "file")) streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
**kwargs
) )
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True

View File

@ -32,8 +32,8 @@ class Template:
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: Optional[str] = None,
tools: str, tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000 cutoff_len: Optional[int] = 1_000_000
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
r""" r"""

View File

@ -4,6 +4,7 @@ import os
import json import json
import torch import torch
import numpy as np import numpy as np
import inspect
from tqdm import tqdm, trange from tqdm import tqdm, trange
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -53,13 +54,18 @@ class Evaluator:
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0) pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
results = {} results = {}
for subject in pbar: for subject in pbar:
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
kwargs = {"trust_remote_code": True}
else:
kwargs = {}
dataset = load_dataset( dataset = load_dataset(
path=os.path.join(self.eval_args.task_dir, self.eval_args.task), path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
name=subject, name=subject,
cache_dir=self.model_args.cache_dir, cache_dir=self.model_args.cache_dir,
download_mode=self.eval_args.download_mode, download_mode=self.eval_args.download_mode,
token=self.model_args.hf_hub_token, token=self.model_args.hf_hub_token,
trust_remote_code=True **kwargs
) )
pbar.set_postfix_str(categorys[subject]["name"]) pbar.set_postfix_str(categorys[subject]["name"])
inputs, outputs, labels = [], [], [] inputs, outputs, labels = [], [], []

View File

@ -27,8 +27,7 @@ class EvalTemplate:
self, self,
target_data: Dict[str, str], target_data: Dict[str, str],
support_set: "Dataset", support_set: "Dataset",
subject_name: str, subject_name: str
use_history: bool
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
messages = [] messages = []
for k in range(len(support_set)): for k in range(len(support_set)):
@ -39,12 +38,7 @@ class EvalTemplate:
prompt, response = self.parse_example(target_data) prompt, response = self.parse_example(target_data)
messages.append({"role": Role.USER, "content": prompt}) messages.append({"role": Role.USER, "content": prompt})
messages.append({"role": Role.ASSISTANT, "content": response}) messages.append({"role": Role.ASSISTANT, "content": response})
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"] messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
if not use_history:
messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
return messages return messages