support prefixes, loading multiple local files

This commit is contained in:
hiyouga
2023-06-26 15:32:40 +08:00
parent f030b09924
commit cec9760eb8
3 changed files with 84 additions and 39 deletions

View File

@@ -10,14 +10,11 @@ class DatasetAttr:
load_from: str
dataset_name: Optional[str] = None
file_name: Optional[str] = None
file_sha1: Optional[str] = None
dataset_sha1: Optional[str] = None
source_prefix: Optional[str] = None
def __repr__(self) -> str:
if self.dataset_name is not None:
return self.dataset_name
else:
return self.file_name
return self.dataset_name
def __post_init__(self):
self.prompt_column = "instruction"
@@ -137,7 +134,7 @@ class DataTrainingArguments:
)
source_prefix: Optional[str] = field(
default=None,
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes."}
)
dev_ratio: Optional[float] = field(
default=0,
@@ -153,8 +150,15 @@ class DataTrainingArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f)
if self.source_prefix is not None:
prefix_list = self.source_prefix.split("|")
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
else:
prefix_list = [None] * len(dataset_names)
self.dataset_list: List[DatasetAttr] = []
for name in dataset_names:
for i, name in enumerate(dataset_names):
if name not in dataset_info:
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
@@ -165,10 +169,12 @@ class DataTrainingArguments:
else:
dataset_attr = DatasetAttr(
"file",
file_name=dataset_info[name]["file_name"],
file_sha1=dataset_info[name].get("file_sha1", None)
dataset_name=dataset_info[name]["file_name"],
dataset_sha1=dataset_info[name].get("file_sha1", None)
)
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]:
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)