mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 21:52:50 +08:00
parent
e199967391
commit
397eb8fbb9
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import inspect
|
||||||
from typing import TYPE_CHECKING, List, Literal, Union
|
from typing import TYPE_CHECKING, List, Literal, Union
|
||||||
|
|
||||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
|
||||||
@ -82,6 +83,11 @@ def load_single_dataset(
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
||||||
else:
|
else:
|
||||||
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||||
|
kwargs = {"trust_remote_code": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=data_path,
|
path=data_path,
|
||||||
name=data_name,
|
name=data_name,
|
||||||
@ -90,7 +96,8 @@ def load_single_dataset(
|
|||||||
split=data_args.split,
|
split=data_args.split,
|
||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
|
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||||
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||||
|
@ -32,8 +32,8 @@ class Template:
|
|||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
system: str,
|
system: Optional[str] = None,
|
||||||
tools: str,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: Optional[int] = 1_000_000
|
cutoff_len: Optional[int] = 1_000_000
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -4,6 +4,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import inspect
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
@ -53,13 +54,18 @@ class Evaluator:
|
|||||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||||
results = {}
|
results = {}
|
||||||
for subject in pbar:
|
for subject in pbar:
|
||||||
|
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
||||||
|
kwargs = {"trust_remote_code": True}
|
||||||
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
path=os.path.join(self.eval_args.task_dir, self.eval_args.task),
|
||||||
name=subject,
|
name=subject,
|
||||||
cache_dir=self.model_args.cache_dir,
|
cache_dir=self.model_args.cache_dir,
|
||||||
download_mode=self.eval_args.download_mode,
|
download_mode=self.eval_args.download_mode,
|
||||||
token=self.model_args.hf_hub_token,
|
token=self.model_args.hf_hub_token,
|
||||||
trust_remote_code=True
|
**kwargs
|
||||||
)
|
)
|
||||||
pbar.set_postfix_str(categorys[subject]["name"])
|
pbar.set_postfix_str(categorys[subject]["name"])
|
||||||
inputs, outputs, labels = [], [], []
|
inputs, outputs, labels = [], [], []
|
||||||
|
@ -27,8 +27,7 @@ class EvalTemplate:
|
|||||||
self,
|
self,
|
||||||
target_data: Dict[str, str],
|
target_data: Dict[str, str],
|
||||||
support_set: "Dataset",
|
support_set: "Dataset",
|
||||||
subject_name: str,
|
subject_name: str
|
||||||
use_history: bool
|
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
messages = []
|
messages = []
|
||||||
for k in range(len(support_set)):
|
for k in range(len(support_set)):
|
||||||
@ -39,12 +38,7 @@ class EvalTemplate:
|
|||||||
prompt, response = self.parse_example(target_data)
|
prompt, response = self.parse_example(target_data)
|
||||||
messages.append({"role": Role.USER, "content": prompt})
|
messages.append({"role": Role.USER, "content": prompt})
|
||||||
messages.append({"role": Role.ASSISTANT, "content": response})
|
messages.append({"role": Role.ASSISTANT, "content": response})
|
||||||
|
|
||||||
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
messages[0]["content"] = self.system.format(subject=subject_name) + messages[0]["content"]
|
||||||
|
|
||||||
if not use_history:
|
|
||||||
messages = [{"role": Role.USER, "content": "\n\n".join([message["content"] for message in messages[:-1]])}]
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user