mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			48974783da
			...
			3dbca4b533
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					3dbca4b533 | ||
| 
						 | 
					9d1acbc191 | ||
| 
						 | 
					52e46e162e | 
							
								
								
									
										3
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							@ -43,6 +43,9 @@ jobs:
 | 
			
		||||
          - python: "3.9"
 | 
			
		||||
            os: "ubuntu-latest"
 | 
			
		||||
            transformers: "4.53.0"
 | 
			
		||||
        exclude:  # exclude python 3.9 on macos
 | 
			
		||||
          - python: "3.9"
 | 
			
		||||
            os: "macos-latest"
 | 
			
		||||
 | 
			
		||||
    runs-on: ${{ matrix.os }}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										50
									
								
								data/reason_tool_use_demo_50.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								data/reason_tool_use_demo_50.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							@ -12,8 +12,25 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from ..config.training_args import TrainingArguments
 | 
			
		||||
from ..extras.types import DataCollator, Model, Processor, TorchDataset
 | 
			
		||||
from ..extras.types import Model, Processor, Tensor, TorchDataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataCollator:
 | 
			
		||||
    """Default Data collator."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, processor: Processor) -> None:
 | 
			
		||||
        self.processor = processor
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
 | 
			
		||||
        """Collate features into a batch."""
 | 
			
		||||
        for feature in features:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        # sft: messages
 | 
			
		||||
        # dpo: chosen_messages, rejected_messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseTrainer:
 | 
			
		||||
 | 
			
		||||
@ -22,19 +22,7 @@ from omegaconf import OmegaConf
 | 
			
		||||
from torch.utils.data import Dataset
 | 
			
		||||
 | 
			
		||||
from ..config.data_args import DataArguments
 | 
			
		||||
from ..extras.types import DatasetInfo, HFDataset, Processor, Tensor
 | 
			
		||||
from ..plugins.data_plugins.loader import DataGetItemPlugin, DataIndexPlugin, DataLoaderPlugin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataCollator:
 | 
			
		||||
    """Default Data collator."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, processor: Processor, dataset_info: dict[str, DatasetInfo]) -> None:
 | 
			
		||||
        self.processor = processor
 | 
			
		||||
        self.dataset_info = dataset_info
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
 | 
			
		||||
        pass
 | 
			
		||||
from ..extras.types import DatasetInfo, HFDataset, Sample
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataEngine(Dataset):
 | 
			
		||||
