diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py index 4d3d7741..01a417a9 100644 --- a/src/llmtuner/data/parser.py +++ b/src/llmtuner/data/parser.py @@ -25,9 +25,10 @@ class DatasetAttr: subset: Optional[str] = None folder: Optional[str] = None ranking: bool = False - formatting: Literal["alpaca", "sharegpt", "llava"] = "alpaca" + formatting: Literal["alpaca", "sharegpt"] = "alpaca" """ columns """ system: Optional[str] = None + images: Optional[str] = None """ columns for the alpaca format """ prompt: Optional[str] = "instruction" query: Optional[str] = "input" @@ -44,8 +45,6 @@ class DatasetAttr: observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" - """ columns for the mllm format """ - images: Optional[str] = None def __repr__(self) -> str: return self.dataset_name @@ -105,21 +104,18 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") - dataset_attr.set_attr("images", dataset_info[name], default="") if "columns" in dataset_info[name]: - column_names = ["system"] + column_names = ["system", "images"] if dataset_attr.formatting == "alpaca": column_names.extend(["prompt", "query", "response", "history"]) - elif dataset_attr.formatting == "llava": - column_names.extend(["messages", "images"]) else: column_names.extend(["messages", "tools"]) for column_name in column_names: dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) - if dataset_attr.formatting != "alpaca" and "tags" in dataset_info[name]: + if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: tag_names = ( "role_tag", "content_tag",