This commit is contained in:
hiyouga
2024-02-12 21:07:46 +08:00
parent 91d09a01ac
commit 12b2066e34

View File

@@ -1,6 +1,8 @@
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Union from typing import TYPE_CHECKING, Any, Dict, List, Union
from datasets import Features
from .utils import Role from .utils import Role
@@ -100,6 +102,18 @@ def align_dataset(
convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr)
column_names = list(next(iter(dataset)).keys()) column_names = list(next(iter(dataset)).keys())
features = Features.from_dict(
{
"prompt": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"response": [
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
],
"system": {"dtype": "string", "_type": "Value"},
"tools": {"dtype": "string", "_type": "Value"},
}
)
kwargs = {} kwargs = {}
if not data_args.streaming: if not data_args.streaming:
kwargs = dict( kwargs = dict(
@@ -108,4 +122,10 @@ def align_dataset(
desc="Converting format of dataset", desc="Converting format of dataset",
) )
return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs) return dataset.map(
convert_func,
batched=True,
remove_columns=column_names,
features=features,
**kwargs,
)