@ -45,73 +33,78 @@ class DataEngine(Dataset):
 | 
			
		||||
        """Data arguments."""
 | 
			
		||||
        self.datasets: dict[str, HFDataset] = {}
 | 
			
		||||
        """Dict of (dataset_name, dataset)"""
 | 
			
		||||
        self.dataset_info: dict[str, DatasetInfo] = {}
 | 
			
		||||
        self.dataset_infos: dict[str, DatasetInfo] = {}
 | 
			
		||||
        """Dict of (dataset_name, dataset_info)"""
 | 
			
		||||
        self.streaming: bool = False
 | 
			
		||||
        """Whether dataset is streaming."""
 | 
			
		||||
        self.data_index: list[tuple[str, int]] = []
 | 
			
		||||
        """List of (dataset_name, sample_index)"""
 | 
			
		||||
        self.data_loader_plugin = DataLoaderPlugin(args=self.args)
 | 
			
		||||
        """Data loader plugin."""
 | 
			
		||||
        self.data_index_plugin = DataIndexPlugin()
 | 
			
		||||
        """Data index plugin."""
 | 
			
		||||
        self.data_getitem_plugin = DataGetItemPlugin(datasets=self.datasets, data_index=self.data_index)
 | 
			
		||||
        """Data getitem plugin."""
 | 
			
		||||
        self.streaming: bool = False
 | 
			
		||||
        """Whether dataset is streaming."""
 | 
			
		||||
        self.get_dataset_info()
 | 
			
		||||
        self.load_dataset()
 | 
			
		||||
        self.build_data_index()
 | 
			
		||||
 | 
			
		||||
    def get_dataset_info(self) -> None:
 | 
			
		||||
        """Get dataset info."""
 | 
			
		||||
        """Get dataset info from data arguments."""
 | 
			
		||||
        if self.args.dataset.endswith(".yaml") and os.path.isfile(
 | 
			
		||||
            os.path.join(self.args.dataset_dir, self.args.dataset)
 | 
			
		||||
        ):  # local file
 | 
			
		||||
            self.dataset_info = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
 | 
			
		||||
            self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
 | 
			
		||||
        elif self.args.dataset.endswith(".yaml"):  # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
 | 
			
		||||
            repo_id, filename = os.path.split(self.args.dataset)
 | 
			
		||||
            filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
 | 
			
		||||
            self.dataset_info = OmegaConf.load(filepath)
 | 
			
		||||
            self.dataset_infos = OmegaConf.load(filepath)
 | 
			
		||||
        elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)):  # local file(s)
 | 
			
		||||
            self.dataset_info = {"default": {"file_name": self.args.dataset}}
 | 
			
		||||
            self.dataset_infos = {"default": {"file_name": self.args.dataset}}
 | 
			
		||||
        else:  # hf hub dataset, e.g. llamafactory/v1-sft-demo
 | 
			
		||||
            self.dataset_info = {"default": {"hf_hub_url": self.args.dataset}}
 | 
			
		||||
            self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
 | 
			
		||||
 | 
			
		||||
    def load_dataset(self) -> None:
 | 
			
		||||
        """Load dataset from dataset info."""
 | 
			
		||||
        for key, value in self.dataset_info.items():
 | 
			
		||||
        """Load datasets according to dataset info."""
 | 
			
		||||
        for key, value in self.dataset_infos.items():
 | 
			
		||||
            split = value.get("split", "train")
 | 
			
		||||
            streaming = value.get("streaming", False)
 | 
			
		||||
            self.streaming |= streaming
 | 
			
		||||
            if "hf_hub_url" in value:
 | 
			
		||||
                self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
 | 
			
		||||
            else:  # data loader plugin
 | 
			
		||||
                self.datasets[key] = self.data_loader_plugin.auto_load_data(value)
 | 
			
		||||
                from ..plugins.data_plugins.loader import DataLoaderPlugin
 | 
			
		||||
 | 
			
		||||
                self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)
 | 
			
		||||
 | 
			
		||||
    def build_data_index(self) -> None:
 | 
			
		||||
        """Build dataset index."""
 | 
			
		||||
        for dataset_name, dataset in self.datasets.items():
 | 
			
		||||
            size = self.dataset_info[dataset_name].get("size")
 | 
			
		||||
            weight = self.dataset_info[dataset_name].get("weight")
 | 
			
		||||
            size = self.dataset_infos[dataset_name].get("size")
 | 
			
		||||
            weight = self.dataset_infos[dataset_name].get("weight")
 | 
			
		||||
            if self.streaming:
 | 
			
		||||
                data_index = [(dataset_name, -1) for _ in range(1000)]
 | 
			
		||||
            else:
 | 
			
		||||
                data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
 | 
			
		||||
 | 
			
		||||
            if size or weight:  # data index plugin
 | 
			
		||||
                data_index = self.data_index_plugin.adjust_data_index(data_index, size, weight)
 | 
			
		||||
                from ..plugins.data_plugins.loader import DataIndexPlugin
 | 
			
		||||
 | 
			
		||||
                data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
 | 
			
		||||
 | 
			
		||||
            self.data_index.extend(data_index)
 | 
			
		||||
 | 
			
		||||
    def get_data_collator(self, processor: Processor) -> DataCollator:
 | 
			
		||||
        """Get data collator.
 | 
			
		||||
    def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
 | 
			
		||||
        """Convert dataset sample.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            processor (Processor): Processor.
 | 
			
		||||
            raw_sample (dict[str, Any]): Raw dataset sample.
 | 
			
		||||
            dataset_name (str): Dataset name.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            DataCollator: Data collator.
 | 
			
		||||
            Sample: Dataset sample.
 | 
			
		||||
        """
 | 
			
		||||
        return DataCollator(processor=processor, dataset_info=self.dataset_info)
 | 
			
		||||
        converter = self.dataset_infos[dataset_name].get("converter")
 | 
			
		||||
        if converter is not None:
 | 
			
		||||
            from ..plugins.data_plugins.converter import get_converter
 | 
			
		||||
 | 
			
		||||
            return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
 | 
			
		||||
        else:
 | 
			
		||||
            return {"_dataset_name": dataset_name, **raw_sample}
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        """Get dataset length.
 | 
			
		||||
