[v1] add data converter (#9263)

This commit is contained in:
Yaowei Zheng
2025-10-13 15:54:47 +08:00
committed by GitHub
parent 48974783da
commit 52e46e162e
7 changed files with 266 additions and 62 deletions

View File

@@ -12,8 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from ..config.training_args import TrainingArguments
from ..extras.types import DataCollator, Model, Processor, TorchDataset
from ..extras.types import Model, Processor, Tensor, TorchDataset
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor) -> None:
self.processor = processor
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
"""Collate features into a batch."""
for feature in features:
pass
# sft: messages
# dpo: chosen_messages, rejected_messages
class BaseTrainer:

View File

@@ -22,19 +22,7 @@ from omegaconf import OmegaConf
from torch.utils.data import Dataset
from ..config.data_args import DataArguments
from ..extras.types import DatasetInfo, HFDataset, Processor, Tensor
from ..plugins.data_plugins.loader import DataGetItemPlugin, DataIndexPlugin, DataLoaderPlugin
class DataCollator:
"""Default Data collator."""
def __init__(self, processor: Processor, dataset_info: dict[str, DatasetInfo]) -> None:
self.processor = processor
self.dataset_info = dataset_info
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
pass
from ..extras.types import DatasetInfo, HFDataset, Sample
class DataEngine(Dataset):
@@ -45,73 +33,78 @@ class DataEngine(Dataset):
"""Data arguments."""
self.datasets: dict[str, HFDataset] = {}
"""Dict of (dataset_name, dataset)"""
self.dataset_info: dict[str, DatasetInfo] = {}
self.dataset_infos: dict[str, DatasetInfo] = {}
"""Dict of (dataset_name, dataset_info)"""
self.streaming: bool = False
"""Whether dataset is streaming."""
self.data_index: list[tuple[str, int]] = []
"""List of (dataset_name, sample_index)"""
self.data_loader_plugin = DataLoaderPlugin(args=self.args)
"""Data loader plugin."""
self.data_index_plugin = DataIndexPlugin()
"""Data index plugin."""
self.data_getitem_plugin = DataGetItemPlugin(datasets=self.datasets, data_index=self.data_index)
"""Data getitem plugin."""
self.streaming: bool = False
"""Whether dataset is streaming."""
self.get_dataset_info()
self.load_dataset()
self.build_data_index()
def get_dataset_info(self) -> None:
"""Get dataset info."""
"""Get dataset info from data arguments."""
if self.args.dataset.endswith(".yaml") and os.path.isfile(
os.path.join(self.args.dataset_dir, self.args.dataset)
): # local file
self.dataset_info = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_info = OmegaConf.load(filepath)
self.dataset_infos = OmegaConf.load(filepath)
elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s)
self.dataset_info = {"default": {"file_name": self.args.dataset}}
self.dataset_infos = {"default": {"file_name": self.args.dataset}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_info = {"default": {"hf_hub_url": self.args.dataset}}
self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
def load_dataset(self) -> None:
"""Load dataset from dataset info."""
for key, value in self.dataset_info.items():
"""Load datasets according to dataset info."""
for key, value in self.dataset_infos.items():
split = value.get("split", "train")
streaming = value.get("streaming", False)
self.streaming |= streaming
if "hf_hub_url" in value:
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
else: # data loader plugin
self.datasets[key] = self.data_loader_plugin.auto_load_data(value)
from ..plugins.data_plugins.loader import DataLoaderPlugin
self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)
def build_data_index(self) -> None:
"""Build dataset index."""
for dataset_name, dataset in self.datasets.items():
size = self.dataset_info[dataset_name].get("size")
weight = self.dataset_info[dataset_name].get("weight")
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if self.streaming:
data_index = [(dataset_name, -1) for _ in range(1000)]
else:
data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
if size or weight: # data index plugin
data_index = self.data_index_plugin.adjust_data_index(data_index, size, weight)
from ..plugins.data_plugins.loader import DataIndexPlugin
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
def get_data_collator(self, processor: Processor) -> DataCollator:
"""Get data collator.
def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
"""Convert dataset sample.
Args:
processor (Processor): Processor.
raw_sample (dict[str, Any]): Raw dataset sample.
dataset_name (str): Dataset name.
Returns:
DataCollator: Data collator.
Sample: Dataset sample.
"""
return DataCollator(processor=processor, dataset_info=self.dataset_info)
converter = self.dataset_infos[dataset_name].get("converter")
if converter is not None:
from ..plugins.data_plugins.converter import get_converter
return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
else:
return {"_dataset_name": dataset_name, **raw_sample}
def __len__(self) -> int:
"""Get dataset length.
@@ -124,23 +117,33 @@ class DataEngine(Dataset):
else:
return len(self.data_index)
def __getitem__(self, index: Union[int, slice, list[int]]) -> Union[dict, list[dict]]:
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
"""Get dataset item.
Args:
index (int): Dataset index.
Returns:
dict: Dataset item.
Sample: Dataset item.
"""
if self.streaming:
raise ValueError("Streaming dataset does not support index access.")
if isinstance(index, int):
dataset_name, sample_index = self.data_index[index]
return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else:
return self.data_getitem_plugin.get_data(index)
from ..plugins.data_plugins.loader import DataSelectorPlugin
selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
for dataset_name, sample_index in selected_index
]
else:
dataset_name, sample_index = selected_index
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
def __iter__(self) -> Iterable:
"""Get dataset iterator.
@@ -156,7 +159,7 @@ class DataEngine(Dataset):
raise NotImplementedError()
def __aiter__(self) -> AsyncIterable:
async def __aiter__(self) -> AsyncIterable:
"""Get dataset async iterator.
Returns:
@@ -169,3 +172,11 @@ class DataEngine(Dataset):
pass
raise NotImplementedError()
if __name__ == "__main__":
from ..config.parser import get_args
data_args, *_ = get_args()
data_engine = DataEngine(data_args=data_args)
print(data_engine[0])