mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[v1] init data plugins (#9248)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									1c35db60d6
								
							
						
					
					
						commit
						9687b71d3a
					
				@ -13,8 +13,8 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import AsyncIterator, Iterator
 | 
			
		||||
from typing import Literal, Optional
 | 
			
		||||
from collections.abc import AsyncIterable, Iterable
 | 
			
		||||
from typing import Any, Union
 | 
			
		||||
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
from huggingface_hub import hf_hub_download
 | 
			
		||||
@ -22,84 +22,22 @@ from omegaconf import OmegaConf
 | 
			
		||||
from torch.utils.data import Dataset
 | 
			
		||||
 | 
			
		||||
from ..config.data_args import DataArguments
 | 
			
		||||
from ..extras.types import DatasetInfo, HFDataset, Processor
 | 
			
		||||
from ..extras.types import DatasetInfo, HFDataset, Processor, Tensor
 | 
			
		||||
from ..plugins.data_plugins.loader import DataGetItemPlugin, DataIndexPlugin, DataLoaderPlugin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataCollator:
 | 
			
		||||
    """Default Data collator."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, processor: Processor) -> None:
 | 
			
		||||
    def __init__(self, processor: Processor, dataset_info: dict[str, DatasetInfo]) -> None:
 | 
			
		||||
        self.processor = processor
 | 
			
		||||
        self.dataset_info = dataset_info
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DatasetPathMixin:
 | 
			
		||||
    """Path utilities."""
 | 
			
		||||
 | 
			
		||||
    args: DataArguments
 | 
			
		||||
    """Data arguments."""
 | 
			
		||||
 | 
			
		||||
    def _abspath(self, path: str, dataset_dir: Optional[str] = None) -> str:
 | 
			
		||||
        """Get absolute path of dataset.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            path (str): Dataset path.
 | 
			
		||||
            dataset_dir (Optional[str], optional): Dataset directory. Defaults to None.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: Absolute path of dataset.
 | 
			
		||||
        """
 | 
			
		||||
        dataset_dir = dataset_dir or self.args.dataset_dir
 | 
			
		||||
        return os.path.abspath(os.path.expanduser(os.path.join(dataset_dir, path)))
 | 
			
		||||
 | 
			
		||||
    def _exists(self, path: str, dataset_dir: Optional[str] = None) -> bool:
 | 
			
		||||
        """Check if dataset exists.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            path (str): Dataset path.
 | 
			
		||||
            dataset_dir (Optional[str], optional): Dataset directory. Defaults to None.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            bool: Whether dataset exists.
 | 
			
		||||
        """
 | 
			
		||||
        return os.path.exists(self._abspath(path, dataset_dir))
 | 
			
		||||
 | 
			
		||||
    def _isfile(self, path: str, dataset_dir: Optional[str] = None) -> bool:
 | 
			
		||||
        """Check if dataset is a file.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            path (str): Dataset path.
 | 
			
		||||
            dataset_dir (Optional[str], optional): Dataset directory. Defaults to None.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            bool: Whether dataset is a file.
 | 
			
		||||
        """
 | 
			
		||||
        return os.path.isfile(self._abspath(path, dataset_dir))
 | 
			
		||||
 | 
			
		||||
    def _isdir(self, path: str, dataset_dir: Optional[str] = None) -> bool:
 | 
			
		||||
        """Check if dataset is a directory.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            path (str): Dataset path.
 | 
			
		||||
            dataset_dir (Optional[str], optional): Dataset directory. Defaults to None.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            bool: Whether dataset is a directory.
 | 
			
		||||
        """
 | 
			
		||||
        return os.path.isdir(self._abspath(path, dataset_dir))
 | 
			
		||||
 | 
			
		||||
    def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
 | 
			
		||||
        """Get dataset builder name.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            path (str): Dataset path.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
 | 
			
		||||
        """
 | 
			
		||||
        return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataEngine(Dataset, DatasetPathMixin):
 | 
			
		||||
class DataEngine(Dataset):
 | 
			
		||||
    """Data engine."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, data_args: DataArguments) -> None:
 | 
			
		||||
@ -113,19 +51,27 @@ class DataEngine(Dataset, DatasetPathMixin):
 | 
			
		||||
        """Whether dataset is streaming."""
 | 
			
		||||
        self.data_index: list[tuple[str, int]] = []
 | 
			
		||||
        """List of (dataset_name, sample_index)"""
 | 
			
		||||
        self.data_loader_plugin = DataLoaderPlugin(args=self.args)
 | 
			
		||||
        """Data loader plugin."""
 | 
			
		||||
        self.data_index_plugin = DataIndexPlugin()
 | 
			
		||||
        """Data index plugin."""
 | 
			
		||||
        self.data_getitem_plugin = DataGetItemPlugin(datasets=self.datasets, data_index=self.data_index)
 | 
			
		||||
        """Data getitem plugin."""
 | 
			
		||||
        self.get_dataset_info()
 | 
			
		||||
        self.load_dataset()
 | 
			
		||||
        self.build_data_index()
 | 
			
		||||
 | 
			
		||||
    def get_dataset_info(self) -> None:
 | 
			
		||||
        """Get dataset info."""
 | 
			
		||||
        if self.args.dataset.endswith(".yaml") and self._isfile(self.args.dataset):  # local file
 | 
			
		||||
            self.dataset_info = OmegaConf.load(self._abspath(self.args.dataset))
 | 
			
		||||
        if self.args.dataset.endswith(".yaml") and os.path.isfile(
 | 
			
		||||
            os.path.join(self.args.dataset_dir, self.args.dataset)
 | 
			
		||||
        ):  # local file
 | 
			
		||||
            self.dataset_info = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
 | 
			
		||||
        elif self.args.dataset.endswith(".yaml"):  # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
 | 
			
		||||
            repo_id, filename = os.path.split(self.args.dataset)
 | 
			
		||||
            filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
 | 
			
		||||
            self.dataset_info = OmegaConf.load(filepath)
 | 
			
		||||
        elif self._exists(self.args.dataset):  # local file(s)
 | 
			
		||||
        elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)):  # local file(s)
 | 
			
		||||
            self.dataset_info = {"default": {"file_name": self.args.dataset}}
 | 
			
		||||
        else:  # hf hub dataset, e.g. llamafactory/v1-sft-demo
 | 
			
		||||
            self.dataset_info = {"default": {"hf_hub_url": self.args.dataset}}
 | 
			
		||||
