[v1] init data plugins (#9248)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng 2025-10-09 22:36:48 +08:00 committed by GitHub
parent 1c35db60d6
commit 9687b71d3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 190 additions and 123 deletions

View File

@ -13,8 +13,8 @@
# limitations under the License. # limitations under the License.
import os import os
from collections.abc import AsyncIterator, Iterator from collections.abc import AsyncIterable, Iterable
from typing import Literal, Optional from typing import Any, Union
from datasets import load_dataset from datasets import load_dataset
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@ -22,84 +22,22 @@ from omegaconf import OmegaConf
from torch.utils.data import Dataset from torch.utils.data import Dataset
from ..config.data_args import DataArguments from ..config.data_args import DataArguments
from ..extras.types import DatasetInfo, HFDataset, Processor from ..extras.types import DatasetInfo, HFDataset, Processor, Tensor
from ..plugins.data_plugins.loader import DataGetItemPlugin, DataIndexPlugin, DataLoaderPlugin
class DataCollator: class DataCollator:
"""Default Data collator.""" """Default Data collator."""
def __init__(self, processor: Processor) -> None: def __init__(self, processor: Processor, dataset_info: dict[str, DatasetInfo]) -> None:
self.processor = processor self.processor = processor
self.dataset_info = dataset_info
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
pass
class DatasetPathMixin: class DataEngine(Dataset):
"""Path utilities."""
args: DataArguments
"""Data arguments."""
def _abspath(self, path: str, dataset_dir: Optional[str] = None) -> str:
"""Get absolute path of dataset.
Args:
path (str): Dataset path.
dataset_dir (Optional[str], optional): Dataset directory. Defaults to None.
Returns:
str: Absolute path of dataset.
"""
dataset_dir = dataset_dir or self.args.dataset_dir
return os.path.abspath(os.path.expanduser(os.path.join(dataset_dir, path)))
def _exists(self, path: str, dataset_dir: Optional[str] = None) -> bool:
"""Check if dataset exists.
Args:
path (str): Dataset path.
dataset_dir (Optional[str], optional): Dataset directory. Defaults to None.
Returns:
bool: Whether dataset exists.
"""
return os.path.exists(self._abspath(path, dataset_dir))
def _isfile(self, path: str, dataset_dir: Optional[str] = None) -> bool:
"""Check if dataset is a file.
Args:
path (str): Dataset path.
dataset_dir (Optional[str], optional): Dataset directory. Defaults to None.
Returns:
bool: Whether dataset is a file.
"""
return os.path.isfile(self._abspath(path, dataset_dir))
def _isdir(self, path: str, dataset_dir: Optional[str] = None) -> bool:
"""Check if dataset is a directory.
Args:
path (str): Dataset path.
dataset_dir (Optional[str], optional): Dataset directory. Defaults to None.
Returns:
bool: Whether dataset is a directory.
"""
return os.path.isdir(self._abspath(path, dataset_dir))
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""Get dataset builder name.
Args:
path (str): Dataset path.
Returns:
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
"""
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
class DataEngine(Dataset, DatasetPathMixin):
"""Data engine.""" """Data engine."""
def __init__(self, data_args: DataArguments) -> None: def __init__(self, data_args: DataArguments) -> None:
@ -113,19 +51,27 @@ class DataEngine(Dataset, DatasetPathMixin):
"""Whether dataset is streaming.""" """Whether dataset is streaming."""
self.data_index: list[tuple[str, int]] = [] self.data_index: list[tuple[str, int]] = []
"""List of (dataset_name, sample_index)""" """List of (dataset_name, sample_index)"""
self.data_loader_plugin = DataLoaderPlugin(args=self.args)
"""Data loader plugin."""
self.data_index_plugin = DataIndexPlugin()
"""Data index plugin."""
self.data_getitem_plugin = DataGetItemPlugin(datasets=self.datasets, data_index=self.data_index)
"""Data getitem plugin."""
self.get_dataset_info() self.get_dataset_info()
self.load_dataset() self.load_dataset()
self.build_data_index() self.build_data_index()
def get_dataset_info(self) -> None: def get_dataset_info(self) -> None:
"""Get dataset info.""" """Get dataset info."""
if self.args.dataset.endswith(".yaml") and self._isfile(self.args.dataset): # local file if self.args.dataset.endswith(".yaml") and os.path.isfile(
self.dataset_info = OmegaConf.load(self._abspath(self.args.dataset)) os.path.join(self.args.dataset_dir, self.args.dataset)
): # local file
self.dataset_info = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
repo_id, filename = os.path.split(self.args.dataset) repo_id, filename = os.path.split(self.args.dataset)
filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
self.dataset_info = OmegaConf.load(filepath) self.dataset_info = OmegaConf.load(filepath)
elif self._exists(self.args.dataset): # local file(s) elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s)
self.dataset_info = {"default": {"file_name": self.args.dataset}} self.dataset_info = {"default": {"file_name": self.args.dataset}}
else: # hf hub dataset, e.g. llamafactory/v1-sft-demo else: # hf hub dataset, e.g. llamafactory/v1-sft-demo
self.dataset_info = {"default": {"hf_hub_url": self.args.dataset}} self.dataset_info = {"default": {"hf_hub_url": self.args.dataset}}
@ -133,37 +79,39 @@ class DataEngine(Dataset, DatasetPathMixin):
def load_dataset(self) -> None: def load_dataset(self) -> None:
"""Load dataset from dataset info.""" """Load dataset from dataset info."""
for key, value in self.dataset_info.items(): for key, value in self.dataset_info.items():
dataset_dir = value.get("dataset_dir", self.args.dataset_dir)
split = value.get("split", "train") split = value.get("split", "train")
streaming = value.get("streaming", False) streaming = value.get("streaming", False)
self.streaming |= streaming self.streaming |= streaming
if "hf_hub_url" in value: if "hf_hub_url" in value:
self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming) self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
elif "file_name" in value: else: # data loader plugin
filepath = self._abspath(value["file_name"], dataset_dir) self.datasets[key] = self.data_loader_plugin.auto_load_data(value)
if os.path.isdir(filepath):
filetype = self._get_builder_name(os.listdir(filepath)[0])
self.datasets[key] = load_dataset(filetype, data_dir=filepath, split=split)
elif os.path.isfile(filepath):
filetype = self._get_builder_name(filepath)
self.datasets[key] = load_dataset(filetype, data_files=filepath, split=split)
else:
raise ValueError(f"Can not load dataset {key} from {filepath}.")
if streaming:
self.datasets[key] = self.datasets[key].to_iterable_dataset()
else:
# TODO: support dataset loader plugins
raise ValueError(f"Dataset {key} is not supported.")
def build_data_index(self) -> None: def build_data_index(self) -> None:
"""Build dataset index.""" """Build dataset index."""
for dataset_name, dataset in self.datasets.items(): for dataset_name, dataset in self.datasets.items():
size = self.dataset_info[dataset_name].get("size")
weight = self.dataset_info[dataset_name].get("weight")
if self.streaming: if self.streaming:
self.data_index.append((dataset_name, -1)) data_index = [(dataset_name, -1) for _ in range(1000)]
else: else:
# TODO: add sample_num, weight data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
self.data_index.extend([(dataset_name, sample_index) for sample_index in range(len(dataset))])
if size or weight: # data index plugin
data_index = self.data_index_plugin.adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
def get_data_collator(self, processor: Processor) -> DataCollator:
"""Get data collator.
Args:
processor (Processor): Processor.
Returns:
DataCollator: Data collator.
"""
return DataCollator(processor=processor, dataset_info=self.dataset_info)
def __len__(self) -> int: def __len__(self) -> int:
"""Get dataset length. """Get dataset length.
@ -176,7 +124,7 @@ class DataEngine(Dataset, DatasetPathMixin):
else: else:
return len(self.data_index) return len(self.data_index)
def __getitem__(self, index: int) -> dict: def __getitem__(self, index: Union[int, slice, list[int]]) -> Union[dict, list[dict]]:
"""Get dataset item. """Get dataset item.
Args: Args:
@ -185,21 +133,39 @@ class DataEngine(Dataset, DatasetPathMixin):
Returns: Returns:
dict: Dataset item. dict: Dataset item.
""" """
dataset_name, sample_index = self.data_index[index] if self.streaming:
return self.datasets[dataset_name][sample_index] raise ValueError("Streaming dataset does not support index access.")
def __iter__(self) -> Iterator: if isinstance(index, int):
dataset_name, sample_index = self.data_index[index]
return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
else:
return self.data_getitem_plugin.get_data(index)
def __iter__(self) -> Iterable:
"""Get dataset iterator. """Get dataset iterator.
Returns: Returns:
Iterator: Dataset iterator. Iterable: Dataset iterator.
""" """
if self.streaming:
pass
else:
# TODO: add shuffle here
pass
raise NotImplementedError() raise NotImplementedError()
def __aiter__(self) -> AsyncIterator: def __aiter__(self) -> AsyncIterable:
"""Get dataset async iterator. """Get dataset async iterator.
Returns: Returns:
AsyncIterator: Dataset async iterator. AsyncIterable: Dataset async iterator.
""" """
if self.streaming:
pass
else:
# TODO: add shuffle here
pass
raise NotImplementedError() raise NotImplementedError()

View File

@ -16,21 +16,20 @@ from typing import TYPE_CHECKING, NotRequired, TypedDict, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset as HFArrowDataset import datasets
from datasets import IterableDataset as HFIterableDataset import torch
from torch.utils.data import DataLoader as TorchDataLoader import torch.utils.data
from torch.utils.data import Dataset as TorchArrowDataset import transformers
from torch.utils.data import IterableDataset as TorchIterableDataset
from transformers import DataCollator as HFDataCollator
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
TorchDataset = Union[TorchArrowDataset, TorchIterableDataset] Tensor = torch.Tensor
HFDataset = Union[HFArrowDataset, HFIterableDataset] TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]
DataCollator = HFDataCollator HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
DataLoader = TorchDataLoader DataCollator = transformers.DataCollator
Model = PreTrainedModel DataLoader = torch.utils.data.DataLoader
Processor = Union[PreTrainedTokenizer, ProcessorMixin] Model = transformers.PreTrainedModel
Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
else: else:
Tensor = None
TorchDataset = None TorchDataset = None
HFDataset = None HFDataset = None
DataCollator = None DataCollator = None
@ -45,14 +44,14 @@ class DatasetInfo(TypedDict, total=False):
file_name: NotRequired[str] file_name: NotRequired[str]
"""Local file path.""" """Local file path."""
dataset_dir: NotRequired[str] dataset_dir: NotRequired[str]
"""Dataset directory.""" """Dataset directory, default to args.dataset_dir."""
split: NotRequired[str] split: NotRequired[str]
"""Dataset split.""" """Dataset split, default to "train"."""
converter: NotRequired[str] converter: NotRequired[str]
"""Dataset converter.""" """Dataset converter, default to None."""
num_samples: NotRequired[int] size: NotRequired[int]
"""Number of samples.""" """Number of samples, default to all samples."""
weight: NotRequired[float] weight: NotRequired[float]
"""Dataset weight.""" """Dataset weight, default to 1.0."""
streaming: NotRequired[bool] streaming: NotRequired[bool]
"""Is streaming dataset.""" """Is streaming dataset, default to False."""

View File

@ -0,0 +1,102 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass
from typing import Literal, Optional, Union
from datasets import load_dataset
from ...config.data_args import DataArguments
from ...extras.types import DatasetInfo, HFDataset
@dataclass
class DataLoaderPlugin:
args: DataArguments
def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
"""Get dataset builder name.
Args:
path (str): Dataset path.
Returns:
Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
"""
return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir)
split = dataset_info.get("split", "train")
streaming = dataset_info.get("streaming", False)
if "file_name" in dataset_info:
filepath = os.path.join(dataset_dir, dataset_info["file_name"])
return self.load_data_from_file(filepath, split, streaming)
else:
raise NotImplementedError()
def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = self._get_builder_name(os.listdir(filepath)[0])
dataset = load_dataset(filetype, data_dir=filepath, split=split)
elif os.path.isfile(filepath):
filetype = self._get_builder_name(filepath)
dataset = load_dataset(filetype, data_files=filepath, split=split)
else:
raise ValueError(f"Can not load dataset from {filepath}.")
if streaming:
dataset = dataset.to_iterable_dataset()
return dataset
@dataclass
class DataIndexPlugin:
def adjust_data_index(
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
) -> list[tuple[str, int]]:
if size is not None:
data_index = self.adjust_by_size(data_index, size)
if weight is not None:
data_index = self.adjust_by_weight(data_index, weight)
return data_index
def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
raise NotImplementedError()
def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
raise NotImplementedError()
@dataclass
class DataGetItemPlugin:
datasets: dict[str, HFDataset]
data_index: list[tuple[str, int]]
def _get_by_index(self, index: int) -> dict:
dataset_name, sample_index = self.data_index[index]
return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
def get_data(self, index: Union[slice, list[int]]) -> list[dict]:
if isinstance(index, slice):
return [self._get_by_index(i) for i in range(*index.indices(len(self.data_index)))]
elif isinstance(index, list):
return [self._get_by_index(i) for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")

View File

@ -33,7 +33,7 @@ def test_map_dataset(num_samples: int):
indexes = random.choices(range(len(data_engine)), k=num_samples) indexes = random.choices(range(len(data_engine)), k=num_samples)
for index in indexes: for index in indexes:
print(data_engine[index]) print(data_engine[index])
assert data_engine[index] == original_data[index] assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
if __name__ == "__main__": if __name__ == "__main__":