mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
merge data part to the text stream
This commit is contained in:
@@ -25,7 +25,7 @@ class DatasetAttr:
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: bool = False
|
||||
formatting: Literal["alpaca", "sharegpt"] = "alpaca"
|
||||
formatting: Literal["alpaca", "sharegpt", "llava"] = "alpaca"
|
||||
""" columns """
|
||||
system: Optional[str] = None
|
||||
""" columns for the alpaca format """
|
||||
@@ -44,11 +44,15 @@ class DatasetAttr:
|
||||
observation_tag: Optional[str] = "observation"
|
||||
function_tag: Optional[str] = "function_call"
|
||||
system_tag: Optional[str] = "system"
|
||||
""" columns for the mllm format """
|
||||
images: Optional[str] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
def set_attr(
|
||||
self, key: str, obj: Dict[str, Any], default: Optional[Any] = None
|
||||
) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
|
||||
@@ -67,12 +71,16 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
except Exception as err:
|
||||
if len(dataset_names) != 0:
|
||||
raise ValueError(
|
||||
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
||||
"Cannot open {} due to {}.".format(
|
||||
os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)
|
||||
)
|
||||
)
|
||||
dataset_info = None
|
||||
|
||||
if data_args.interleave_probs is not None:
|
||||
data_args.interleave_probs = [float(prob.strip()) for prob in data_args.interleave_probs.split(",")]
|
||||
data_args.interleave_probs = [
|
||||
float(prob.strip()) for prob in data_args.interleave_probs.split(",")
|
||||
]
|
||||
|
||||
dataset_list: List[DatasetAttr] = []
|
||||
for name in dataset_names:
|
||||
@@ -90,31 +98,42 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
|
||||
if has_hf_url or has_ms_url:
|
||||
if (use_modelscope() and has_ms_url) or (not has_hf_url):
|
||||
dataset_attr = DatasetAttr("ms_hub", dataset_name=dataset_info[name]["ms_hub_url"])
|
||||
dataset_attr = DatasetAttr(
|
||||
"ms_hub", dataset_name=dataset_info[name]["ms_hub_url"]
|
||||
)
|
||||
else:
|
||||
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
||||
dataset_attr = DatasetAttr(
|
||||
"hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]
|
||||
)
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
dataset_attr = DatasetAttr(
|
||||
"script", dataset_name=dataset_info[name]["script_url"]
|
||||
)
|
||||
else:
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
dataset_attr = DatasetAttr(
|
||||
"file", dataset_name=dataset_info[name]["file_name"]
|
||||
)
|
||||
|
||||
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
dataset_attr.set_attr("images", dataset_info[name], default="")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
elif dataset_attr.formatting == "llava":
|
||||
column_names.extend(["messages", "images"])
|
||||
else:
|
||||
column_names.extend(["messages", "tools"])
|
||||
|
||||
for column_name in column_names:
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
|
||||
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||
if dataset_attr.formatting != "alpaca" and "tags" in dataset_info[name]:
|
||||
tag_names = (
|
||||
"role_tag",
|
||||
"content_tag",
|
||||
|
||||
Reference in New Issue
Block a user