@ -133,37 +79,39 @@ class DataEngine(Dataset, DatasetPathMixin):
 | 
			
		||||
    def load_dataset(self) -> None:
 | 
			
		||||
        """Load dataset from dataset info."""
 | 
			
		||||
        for key, value in self.dataset_info.items():
 | 
			
		||||
            dataset_dir = value.get("dataset_dir", self.args.dataset_dir)
 | 
			
		||||
            split = value.get("split", "train")
 | 
			
		||||
            streaming = value.get("streaming", False)
 | 
			
		||||
            self.streaming |= streaming
 | 
			
		||||
            if "hf_hub_url" in value:
 | 
			
		||||
                self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
 | 
			
		||||
            elif "file_name" in value:
 | 
			
		||||
                filepath = self._abspath(value["file_name"], dataset_dir)
 | 
			
		||||
                if os.path.isdir(filepath):
 | 
			
		||||
                    filetype = self._get_builder_name(os.listdir(filepath)[0])
 | 
			
		||||
                    self.datasets[key] = load_dataset(filetype, data_dir=filepath, split=split)
 | 
			
		||||
                elif os.path.isfile(filepath):
 | 
			
		||||
                    filetype = self._get_builder_name(filepath)
 | 
			
		||||
                    self.datasets[key] = load_dataset(filetype, data_files=filepath, split=split)
 | 
			
		||||
                else:
 | 
			
		||||
                    raise ValueError(f"Can not load dataset {key} from {filepath}.")
 | 
			
		||||
 | 
			
		||||
                if streaming:
 | 
			
		||||
                    self.datasets[key] = self.datasets[key].to_iterable_dataset()
 | 
			
		||||
            else:
 | 
			
		||||
                # TODO: support dataset loader plugins
 | 
			
		||||
                raise ValueError(f"Dataset {key} is not supported.")
 | 
			
		||||
            else:  # data loader plugin
 | 
			
		||||
                self.datasets[key] = self.data_loader_plugin.auto_load_data(value)
 | 
			
		||||
 | 
			
		||||
    def build_data_index(self) -> None:
 | 
			
		||||
        """Build dataset index."""
 | 
			
		||||
        for dataset_name, dataset in self.datasets.items():
 | 
			
		||||
            size = self.dataset_info[dataset_name].get("size")
 | 
			
		||||
            weight = self.dataset_info[dataset_name].get("weight")
 | 
			
		||||
            if self.streaming:
 | 
			
		||||
                self.data_index.append((dataset_name, -1))
 | 
			
		||||
                data_index = [(dataset_name, -1) for _ in range(1000)]
 | 
			
		||||
            else:
 | 
			
		||||
                # TODO: add sample_num, weight
 | 
			
		||||
                self.data_index.extend([(dataset_name, sample_index) for sample_index in range(len(dataset))])
 | 
			
		||||
                data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
 | 
			
		||||
 | 
			
		||||
            if size or weight:  # data index plugin
 | 
			
		||||
                data_index = self.data_index_plugin.adjust_data_index(data_index, size, weight)
 | 
			
		||||
 | 
			
		||||
            self.data_index.extend(data_index)
 | 
			
		||||
 | 
			
		||||
    def get_data_collator(self, processor: Processor) -> DataCollator:
 | 
			
		||||
        """Get data collator.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            processor (Processor): Processor.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            DataCollator: Data collator.
 | 
			
		||||
        """
 | 
			
		||||
        return DataCollator(processor=processor, dataset_info=self.dataset_info)
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        """Get dataset length.
 | 
			
		||||
@ -176,7 +124,7 @@ class DataEngine(Dataset, DatasetPathMixin):
 | 
			
		||||
        else:
 | 
			
		||||
            return len(self.data_index)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index: int) -> dict:
 | 
			
		||||
    def __getitem__(self, index: Union[int, slice, list[int]]) -> Union[dict, list[dict]]:
 | 
			
		||||
        """Get dataset item.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
