mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
support prefixes, loading multiple local files
This commit is contained in:
@@ -10,14 +10,11 @@ class DatasetAttr:
|
||||
|
||||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
file_name: Optional[str] = None
|
||||
file_sha1: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
source_prefix: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.dataset_name is not None:
|
||||
return self.dataset_name
|
||||
else:
|
||||
return self.file_name
|
||||
return self.dataset_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt_column = "instruction"
|
||||
@@ -137,7 +134,7 @@ class DataTrainingArguments:
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes."}
|
||||
)
|
||||
dev_ratio: Optional[float] = field(
|
||||
default=0,
|
||||
@@ -153,8 +150,15 @@ class DataTrainingArguments:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if self.source_prefix is not None:
|
||||
prefix_list = self.source_prefix.split("|")
|
||||
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
||||
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
||||
else:
|
||||
prefix_list = [None] * len(dataset_names)
|
||||
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
for i, name in enumerate(dataset_names):
|
||||
if name not in dataset_info:
|
||||
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
||||
|
||||
@@ -165,10 +169,12 @@ class DataTrainingArguments:
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
file_name=dataset_info[name]["file_name"],
|
||||
file_sha1=dataset_info[name].get("file_sha1", None)
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
dataset_attr.source_prefix = prefix_list[i]
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
||||
|
||||
Reference in New Issue
Block a user