mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
improve aligner
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
|
||||
from ..extras.constants import DATA_CONFIG
|
||||
from ..extras.misc import use_modelscope
|
||||
@@ -13,38 +13,44 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class DatasetAttr:
|
||||
r"""
|
||||
Dataset attributes.
|
||||
"""
|
||||
|
||||
""" basic configs """
|
||||
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
""" extra configs """
|
||||
file_sha1: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
folder: Optional[str] = None
|
||||
ranking: Optional[bool] = False
|
||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||
|
||||
""" columns """
|
||||
system: Optional[str] = None
|
||||
|
||||
""" columns for the alpaca format """
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
|
||||
""" columns for the sharegpt format """
|
||||
messages: Optional[str] = "conversations"
|
||||
tools: Optional[str] = None
|
||||
|
||||
""" tags for the sharegpt format """
|
||||
role_tag: Optional[str] = "from"
|
||||
content_tag: Optional[str] = "value"
|
||||
user_tag: Optional[str] = "human"
|
||||
assistant_tag: Optional[str] = "gpt"
|
||||
observation_tag: Optional[str] = "observation"
|
||||
function_tag: Optional[str] = "function_call"
|
||||
system_tag: Optional[str] = None
|
||||
|
||||
assert system_tag is None or system is None, f"Can not provide both system message (system_tag={system_tag}) and system column(system={system})"
|
||||
|
||||
system_tag: Optional[str] = "system"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
||||
setattr(self, key, obj.get(key, default))
|
||||
|
||||
|
||||
def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
dataset_names = [ds.strip() for ds in data_args.dataset.split(",")] if data_args.dataset is not None else []
|
||||
@@ -77,30 +83,36 @@ def get_dataset_list(data_args: "DataArguments") -> List["DatasetAttr"]:
|
||||
elif "script_url" in dataset_info[name]:
|
||||
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
||||
else:
|
||||
dataset_attr = DatasetAttr(
|
||||
"file",
|
||||
dataset_name=dataset_info[name]["file_name"],
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None),
|
||||
)
|
||||
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[name]["file_name"])
|
||||
|
||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||
dataset_attr.folder = dataset_info[name].get("folder", None)
|
||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||
dataset_attr.set_attr("file_sha1", dataset_info[name])
|
||||
dataset_attr.set_attr("subset", dataset_info[name])
|
||||
dataset_attr.set_attr("folder", dataset_info[name])
|
||||
dataset_attr.set_attr("ranking", dataset_info[name], default=False)
|
||||
dataset_attr.set_attr("formatting", dataset_info[name], default="alpaca")
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
column_names = ["system"]
|
||||
if dataset_attr.formatting == "alpaca":
|
||||
column_names = ["prompt", "query", "response", "history"]
|
||||
column_names.extend(["prompt", "query", "response", "history"])
|
||||
else:
|
||||
column_names = ["messages", "tools"]
|
||||
column_names.extend(["messages", "tools"])
|
||||
|
||||
column_names += ["system"]
|
||||
for column_name in column_names:
|
||||
setattr(dataset_attr, column_name, dataset_info[name]["columns"].get(column_name, None))
|
||||
dataset_attr.set_attr(column_name, dataset_info[name]["columns"])
|
||||
|
||||
if dataset_attr.formatting == "sharegpt" and "tags" in dataset_info[name]:
|
||||
for tag in ["role_tag", "content_tag", "user_tag", "assistant_tag", "observation_tag", "function_tag", "system_tag"]:
|
||||
setattr(dataset_attr, tag, dataset_info[name]["tags"].get(tag, None))
|
||||
tag_names = (
|
||||
"role_tag",
|
||||
"content_tag",
|
||||
"user_tag",
|
||||
"assistant_tag",
|
||||
"observation_tag",
|
||||
"function_tag",
|
||||
"system_tag",
|
||||
)
|
||||
for tag in tag_names:
|
||||
dataset_attr.set_attr(tag, dataset_info[name]["tags"])
|
||||
|
||||
dataset_list.append(dataset_attr)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user