Former-commit-id: 12b2066e342e68c241dd98015d59148c122cffa8
This commit is contained in:
hiyouga 2024-02-12 21:07:46 +08:00
parent 75adbfec79
commit 0b1c20eada

View File

@ -1,6 +1,8 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from .utils import Role
@ -100,6 +102,18 @@ def align_dataset(
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
}
)
kwargs = {}
if not data_args.streaming:
kwargs = dict(
@ -108,4 +122,10 @@ def align_dataset(
desc="Converting format of dataset",
)
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs)
return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
features=features,
**kwargs,
)