Update aligner.py

Former-commit-id: fcd09112d5929f5a4c502c5886050083e69122f9
This commit is contained in:
hoshi-hiyouga 2024-04-26 03:48:34 +08:00 committed by GitHub
parent 6ef2ec71ed
commit 15b7182418

View File

@ -1,3 +1,4 @@
import os
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Union
@ -13,8 +14,10 @@ if TYPE_CHECKING:
from .parser import DatasetAttr from .parser import DatasetAttr
def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: def convert_alpaca(
outputs = {"prompt": [], "response": [], "system": [], "tools": []} examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
for i in range(len(examples[dataset_attr.prompt])): for i in range(len(examples[dataset_attr.prompt])):
prompt = [] prompt = []
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
@ -44,11 +47,18 @@ def convert_alpaca(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr")
outputs["response"].append(response) outputs["response"].append(response)
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
outputs["tools"].append("") outputs["tools"].append("")
outputs["images"].append([]) outputs["images"].append(
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
if dataset_attr.images
else []
)
return outputs return outputs
def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr") -> Dict[str, List[Any]]: def convert_sharegpt(
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
) -> Dict[str, List[Any]]:
outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []} outputs = {"prompt": [], "response": [], "system": [], "tools": [], "images": []}
tag_mapping = { tag_mapping = {
dataset_attr.user_tag: Role.USER.value, dataset_attr.user_tag: Role.USER.value,
@ -84,7 +94,11 @@ def convert_sharegpt(examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr"
outputs["response"].append(aligned_messages[-1:]) outputs["response"].append(aligned_messages[-1:])
outputs["system"].append(system) outputs["system"].append(system)
outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "") outputs["tools"].append(examples[dataset_attr.tools][i] if dataset_attr.tools else "")
outputs["images"].append(examples[dataset_attr.images][i] if dataset_attr.images else []) outputs["images"].append(
[os.path.join(data_args.dataset_dir, path) for path in examples[dataset_attr.images][i]]
if dataset_attr.images
else []
)
return outputs return outputs
@ -97,12 +111,13 @@ def align_dataset(
prompt: [{"role": "user", "content": "..."}] * (2T - 1) prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..." system: "..."
tools: "..." tools: "...",
images: [],
""" """
if dataset_attr.formatting == "alpaca": if dataset_attr.formatting == "alpaca":
convert_func = partial(convert_alpaca, dataset_attr=dataset_attr) convert_func = partial(convert_alpaca, dataset_attr=dataset_attr, data_args=data_args)
else: else:
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr, data_args=data_args)
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
features = Features.from_dict( features = Features.from_dict(
@ -115,7 +130,7 @@ def align_dataset(
], ],
"system": {"dtype": "string", "_type": "Value"}, "system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"}, "tools": {"dtype": "string", "_type": "Value"},
"images": {"feature": {"_type": "Image"}, "_type": "Sequence"}, "images": [{"_type": "Image"}],
} }
) )
kwargs = {} kwargs = {}