@ -124,23 +117,33 @@ class DataEngine(Dataset):
 | 
			
		||||
        else:
 | 
			
		||||
            return len(self.data_index)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index: Union[int, slice, list[int]]) -> Union[dict, list[dict]]:
 | 
			
		||||
    def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
 | 
			
		||||
        """Get dataset item.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            index (int): Dataset index.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            dict: Dataset item.
 | 
			
		||||
            Sample: Dataset item.
 | 
			
		||||
        """
 | 
			
		||||
        if self.streaming:
 | 
			
		||||
            raise ValueError("Streaming dataset does not support index access.")
 | 
			
		||||
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            dataset_name, sample_index = self.data_index[index]
 | 
			
		||||
            return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
 | 
			
		||||
            return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
 | 
			
		||||
        else:
 | 
			
		||||
            return self.data_getitem_plugin.get_data(index)
 | 
			
		||||
            from ..plugins.data_plugins.loader import DataSelectorPlugin
 | 
			
		||||
 | 
			
		||||
            selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
 | 
			
		||||
            if isinstance(selected_index, list):
 | 
			
		||||
                return [
 | 
			
		||||
                    self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
 | 
			
		||||
                    for dataset_name, sample_index in selected_index
 | 
			
		||||
                ]
 | 
			
		||||
            else:
 | 
			
		||||
                dataset_name, sample_index = selected_index
 | 
			
		||||
                return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
 | 
			
		||||
 | 
			
		||||
    def __iter__(self) -> Iterable:
 | 
			
		||||
        """Get dataset iterator.
 | 
			
		||||
@ -156,7 +159,7 @@ class DataEngine(Dataset):
 | 
			
		||||
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def __aiter__(self) -> AsyncIterable:
 | 
			
		||||
    async def __aiter__(self) -> AsyncIterable:
 | 
			
		||||
        """Get dataset async iterator.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
@ -169,3 +172,11 @@ class DataEngine(Dataset):
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    from ..config.parser import get_args
 | 
			
		||||
 | 
			
		||||
    data_args, *_ = get_args()
 | 
			
		||||
    data_engine = DataEngine(data_args=data_args)
 | 
			
		||||
    print(data_engine[0])
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, NotRequired, TypedDict, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -26,7 +26,8 @@ if TYPE_CHECKING:
 | 
			
		||||
    HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
 | 
			
		||||
    DataCollator = transformers.DataCollator
 | 
			
		||||
    DataLoader = torch.utils.data.DataLoader
 | 
			
		||||
    Model = transformers.PreTrainedModel
 | 
			
		||||
    HFModel = transformers.PreTrainedModel
 | 
			
		||||
    DistModel = torch.nn.parallel.DistributedDataParallel
 | 
			
		||||
    Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
 | 
			
		||||
else:
 | 
			
		||||
    Tensor = None
 | 
			
		||||
