From 0b1c20eada21d8c58b51bae8d55c3f2eefd8147c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 12 Feb 2024 21:07:46 +0800 Subject: [PATCH] fix #2471 Former-commit-id: 12b2066e342e68c241dd98015d59148c122cffa8 --- src/llmtuner/data/aligner.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/data/aligner.py b/src/llmtuner/data/aligner.py index a982ec32..fbf3a32d 100644 --- a/src/llmtuner/data/aligner.py +++ b/src/llmtuner/data/aligner.py @@ -1,6 +1,8 @@ from functools import partial from typing import TYPE_CHECKING, Any, Dict, List, Union +from datasets import Features + from .utils import Role @@ -100,6 +102,18 @@ def align_dataset( convert_func = partial(convert_sharegpt, dataset_attr=dataset_attr) column_names = list(next(iter(dataset)).keys()) + features = Features.from_dict( + { + "prompt": [ + {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} + ], + "response": [ + {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} + ], + "system": {"dtype": "string", "_type": "Value"}, + "tools": {"dtype": "string", "_type": "Value"}, + } + ) kwargs = {} if not data_args.streaming: kwargs = dict( @@ -108,4 +122,10 @@ def align_dataset( desc="Converting format of dataset", ) - return dataset.map(convert_func, batched=True, remove_columns=column_names, **kwargs) + return dataset.map( + convert_func, + batched=True, + remove_columns=column_names, + features=features, + **kwargs, + )