Update parser.py

Former-commit-id: 3d39d74003c4ca36f9c9b77f622d366383b0af7e
This commit is contained in:
hoshi-hiyouga 2024-07-14 23:04:34 +08:00 committed by GitHub
parent ddbd848e49
commit 140b512426

View File

@ -38,9 +38,9 @@ class DatasetAttr:
ranking: bool = False ranking: bool = False
# extra configs # extra configs
subset: Optional[str] = None subset: Optional[str] = None
split: str = "train"
folder: Optional[str] = None folder: Optional[str] = None
num_samples: Optional[int] = None num_samples: Optional[int] = None
split: Optional[str] = "train"
# common columns # common columns
system: Optional[str] = None system: Optional[str] = None
tools: Optional[str] = None tools: Optional[str] = None
@ -72,7 +72,7 @@ class DatasetAttr:
setattr(self, key, obj.get(key, default)) setattr(self, key, obj.get(key, default))
def get_dataset_list(data_args: "DataArguments", dataset: "str" = None) -> List["DatasetAttr"]: def get_dataset_list(data_args: "DataArguments", dataset: Optional[str]) -> List["DatasetAttr"]:
if dataset is not None: if dataset is not None:
dataset_names = [ds.strip() for ds in dataset.split(",")] dataset_names = [ds.strip() for ds in dataset.split(",")]
else: else:
@ -121,10 +121,9 @@ def get_dataset_list(data_args: "DataArguments", dataset: "str" = None) -> List[
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("split", dataset_info[name], default="train")
dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("num_samples", dataset_info[name]) dataset_attr.set_attr("num_samples", dataset_info[name])
if "split" in dataset_info[name]:
dataset_attr.set_attr("split", dataset_info[name])
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]