@ -34,7 +35,8 @@ else:
 | 
			
		||||
    HFDataset = None
 | 
			
		||||
    DataCollator = None
 | 
			
		||||
    DataLoader = None
 | 
			
		||||
    Model = None
 | 
			
		||||
    HFModel = None
 | 
			
		||||
    DistModel = None
 | 
			
		||||
    Processor = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,3 +57,37 @@ class DatasetInfo(TypedDict, total=False):
 | 
			
		||||
    """Dataset weight, default to 1.0."""
 | 
			
		||||
    streaming: NotRequired[bool]
 | 
			
		||||
    """Is streaming dataset, default to False."""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Content(TypedDict):
 | 
			
		||||
    type: Literal["text", "tools", "reasoning", "tool_calls", "image_url"]
 | 
			
		||||
    value: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Message(TypedDict):
 | 
			
		||||
    role: Literal["system", "user", "assistant"]
 | 
			
		||||
    content: list[Content]
 | 
			
		||||
    loss_weight: float
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SFTSample(TypedDict):
 | 
			
		||||
    messages: list[Message]
 | 
			
		||||
    extra_info: NotRequired[str]
 | 
			
		||||
    _dataset_name: NotRequired[str]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DPOSample(TypedDict):
 | 
			
		||||
    chosen_messages: list[Message]
 | 
			
		||||
    rejected_messages: list[Message]
 | 
			
		||||
    extra_info: NotRequired[str]
 | 
			
		||||
    _dataset_name: NotRequired[str]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Sample = Union[SFTSample, DPOSample]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Model(TypedDict):
 | 
			
		||||
    hf_model: HFModel
 | 
			
		||||
    """HF model."""
 | 
			
		||||
    dist_model: DistModel
 | 
			
		||||
    """Distributed model."""
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,71 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from typing import Callable, NotRequired, TypedDict
 | 
			
		||||
 | 
			
		||||
