mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
Update parser.py
Former-commit-id: f62cadb258f85ddf6b595c8f6fe3385c2e0e3376
This commit is contained in:
parent
4487a4a790
commit
6ef2ec71ed
@ -25,9 +25,10 @@ class DatasetAttr:
|
|||||||
subset: Optional[str] = None
|
subset: Optional[str] = None
|
||||||
folder: Optional[str] = None
|
folder: Optional[str] = None
|
||||||
ranking: bool = False
|
ranking: bool = False
|
||||||
formatting: Literal["alpaca", "sharegpt", "llava"] = "alpaca"
|
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||||
""" columns """
|
""" columns """
|
||||||
system: Optional[str] = None
|
system: Optional[str] = None
|
||||||
|
images: Optional[str] = None
|
||||||
""" columns for the alpaca format """
|
""" columns for the alpaca format """
|
||||||
prompt: Optional[str] = "instruction"
|
prompt: Optional[str] = "instruction"
|
||||||
query: Optional[str] = "input"
|
query: Optional[str] = "input"
|
||||||
@ -44,8 +45,6 @@ class DatasetAttr:
|
|||||||
observation_tag: Optional[str] = "observation"
|
observation_tag: Optional[str] = "observation"
|
||||||
function_tag: Optional[str] = "function_call"
|
function_tag: Optional[str] = "function_call"
|
||||||
system_tag: Optional[str] = "system"
|
system_tag: Optional[str] = "system"
|
||||||
""" columns for the mllm format """
|
|
||||||
images: Optional[str] = None
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return self.dataset_name
|
return self.dataset_name
|
||||||
@ -105,21 +104,18 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
|||||||
dataset_attr.set_attr("folder", dataset_info[name])
|
dataset_attr.set_attr("folder", dataset_info[name])
|
||||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||||
dataset_attr.set_attr("images", dataset_info[name], default="")
|
|
||||||
|
|
||||||
if "columns" in dataset_info[name]:
|
if "columns" in dataset_info[name]:
|
||||||
column_names = ["system"]
|
column_names = ["system", "images"]
|
||||||
if dataset_attr.formatting == "alpaca":
|
if dataset_attr.formatting == "alpaca":
|
||||||
column_names.extend(["prompt", "query", "response", "history"])
|
column_names.extend(["prompt", "query", "response", "history"])
|
||||||
elif dataset_attr.formatting == "llava":
|
|
||||||
column_names.extend(["messages", "images"])
|
|
||||||
else:
|
else:
|
||||||
column_names.extend(["messages", "tools"])
|
column_names.extend(["messages", "tools"])
|
||||||
|
|
||||||
for column_name in column_names:
|
for column_name in column_names:
|
||||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||||
|
|
||||||
if dataset_attr.formatting != "alpaca" and "tags" in dataset_info[name]:
|
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||||
tag_names = (
|
tag_names = (
|
||||||
"role_tag",
|
"role_tag",
|
||||||
"content_tag",
|
"content_tag",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user