@ -185,21 +133,39 @@ class DataEngine(Dataset, DatasetPathMixin):
 | 
			
		||||
        Returns:
 | 
			
		||||
            dict: Dataset item.
 | 
			
		||||
        """
 | 
			
		||||
        dataset_name, sample_index = self.data_index[index]
 | 
			
		||||
        return self.datasets[dataset_name][sample_index]
 | 
			
		||||
        if self.streaming:
 | 
			
		||||
            raise ValueError("Streaming dataset does not support index access.")
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterator:
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            dataset_name, sample_index = self.data_index[index]
 | 
			
		||||
            return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
 | 
			
		||||
        else:
 | 
			
		||||
            return self.data_getitem_plugin.get_data(index)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterable:
 | 
			
		||||
        """Get dataset iterator.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Iterator: Dataset iterator.
 | 
			
		||||
            Iterable: Dataset iterator.
 | 
			
		||||
        """
 | 
			
		||||
        if self.streaming:
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            # TODO: add shuffle here
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def __aiter__(self) -> AsyncIterator:
 | 
			
		||||
    def __aiter__(self) -> AsyncIterable:
 | 
			
		||||
        """Get dataset async iterator.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            AsyncIterator: Dataset async iterator.
 | 
			
		||||
            AsyncIterable: Dataset async iterator.
 | 
			
		||||
        """
 | 
			
		||||
        if self.streaming:
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            # TODO: add shuffle here
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
@ -16,21 +16,20 @@ from typing import TYPE_CHECKING, NotRequired, TypedDict, Union
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from datasets import Dataset as HFArrowDataset
 | 
			
		||||
    from datasets import IterableDataset as HFIterableDataset
 | 
			
		||||
    from torch.utils.data import DataLoader as TorchDataLoader
 | 
			
		||||
    from torch.utils.data import Dataset as TorchArrowDataset
 | 
			
		||||
    from torch.utils.data import IterableDataset as TorchIterableDataset
 | 
			
		||||
    from transformers import DataCollator as HFDataCollator
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    import datasets
 | 
			
		||||
    import torch
 | 
			
		||||
    import torch.utils.data
 | 
			
		||||
    import transformers
 | 
			
		||||
 | 
			
		||||
    TorchDataset = Union[TorchArrowDataset, TorchIterableDataset]
 | 
			
		||||
    HFDataset = Union[HFArrowDataset, HFIterableDataset]
 | 
			
		||||
    DataCollator = HFDataCollator
 | 
			
		||||
    DataLoader = TorchDataLoader
 | 
			
		||||
    Model = PreTrainedModel
 | 
			
		||||
    Processor = Union[PreTrainedTokenizer, ProcessorMixin]
 | 
			
		||||
    Tensor = torch.Tensor
 | 
			
		||||
    TorchDataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset]
 | 
			
		||||
    HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
 | 
			
		||||
    DataCollator = transformers.DataCollator
 | 
			
		||||
    DataLoader = torch.utils.data.DataLoader
 | 
			
		||||
    Model = transformers.PreTrainedModel
 | 
			
		||||
    Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
 | 
			
		||||
else:
 | 
			
		||||
    Tensor = None
 | 
			
		||||
    TorchDataset = None
 | 
			
		||||
    HFDataset = None
 | 
			
		||||
    DataCollator = None
 | 
			
		||||
@ -45,14 +44,14 @@ class DatasetInfo(TypedDict, total=False):
 | 
			
		||||
    file_name: NotRequired[str]
 | 
			
		||||
    """Local file path."""
 | 
			
		||||
    dataset_dir: NotRequired[str]
 | 
			
		||||
    """Dataset directory."""
 | 
			
		||||
    """Dataset directory, default to args.dataset_dir."""
 | 
			
		||||
    split: NotRequired[str]
 | 
			
		||||
    """Dataset split."""
 | 
			
		||||
    """Dataset split, default to "train"."""
 | 
			
		||||
    converter: NotRequired[str]
 | 
			
		||||
    """Dataset converter."""
 | 
			
		||||
    num_samples: NotRequired[int]
 | 
			
		||||
    """Number of samples."""
 | 
			
		||||
    """Dataset converter, default to None."""
 | 
			
		||||
    size: NotRequired[int]
 | 
			
		||||
    """Number of samples, default to all samples."""
 | 
			
		||||
    weight: NotRequired[float]
 | 
			
		||||
    """Dataset weight."""
 | 
			
		||||
    """Dataset weight, default to 1.0."""
 | 
			
		||||
    streaming: NotRequired[bool]
 | 
			
		||||
    """Is streaming dataset."""
 | 
			
		||||
    """Is streaming dataset, default to False."""
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,102 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
 | 
			
		||||
from ...config.data_args import DataArguments
 | 
			
		||||
from ...extras.types import DatasetInfo, HFDataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataLoaderPlugin:
 | 
			
		||||
    args: DataArguments
 | 
			
		||||
 | 
			
		||||
    def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
 | 
			
		||||
        """Get dataset builder name.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            path (str): Dataset path.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Literal["arrow", "csv", "json", "parquet", "text"]: Dataset builder name.
 | 
			
		||||
        """
 | 
			
		||||
        return os.path.splitext(path)[-1][1:].replace("jsonl", "json").replace("txt", "text")
 | 
			
		||||
 | 
			
		||||
    def auto_load_data(self, dataset_info: DatasetInfo) -> HFDataset:
 | 
			
		||||
        dataset_dir = dataset_info.get("dataset_dir", self.args.dataset_dir)
 | 
			
		||||
        split = dataset_info.get("split", "train")
 | 
			
		||||
        streaming = dataset_info.get("streaming", False)
 | 
			
		||||
        if "file_name" in dataset_info:
 | 
			
		||||
            filepath = os.path.join(dataset_dir, dataset_info["file_name"])
 | 
			
		||||
            return self.load_data_from_file(filepath, split, streaming)
 | 
			
		||||
        else:
 | 
			
		||||
            raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def load_data_from_file(self, filepath: str, split: str, streaming: bool) -> HFDataset:
 | 
			
		||||
        if os.path.isdir(filepath):
 | 
			
		||||
            filetype = self._get_builder_name(os.listdir(filepath)[0])
 | 
			
		||||
            dataset = load_dataset(filetype, data_dir=filepath, split=split)
 | 
			
		||||
        elif os.path.isfile(filepath):
 | 
			
		||||
            filetype = self._get_builder_name(filepath)
 | 
			
		||||
            dataset = load_dataset(filetype, data_files=filepath, split=split)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Can not load dataset from {filepath}.")
 | 
			
		||||
 | 
			
		||||
        if streaming:
 | 
			
		||||
            dataset = dataset.to_iterable_dataset()
 | 
			
		||||
 | 
			
		||||
        return dataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataIndexPlugin:
 | 
			
		||||
    def adjust_data_index(
 | 
			
		||||
        self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
 | 
			
		||||
    ) -> list[tuple[str, int]]:
 | 
			
		||||
        if size is not None:
 | 
			
		||||
            data_index = self.adjust_by_size(data_index, size)
 | 
			
		||||
 | 
			
		||||
        if weight is not None:
 | 
			
		||||
            data_index = self.adjust_by_weight(data_index, weight)
 | 
			
		||||
 | 
			
		||||
        return data_index
 | 
			
		||||
 | 
			
		||||
    def adjust_by_size(self, data_index: list[tuple[str, int]], size: int) -> list[tuple[str, int]]:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def adjust_by_weight(self, data_index: list[tuple[str, int]], weight: float) -> list[tuple[str, int]]:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataGetItemPlugin:
 | 
			
		||||
    datasets: dict[str, HFDataset]
 | 
			
		||||
    data_index: list[tuple[str, int]]
 | 
			
		||||
 | 
			
		||||
    def _get_by_index(self, index: int) -> dict:
 | 
			
		||||
        dataset_name, sample_index = self.data_index[index]
 | 
			
		||||
        return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
 | 
			
		||||
 | 
			
		||||
    def get_data(self, index: Union[slice, list[int]]) -> list[dict]:
 | 
			
		||||
        if isinstance(index, slice):
 | 
			
		||||
            return [self._get_by_index(i) for i in range(*index.indices(len(self.data_index)))]
 | 
			
		||||
        elif isinstance(index, list):
 | 
			
		||||
            return [self._get_by_index(i) for i in index]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Invalid index type {type(index)}.")
 | 
			
		||||
@ -33,7 +33,7 @@ def test_map_dataset(num_samples: int):
 | 
			
		||||
    indexes = random.choices(range(len(data_engine)), k=num_samples)
 | 
			
		||||
    for index in indexes:
 | 
			
		||||
        print(data_engine[index])
 | 
			
		||||
        assert data_engine[index] == original_data[index]
 | 
			
		||||
        assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user