from ...extras.types import Sample, SFTSample
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AlpacaSample(TypedDict, total=False):
 | 
			
		||||
    system: NotRequired[str]
 | 
			
		||||
    instruction: NotRequired[str]
 | 
			
		||||
    input: NotRequired[str]
 | 
			
		||||
    output: NotRequired[str]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
 | 
			
		||||
    """Convert Alpaca sample to SFT sample.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        raw_sample (AlpacaSample): Alpaca sample.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        SFTSample: SFT sample.
 | 
			
		||||
    """
 | 
			
		||||
    messages = []
 | 
			
		||||
    if "system" in raw_sample:
 | 
			
		||||
        messages.append(
 | 
			
		||||
            {"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if "instruction" in raw_sample or "input" in raw_sample:
 | 
			
		||||
        messages.append(
 | 
			
		||||
            {
 | 
			
		||||
                "role": "user",
 | 
			
		||||
                "content": [
 | 
			
		||||
                    {"type": "text", "value": raw_sample.get("instruction", "") + raw_sample.get("input", "")}
 | 
			
		||||
                ],
 | 
			
		||||
                "loss_weight": 0.0,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if "output" in raw_sample:
 | 
			
		||||
        messages.append(
 | 
			
		||||
            {"role": "assistant", "content": [{"type": "text", "value": raw_sample["output"]}], "loss_weight": 1.0}
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return {"messages": messages}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
CONVERTERS = {
 | 
			
		||||
    "alpaca": alpaca_converter,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_converter(converter_name: str) -> Callable[[dict], Sample]:
 | 
			
		||||
    if converter_name not in CONVERTERS:
 | 
			
		||||
        raise ValueError(f"Converter {converter_name} not found.")
 | 
			
		||||
 | 
			
		||||
    return CONVERTERS[converter_name]
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Literal, Optional, Union
 | 
			
		||||
from typing import Any, Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,10 @@ from ...extras.types import DatasetInfo, HFDataset
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataLoaderPlugin:
 | 
			
		||||
    """Plugin for loading dataset."""
 | 
			
		||||
 | 
			
		||||
    args: DataArguments
 | 
			
		||||
    """Data arguments."""
 | 
			
		||||
 | 
			
		||||
    def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
 | 
			
		||||
        """Get dataset builder name.
 | 
			
		||||
@ -66,9 +69,21 @@ class DataLoaderPlugin:
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataIndexPlugin:
 | 
			
		||||
    """Plugin for adjusting dataset index."""
 | 
			
		||||
 | 
			
		||||
    def adjust_data_index(
 | 
			
		||||
        self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
 | 
			
		||||
    ) -> list[tuple[str, int]]:
 | 
			
		||||
        """Adjust dataset index by size and weight.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
 | 
			
		||||
            size (Optional[int]): Desired dataset size.
 | 
			
		||||
            weight (Optional[float]): Desired dataset weight.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list[tuple[str, int]]: Adjusted dataset index.
 | 
			
		||||
        """
 | 
			
		||||
        if size is not None:
 | 
			
		||||
            data_index = self.adjust_by_size(data_index, size)
 | 
			
		||||
 | 
			
		||||
@ -85,18 +100,24 @@ class DataIndexPlugin:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class DataGetItemPlugin:
 | 
			
		||||
    datasets: dict[str, HFDataset]
 | 
			
		||||
class DataSelectorPlugin:
 | 
			
		||||
    """Plugin for selecting dataset samples."""
 | 
			
		||||
 | 
			
		||||
    data_index: list[tuple[str, int]]
 | 
			
		||||
    """List of (dataset_name, sample_index)"""
 | 
			
		||||
 | 
			
		||||
    def _get_by_index(self, index: int) -> dict:
 | 
			
		||||
        dataset_name, sample_index = self.data_index[index]
 | 
			
		||||
        return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
 | 
			
		||||
    def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
 | 
			
		||||
        """Select dataset samples.
 | 
			
		||||
 | 
			
		||||
    def get_data(self, index: Union[slice, list[int]]) -> list[dict]:
 | 
			
		||||
        Args:
 | 
			
		||||
            index (Union[slice, list[int], Any]): Index of dataset samples.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(index, slice):
 | 
			
		||||
            return [self._get_by_index(i) for i in range(*index.indices(len(self.data_index)))]
 | 
			
		||||
            return [self.data_index[i] for i in range(*index.indices(len(self.data_index)))]
 | 
			
		||||
        elif isinstance(index, list):
 | 
			
		||||
            return [self._get_by_index(i) for i in index]
 | 
			
		||||
            return [self.data_index[i] for i in index]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"Invalid index type {type(index)}.")
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,6 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
@ -22,14 +21,11 @@ from llamafactory.v1.config.data_args import DataArguments
 | 
			
		||||
from llamafactory.v1.core.data_engine import DataEngine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/v1-sft-demo")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("num_samples", [16])
 | 
			
		||||
def test_map_dataset(num_samples: int):
 | 
			
		||||
    data_args = DataArguments(dataset=TINY_DATA)
 | 
			
		||||
    data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
 | 
			
		||||
    data_engine = DataEngine(data_args)
 | 
			
		||||
    original_data = load_dataset(TINY_DATA, split="train")
 | 
			
		||||
    original_data = load_dataset("llamafactory/v1-sft-demo", split="train")
 | 
			
		||||
    indexes = random.choices(range(len(data_engine)), k=num_samples)
 | 
			
		||||
    for index in indexes:
 | 
			
		||||
        print(data_engine[index])
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										52
									
								
								tests_v1/plugins/data_plugins/test_converter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								tests_v1/plugins/data_plugins/test_converter.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,52 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import random
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from datasets import load_dataset
 | 
			
		||||
 | 
			
		||||
from llamafactory.v1.config.data_args import DataArguments
 | 
			
		||||
from llamafactory.v1.core.data_engine import DataEngine
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("num_samples", [16])
 | 
			
		||||
def test_alpaca_converter(num_samples: int):
 | 
			
		||||
    data_args = DataArguments(dataset="llamafactory/v1-sft-demo/dataset_info.yaml")
 | 
			
		||||
    data_engine = DataEngine(data_args)
 | 
			
		||||
    original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
 | 
			
		||||
    indexes = random.choices(range(len(data_engine)), k=num_samples)
 | 
			
		||||
    for index in indexes:
 | 
			
		||||
        print(data_engine[index])
 | 
			
		||||
        expected_data = {
 | 
			
		||||
            "messages": [
 | 
			
		||||
                {
 | 
			
		||||
                    "role": "user",
 | 
			
		||||
                    "content": [
 | 
			
		||||
                        {"type": "text", "value": original_data[index]["instruction"] + original_data[index]["input"]}
 | 
			
		||||
                    ],
 | 
			
		||||
                    "loss_weight": 0.0,
 | 
			
		||||
                },
 | 
			
		||||
                {
 | 
			
		||||
                    "role": "assistant",
 | 
			
		||||
                    "content": [{"type": "text", "value": original_data[index]["output"]}],
 | 
			
		||||
                    "loss_weight": 1.0,
 | 
			
		||||
                },
 | 
			
		||||
            ]
 | 
			
		||||
        }
 | 
			
		||||
        assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    test_alpaca_converter(1)
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user