From 6ef2ec71edd95504ea23324bcf7d048c3a08cc0b Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 26 Apr 2024 03:35:39 +0800 Subject: [PATCH] Update parser.py Former-commit-id: f62cadb258f85ddf6b595c8f6fe3385c2e0e3376 --- src/llmtuner/data/parser.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/llmtuner/data/parser.py b/src/llmtuner/data/parser.py index 4d3d7741..01a417a9 100644 --- a/src/llmtuner/data/parser.py +++ b/src/llmtuner/data/parser.py @@ -25,9 +25,10 @@ class DatasetAttr: subset: Optional[str] = None folder: Optional[str] = None ranking: bool = False - formatting: Literal["alpaca", "sharegpt", "llava"] = "alpaca" + formatting: Literal["alpaca", "sharegpt"] = "alpaca" """ columns """ system: Optional[str] = None + images: Optional[str] = None """ columns for the alpaca format """ prompt: Optional[str] = "instruction" query: Optional[str] = "input" @@ -44,8 +45,6 @@ class DatasetAttr: observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" - """ columns for the mllm format """ - images: Optional[str] = None def __repr__(self) -> str: return self.dataset_name @@ -105,21 +104,18 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]: dataset_attr.set_attr("folder", dataset_info[name]) dataset_attr.set_attr("ranking", dataset_info[name], default=False) dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca") - dataset_attr.set_attr("images", dataset_info[name], default="") if "columns" in dataset_info[name]: - column_names = ["system"] + column_names = ["system", "images"] if dataset_attr.formatting == "alpaca": column_names.extend(["prompt", "query", "response", "history"]) - elif dataset_attr.formatting == "llava": - column_names.extend(["messages", "images"]) else: column_names.extend(["messages", "tools"]) for column_name in column_names: dataset_attr.set_attr(column_name, dataset_info[name]["columns"]) - if dataset_attr.formatting != "alpaca" and "tags" in dataset_info[name]: + if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]: tag_names = ( "role_tag", "content_tag",