From 9687b71d3ae2858c85f359741abe62824faf5de2 Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Thu, 9 Oct 2025 22:36:48 +0800 Subject: [PATCH] [v1] init data plugins (#9248) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- src/llamafactory/v1/core/data_engine.py | 170 +++++++----------- src/llamafactory/v1/extras/types.py | 39 ++-- .../v1/plugins/data_plugins/filter.py | 0 .../v1/plugins/data_plugins/loader.py | 102 +++++++++++ tests_v1/core/test_data_engine.py | 2 +- 5 files changed, 190 insertions(+), 123 deletions(-) delete mode 100644 src/llamafactory/v1/plugins/data_plugins/filter.py diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py index 465b3352..3d0581b3 100644 --- a/src/llamafactory/v1/core/data_engine.py +++ b/src/llamafactory/v1/core/data_engine.py @@ -13,8 +13,8 @@ # limitations under the License. import os -from collections.abc import AsyncIterator, Iterator -from typing import Literal, Optional +from collections.abc import AsyncIterable, Iterable +from typing import Any, Union from datasets import load_dataset from huggingface_hub import hf_hub_download @@ -22,84 +22,22 @@ from omegaconf import OmegaConf from torch.utils.data import Dataset from ..config.data_args import DataArguments -from ..extras.types import DatasetInfo, HFDataset, Processor +from ..extras.types import DatasetInfo, HFDataset, Processor, Tensor +from ..plugins.data_plugins.loader import DataGetItemPlugin, DataIndexPlugin, DataLoaderPlugin class DataCollator: """Default Data collator.""" - def __init__(self, processor: Processor) -> None: + def __init__(self, processor: Processor, dataset_info: dict[str, DatasetInfo]) -> None: self.processor = processor + self.dataset_info = dataset_info + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]: + pass -class DatasetPathMixin: - """Path utilities.""" - - args: DataArguments - """Data arguments.""" - - def _abspath(self, path: str, dataset_dir: Optional[str] = None) -> str: - """Get absolute path of dataset. - - Args: - path (str): Dataset path. - dataset_dir (Optional[str], optional): Dataset directory. Defaults to None. - - Returns: - str: Absolute path of dataset. - """ - dataset_dir = dataset_dir or self.args.dataset_dir - return os.path.abspath(os.path.expanduser(os.path.join(dataset_dir, path))) - - def _exists(self, path: str, dataset_dir: Optional[str] = None) -> bool: - """Check if dataset exists. - - Args: - path (str): Dataset path. - dataset_dir (Optional[str], optional): Dataset directory. Defaults to None. - - Returns: - bool: Whether dataset exists. - """ - return os.path.exists(self._abspath(path, dataset_dir)) - - def _isfile(self, path: str, dataset_dir: Optional[str] = None) -> bool: - """Check if dataset is a file. - - Args: - path (str): Dataset path. - dataset_dir (Optional[str], optional): Dataset directory. Defaults to None. - - Returns: - bool: Whether dataset is a file. - """ - return os.path.isfile(self._abspath(path, dataset_dir)) - - def _isdir(self, path: str, dataset_dir: Optional[str] = None) -> bool: - """Check if dataset is a directory. - - Args: - path (str): Dataset path. - dataset_dir (Optional[str], optional): Dataset directory. Defaults to None. - - Returns: - bool: Whether dataset is a directory. - """ - return os.path.isdir(self._abspath(path, dataset_dir)) - - def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]: - """Get dataset builder name. - - Args: - path (str): Dataset path. - - Returns: - Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name. - """ - return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text") - - -class DataEngine(Dataset, DatasetPathMixin): +class DataEngine(Dataset): """Data engine.""" def __init__(self, data_args: DataArguments) -> None: @@ -113,19 +51,27 @@ class DataEngine(Dataset, DatasetPathMixin): """Whether dataset is streaming.""" self.data_index: list[tuple[str, int]] = [] """List of (dataset_name, sample_index)""" + self.data_loader_plugin = DataLoaderPlugin(args=self.args) + """Data loader plugin.""" + self.data_index_plugin = DataIndexPlugin() + """Data index plugin.""" + self.data_getitem_plugin = DataGetItemPlugin(datasets=self.datasets, data_index=self.data_index) + """Data getitem plugin.""" self.get_dataset_info() self.load_dataset() self.build_data_index() def get_dataset_info(self) -> None: """Get dataset info.""" - if self.args.dataset.endswith(".yaml") and self._isfile(self.args.dataset): # local file - self.dataset_info = OmegaConf.load(self._abspath(self.args.dataset)) + if self.args.dataset.endswith(".yaml") and os.path.isfile( + os.path.join(self.args.dataset_dir, self.args.dataset) + ): # local file + self.dataset_info = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset)) elif self.args.dataset.endswith(".yaml"): # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml repo_id, filename = os.path.split(self.args.dataset) filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset") self.dataset_info = OmegaConf.load(filepath) - elif self._exists(self.args.dataset): # local file(s) + elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)): # local file(s) self.dataset_info = {"default": {"file_name": self.args.dataset}} else: # hf hub dataset, e.g. llamafactory/v1-sft-demo self.dataset_info = {"default": {"hf_hub_url": self.args.dataset}} @@ -133,37 +79,39 @@ class DataEngine(Dataset, DatasetPathMixin): def load_dataset(self) -> None: """Load dataset from dataset info.""" for key, value in self.dataset_info.items(): - dataset_dir = value.get("dataset_dir", self.args.dataset_dir) split = value.get("split", "train") streaming = value.get("streaming", False) self.streaming |= streaming if "hf_hub_url" in value: self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming) - elif "file_name" in value: - filepath = self._abspath(value["file_name"], dataset_dir) - if os.path.isdir(filepath): - filetype = self._get_builder_name(os.listdir(filepath)[0]) - self.datasets[key] = load_dataset(filetype, data_dir=filepath, split=split) - elif os.path.isfile(filepath): - filetype = self._get_builder_name(filepath) - self.datasets[key] = load_dataset(filetype, data_files=filepath, split=split) - else: - raise ValueError(f"Can not load dataset {key} from {filepath}.") - - if streaming: - self.datasets[key] = self.datasets[key].to_iterable_dataset() - else: - # TODO: support dataset loader plugins - raise ValueError(f"Dataset {key} is not supported.") + else: # data loader plugin + self.datasets[key] = self.data_loader_plugin.auto_load_data(value) def build_data_index(self) -> None: """Build dataset index.""" for dataset_name, dataset in self.datasets.items(): + size = self.dataset_info[dataset_name].get("size") + weight = self.dataset_info[dataset_name].get("weight") if self.streaming: - self.data_index.append((dataset_name, -1)) + data_index = [(dataset_name, -1) for _ in range(1000)] else: - # TODO: add sample_num, weight - self.data_index.extend([(dataset_name, sample_index) for sample_index in range(len(dataset))]) + data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))] + + if size or weight: # data index plugin + data_index = self.data_index_plugin.adjust_data_index(data_index, size, weight) + + self.data_index.extend(data_index) + + def get_data_collator(self, processor: Processor) -> DataCollator: + """Get data collator. + + Args: + processor (Processor): Processor. + + Returns: + DataCollator: Data collator. + """ + return DataCollator(processor=processor, dataset_info=self.dataset_info) def __len__(self) -> int: """Get dataset length. @@ -176,7 +124,7 @@ class DataEngine(Dataset, DatasetPathMixin): else: return len(self.data_index) - def __getitem__(self, index: int) -> dict: + def __getitem__(self, index: Union[int, slice, list[int]]) -> Union[dict, list[dict]]: """Get dataset item. Args: @@ -185,21 +133,39 @@ class DataEngine(Dataset, DatasetPathMixin): Returns: dict: Dataset item. """ - dataset_name, sample_index = self.data_index[index] - return self.datasets[dataset_name][sample_index] + if self.streaming: + raise ValueError("Streaming dataset does not support index access.") - def __iter__(self) -> Iterator: + if isinstance(index, int): + dataset_name, sample_index = self.data_index[index] + return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]} + else: + return self.data_getitem_plugin.get_data(index) + + def __iter__(self) -> Iterable: """Get dataset iterator. Returns: - Iterator: Dataset iterator. + Iterable: Dataset iterator. """ + if self.streaming: + pass + else: + # TODO: add shuffle here + pass + raise NotImplementedError() - def __aiter__(self) -> AsyncIterator: + def __aiter__(self) -> AsyncIterable: """Get dataset async iterator. Returns: - AsyncIterator: Dataset async iterator. + AsyncIterable: Dataset async iterator. """ + if self.streaming: + pass + else: + # TODO: add shuffle here + pass + raise NotImplementedError() diff --git a/src/llamafactory/v1/extras/types.py b/src/llamafactory/v1/extras/types.py index 2cedd63c..b1b4a097 100644 --- a/src/llamafactory/v1/extras/types.py +++ b/src/llamafactory/v1/extras/types.py @@ -16,21 +16,20 @@ from typing import TYPE_CHECKING, NotRequired, TypedDict, Union if TYPE_CHECKING: - from datasets import Dataset as HFArrowDataset - from datasets import IterableDataset as HFIterableDataset - from torch.utils.data import DataLoader as TorchDataLoader - from torch.utils.data import Dataset as TorchArrowDataset - from torch.utils.data import IterableDataset as TorchIterableDataset - from transformers import DataCollator as HFDataCollator - from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin + import datasets + import torch + import torch.utils.data + import transformers - TorchDataset = Union[TorchArrowDataset, TorchIterableDataset] - HFDataset = Union[HFArrowDataset, HFIterableDataset] - DataCollator = HFDataCollator - DataLoader = TorchDataLoader - Model = PreTrainedModel - Processor = Union[PreTrainedTokenizer, ProcessorMixin] + Tensor = torch.Tensor + TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset] + HFDataset = Union[datasets.Dataset, datasets.IterableDataset] + DataCollator = transformers.DataCollator + DataLoader = torch.utils.data.DataLoader + Model = transformers.PreTrainedModel + Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin] else: + Tensor = None TorchDataset = None HFDataset = None DataCollator = None @@ -45,14 +44,14 @@ class DatasetInfo(TypedDict, total=False): file_name: NotRequired[str] """Local file path.""" dataset_dir: NotRequired[str] - """Dataset directory.""" + """Dataset directory, default to args.dataset_dir.""" split: NotRequired[str] - """Dataset split.""" + """Dataset split, default to "train".""" converter: NotRequired[str] - """Dataset converter.""" - num_samples: NotRequired[int] - """Number of samples.""" + """Dataset converter, default to None.""" + size: NotRequired[int] + """Number of samples, default to all samples.""" weight: NotRequired[float] - """Dataset weight.""" + """Dataset weight, default to 1.0.""" streaming: NotRequired[bool] - """Is streaming dataset.""" + """Is streaming dataset, default to False.""" diff --git a/src/llamafactory/v1/plugins/data_plugins/filter.py b/src/llamafactory/v1/plugins/data_plugins/filter.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/llamafactory/v1/plugins/data_plugins/loader.py b/src/llamafactory/v1/plugins/data_plugins/loader.py index e69de29b..626cc399 100644 --- a/src/llamafactory/v1/plugins/data_plugins/loader.py +++ b/src/llamafactory/v1/plugins/data_plugins/loader.py @@ -0,0 +1,102 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from dataclasses import dataclass +from typing import Literal, Optional, Union + +from datasets import load_dataset + +from ...config.data_args import DataArguments +from ...extras.types import DatasetInfo, HFDataset + + +@dataclass +class DataLoaderPlugin: + args: DataArguments + + def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]: + """Get dataset builder name. + + Args: + path (str): Dataset path. + + Returns: + Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name. + """ + return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text") + + def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset: + dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir) + split = dataset_info.get("split", "train") + streaming = dataset_info.get("streaming", False) + if "file_name" in dataset_info: + filepath = os.path.join(dataset_dir, dataset_info["file_name"]) + return self.load_data_from_file(filepath, split, streaming) + else: + raise NotImplementedError() + + def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset: + if os.path.isdir(filepath): + filetype = self._get_builder_name(os.listdir(filepath)[0]) + dataset = load_dataset(filetype, data_dir=filepath, split=split) + elif os.path.isfile(filepath): + filetype = self._get_builder_name(filepath) + dataset = load_dataset(filetype, data_files=filepath, split=split) + else: + raise ValueError(f"Can not load dataset from {filepath}.") + + if streaming: + dataset = dataset.to_iterable_dataset() + + return dataset + + +@dataclass +class DataIndexPlugin: + def adjust_data_index( + self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float] + ) -> list[tuple[str, int]]: + if size is not None: + data_index = self.adjust_by_size(data_index, size) + + if weight is not None: + data_index = self.adjust_by_weight(data_index, weight) + + return data_index + + def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]: + raise NotImplementedError() + + def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]: + raise NotImplementedError() + + +@dataclass +class DataGetItemPlugin: + datasets: dict[str, HFDataset] + data_index: list[tuple[str, int]] + + def _get_by_index(self, index: int) -> dict: + dataset_name, sample_index = self.data_index[index] + return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]} + + def get_data(self, index: Union[slice, list[int]]) -> list[dict]: + if isinstance(index, slice): + return [self._get_by_index(i) for i in range(*index.indices(len(self.data_index)))] + elif isinstance(index, list): + return [self._get_by_index(i) for i in index] + else: + raise ValueError(f"Invalid index type {type(index)}.") diff --git a/tests_v1/core/test_data_engine.py b/tests_v1/core/test_data_engine.py index ffd8f80e..68c830b7 100644 --- a/tests_v1/core/test_data_engine.py +++ b/tests_v1/core/test_data_engine.py @@ -33,7 +33,7 @@ def test_map_dataset(num_samples: int): indexes = random.choices(range(len(data_engine)), k=num_samples) for index in indexes: print(data_engine[index]) - assert data_engine[index] == original_data[index] + assert data_engine[index] == {"_dataset_name": "default", **original_data[index]} if __name__ == "__main__":