Update parser.py

Former-commit-id: 51dd454337941801d0a66eaadb0da2e007e9573d
This commit is contained in:
hoshi-hiyouga 2024-05-30 00:05:20 +08:00 committed by GitHub
parent 21e7979837
commit 69a51cacb1

View File

@ -20,11 +20,12 @@ class DatasetAttr:
""" basic configs """ """ basic configs """
load_from: Literal["hf_hub", "ms_hub", "script", "file"] load_from: Literal["hf_hub", "ms_hub", "script", "file"]
dataset_name: str dataset_name: str
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
ranking: bool = False
""" extra configs """ """ extra configs """
subset: Optional[str] = None subset: Optional[str] = None
folder: Optional[str] = None folder: Optional[str] = None
ranking: bool = False num_samples: Optional[int] = None
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
""" common columns """ """ common columns """
system: Optional[str] = None system: Optional[str] = None
tools: Optional[str] = None tools: Optional[str] = None
@ -48,7 +49,6 @@ class DatasetAttr:
observation_tag: Optional[str] = "observation" observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call" function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system" system_tag: Optional[str] = "system"
sample_num: Optional[int] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
@ -103,12 +103,12 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
else: else:
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"]) dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
dataset_attr.set_attr("subset", dataset_info[name]) dataset_attr.set_attr("subset", dataset_info[name])
dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("folder", dataset_info[name])
dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("num_samples", dataset_info[name])
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
dataset_attr.set_attr("sample_num", dataset_info[name])
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":