mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			48974783da
			...
			3dbca4b533
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					3dbca4b533 | ||
| 
						 | 
					9d1acbc191 | ||
| 
						 | 
					52e46e162e | 
							
								
								
									
										3
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							@ -43,6 +43,9 @@ jobs:
 | 
				
			|||||||
          - python: "3.9"
 | 
					          - python: "3.9"
 | 
				
			||||||
            os: "ubuntu-latest"
 | 
					            os: "ubuntu-latest"
 | 
				
			||||||
            transformers: "4.53.0"
 | 
					            transformers: "4.53.0"
 | 
				
			||||||
 | 
					        exclude:  # exclude python 3.9 on macos
 | 
				
			||||||
 | 
					          - python: "3.9"
 | 
				
			||||||
 | 
					            os: "macos-latest"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    runs-on: ${{ matrix.os }}
 | 
					    runs-on: ${{ matrix.os }}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										50
									
								
								data/reason_tool_use_demo_50.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								data/reason_tool_use_demo_50.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							@ -12,8 +12,25 @@
 | 
				
			|||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Any
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..config.training_args import TrainingArguments
 | 
					from ..config.training_args import TrainingArguments
 | 
				
			||||||
from ..extras.types import DataCollator, Model, Processor, TorchDataset
 | 
					from ..extras.types import Model, Processor, Tensor, TorchDataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DataCollator:
 | 
				
			||||||
 | 
					    """Default Data collator."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, processor: Processor) -> None:
 | 
				
			||||||
 | 
					        self.processor = processor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
 | 
				
			||||||
 | 
					        """Collate features into a batch."""
 | 
				
			||||||
 | 
					        for feature in features:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # sft: messages
 | 
				
			||||||
 | 
					        # dpo: chosen_messages, rejected_messages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseTrainer:
 | 
					class BaseTrainer:
 | 
				
			||||||
 | 
				
			|||||||
@ -22,19 +22,7 @@ from omegaconf import OmegaConf
 | 
				
			|||||||
from torch.utils.data import Dataset
 | 
					from torch.utils.data import Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..config.data_args import DataArguments
 | 
					from ..config.data_args import DataArguments
 | 
				
			||||||
from ..extras.types import DatasetInfo, HFDataset, Processor, Tensor
 | 
					from ..extras.types import DatasetInfo, HFDataset, Sample
 | 
				
			||||||
from ..plugins.data_plugins.loader import DataGetItemPlugin, DataIndexPlugin, DataLoaderPlugin
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class DataCollator:
 | 
					 | 
				
			||||||
    """Default Data collator."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, processor: Processor, dataset_info: dict[str, DatasetInfo]) -> None:
 | 
					 | 
				
			||||||
        self.processor = processor
 | 
					 | 
				
			||||||
        self.dataset_info = dataset_info
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, Tensor]:
 | 
					 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DataEngine(Dataset):
 | 
					class DataEngine(Dataset):
 | 
				
			||||||
@ -45,73 +33,78 @@ class DataEngine(Dataset):
 | 
				
			|||||||
        """Data arguments."""
 | 
					        """Data arguments."""
 | 
				
			||||||
        self.datasets: dict[str, HFDataset] = {}
 | 
					        self.datasets: dict[str, HFDataset] = {}
 | 
				
			||||||
        """Dict of (dataset_name, dataset)"""
 | 
					        """Dict of (dataset_name, dataset)"""
 | 
				
			||||||
        self.dataset_info: dict[str, DatasetInfo] = {}
 | 
					        self.dataset_infos: dict[str, DatasetInfo] = {}
 | 
				
			||||||
        """Dict of (dataset_name, dataset_info)"""
 | 
					        """Dict of (dataset_name, dataset_info)"""
 | 
				
			||||||
        self.streaming: bool = False
 | 
					 | 
				
			||||||
        """Whether dataset is streaming."""
 | 
					 | 
				
			||||||
        self.data_index: list[tuple[str, int]] = []
 | 
					        self.data_index: list[tuple[str, int]] = []
 | 
				
			||||||
        """List of (dataset_name, sample_index)"""
 | 
					        """List of (dataset_name, sample_index)"""
 | 
				
			||||||
        self.data_loader_plugin = DataLoaderPlugin(args=self.args)
 | 
					        self.streaming: bool = False
 | 
				
			||||||
        """Data loader plugin."""
 | 
					        """Whether dataset is streaming."""
 | 
				
			||||||
        self.data_index_plugin = DataIndexPlugin()
 | 
					 | 
				
			||||||
        """Data index plugin."""
 | 
					 | 
				
			||||||
        self.data_getitem_plugin = DataGetItemPlugin(datasets=self.datasets, data_index=self.data_index)
 | 
					 | 
				
			||||||
        """Data getitem plugin."""
 | 
					 | 
				
			||||||
        self.get_dataset_info()
 | 
					        self.get_dataset_info()
 | 
				
			||||||
        self.load_dataset()
 | 
					        self.load_dataset()
 | 
				
			||||||
        self.build_data_index()
 | 
					        self.build_data_index()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_dataset_info(self) -> None:
 | 
					    def get_dataset_info(self) -> None:
 | 
				
			||||||
        """Get dataset info."""
 | 
					        """Get dataset info from data arguments."""
 | 
				
			||||||
        if self.args.dataset.endswith(".yaml") and os.path.isfile(
 | 
					        if self.args.dataset.endswith(".yaml") and os.path.isfile(
 | 
				
			||||||
            os.path.join(self.args.dataset_dir, self.args.dataset)
 | 
					            os.path.join(self.args.dataset_dir, self.args.dataset)
 | 
				
			||||||
        ):  # local file
 | 
					        ):  # local file
 | 
				
			||||||
            self.dataset_info = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
 | 
					            self.dataset_infos = OmegaConf.load(os.path.join(self.args.dataset_dir, self.args.dataset))
 | 
				
			||||||
        elif self.args.dataset.endswith(".yaml"):  # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
 | 
					        elif self.args.dataset.endswith(".yaml"):  # hf hub uri, e.g. llamafactory/v1-sft-demo/dataset_info.yaml
 | 
				
			||||||
            repo_id, filename = os.path.split(self.args.dataset)
 | 
					            repo_id, filename = os.path.split(self.args.dataset)
 | 
				
			||||||
            filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
 | 
					            filepath = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
 | 
				
			||||||
            self.dataset_info = OmegaConf.load(filepath)
 | 
					            self.dataset_infos = OmegaConf.load(filepath)
 | 
				
			||||||
        elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)):  # local file(s)
 | 
					        elif os.path.exists(os.path.join(self.args.dataset_dir, self.args.dataset)):  # local file(s)
 | 
				
			||||||
            self.dataset_info = {"default": {"file_name": self.args.dataset}}
 | 
					            self.dataset_infos = {"default": {"file_name": self.args.dataset}}
 | 
				
			||||||
        else:  # hf hub dataset, e.g. llamafactory/v1-sft-demo
 | 
					        else:  # hf hub dataset, e.g. llamafactory/v1-sft-demo
 | 
				
			||||||
            self.dataset_info = {"default": {"hf_hub_url": self.args.dataset}}
 | 
					            self.dataset_infos = {"default": {"hf_hub_url": self.args.dataset}}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load_dataset(self) -> None:
 | 
					    def load_dataset(self) -> None:
 | 
				
			||||||
        """Load dataset from dataset info."""
 | 
					        """Load datasets according to dataset info."""
 | 
				
			||||||
        for key, value in self.dataset_info.items():
 | 
					        for key, value in self.dataset_infos.items():
 | 
				
			||||||
            split = value.get("split", "train")
 | 
					            split = value.get("split", "train")
 | 
				
			||||||
            streaming = value.get("streaming", False)
 | 
					            streaming = value.get("streaming", False)
 | 
				
			||||||
            self.streaming |= streaming
 | 
					            self.streaming |= streaming
 | 
				
			||||||
            if "hf_hub_url" in value:
 | 
					            if "hf_hub_url" in value:
 | 
				
			||||||
                self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
 | 
					                self.datasets[key] = load_dataset(value["hf_hub_url"], split=split, streaming=streaming)
 | 
				
			||||||
            else:  # data loader plugin
 | 
					            else:  # data loader plugin
 | 
				
			||||||
                self.datasets[key] = self.data_loader_plugin.auto_load_data(value)
 | 
					                from ..plugins.data_plugins.loader import DataLoaderPlugin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                self.datasets[key] = DataLoaderPlugin(args=self.args).auto_load_data(value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def build_data_index(self) -> None:
 | 
					    def build_data_index(self) -> None:
 | 
				
			||||||
        """Build dataset index."""
 | 
					        """Build dataset index."""
 | 
				
			||||||
        for dataset_name, dataset in self.datasets.items():
 | 
					        for dataset_name, dataset in self.datasets.items():
 | 
				
			||||||
            size = self.dataset_info[dataset_name].get("size")
 | 
					            size = self.dataset_infos[dataset_name].get("size")
 | 
				
			||||||
            weight = self.dataset_info[dataset_name].get("weight")
 | 
					            weight = self.dataset_infos[dataset_name].get("weight")
 | 
				
			||||||
            if self.streaming:
 | 
					            if self.streaming:
 | 
				
			||||||
                data_index = [(dataset_name, -1) for _ in range(1000)]
 | 
					                data_index = [(dataset_name, -1) for _ in range(1000)]
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
 | 
					                data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if size or weight:  # data index plugin
 | 
					            if size or weight:  # data index plugin
 | 
				
			||||||
                data_index = self.data_index_plugin.adjust_data_index(data_index, size, weight)
 | 
					                from ..plugins.data_plugins.loader import DataIndexPlugin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            self.data_index.extend(data_index)
 | 
					            self.data_index.extend(data_index)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_data_collator(self, processor: Processor) -> DataCollator:
 | 
					    def _convert_data_sample(self, raw_sample: dict[str, Any], dataset_name: str) -> Sample:
 | 
				
			||||||
        """Get data collator.
 | 
					        """Convert dataset sample.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            processor (Processor): Processor.
 | 
					            raw_sample (dict[str, Any]): Raw dataset sample.
 | 
				
			||||||
 | 
					            dataset_name (str): Dataset name.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            DataCollator: Data collator.
 | 
					            Sample: Dataset sample.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        return DataCollator(processor=processor, dataset_info=self.dataset_info)
 | 
					        converter = self.dataset_infos[dataset_name].get("converter")
 | 
				
			||||||
 | 
					        if converter is not None:
 | 
				
			||||||
 | 
					            from ..plugins.data_plugins.converter import get_converter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            return {"_dataset_name": dataset_name, **get_converter(converter)(raw_sample)}
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return {"_dataset_name": dataset_name, **raw_sample}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __len__(self) -> int:
 | 
					    def __len__(self) -> int:
 | 
				
			||||||
        """Get dataset length.
 | 
					        """Get dataset length.
 | 
				
			||||||
@ -124,23 +117,33 @@ class DataEngine(Dataset):
 | 
				
			|||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return len(self.data_index)
 | 
					            return len(self.data_index)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, index: Union[int, slice, list[int]]) -> Union[dict, list[dict]]:
 | 
					    def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
 | 
				
			||||||
        """Get dataset item.
 | 
					        """Get dataset item.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            index (int): Dataset index.
 | 
					            index (int): Dataset index.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            dict: Dataset item.
 | 
					            Sample: Dataset item.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if self.streaming:
 | 
					        if self.streaming:
 | 
				
			||||||
            raise ValueError("Streaming dataset does not support index access.")
 | 
					            raise ValueError("Streaming dataset does not support index access.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if isinstance(index, int):
 | 
					        if isinstance(index, int):
 | 
				
			||||||
            dataset_name, sample_index = self.data_index[index]
 | 
					            dataset_name, sample_index = self.data_index[index]
 | 
				
			||||||
            return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
 | 
					            return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return self.data_getitem_plugin.get_data(index)
 | 
					            from ..plugins.data_plugins.loader import DataSelectorPlugin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            selected_index = DataSelectorPlugin(data_index=self.data_index).select(index)
 | 
				
			||||||
 | 
					            if isinstance(selected_index, list):
 | 
				
			||||||
 | 
					                return [
 | 
				
			||||||
 | 
					                    self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
 | 
				
			||||||
 | 
					                    for dataset_name, sample_index in selected_index
 | 
				
			||||||
 | 
					                ]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                dataset_name, sample_index = selected_index
 | 
				
			||||||
 | 
					                return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __iter__(self) -> Iterable:
 | 
					    def __iter__(self) -> Iterable:
 | 
				
			||||||
        """Get dataset iterator.
 | 
					        """Get dataset iterator.
 | 
				
			||||||
@ -156,7 +159,7 @@ class DataEngine(Dataset):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __aiter__(self) -> AsyncIterable:
 | 
					    async def __aiter__(self) -> AsyncIterable:
 | 
				
			||||||
        """Get dataset async iterator.
 | 
					        """Get dataset async iterator.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
@ -169,3 +172,11 @@ class DataEngine(Dataset):
 | 
				
			|||||||
            pass
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        raise NotImplementedError()
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    from ..config.parser import get_args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    data_args, *_ = get_args()
 | 
				
			||||||
 | 
					    data_engine = DataEngine(data_args=data_args)
 | 
				
			||||||
 | 
					    print(data_engine[0])
 | 
				
			||||||
 | 
				
			|||||||
@ -12,7 +12,7 @@
 | 
				
			|||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from typing import TYPE_CHECKING, NotRequired, TypedDict, Union
 | 
					from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if TYPE_CHECKING:
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
@ -26,7 +26,8 @@ if TYPE_CHECKING:
 | 
				
			|||||||
    HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
 | 
					    HFDataset = Union[datasets.Dataset, datasets.IterableDataset]
 | 
				
			||||||
    DataCollator = transformers.DataCollator
 | 
					    DataCollator = transformers.DataCollator
 | 
				
			||||||
    DataLoader = torch.utils.data.DataLoader
 | 
					    DataLoader = torch.utils.data.DataLoader
 | 
				
			||||||
    Model = transformers.PreTrainedModel
 | 
					    HFModel = transformers.PreTrainedModel
 | 
				
			||||||
 | 
					    DistModel = torch.nn.parallel.DistributedDataParallel
 | 
				
			||||||
    Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
 | 
					    Processor = Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]
 | 
				
			||||||
else:
 | 
					else:
 | 
				
			||||||
    Tensor = None
 | 
					    Tensor = None
 | 
				
			||||||
@ -34,7 +35,8 @@ else:
 | 
				
			|||||||
    HFDataset = None
 | 
					    HFDataset = None
 | 
				
			||||||
    DataCollator = None
 | 
					    DataCollator = None
 | 
				
			||||||
    DataLoader = None
 | 
					    DataLoader = None
 | 
				
			||||||
    Model = None
 | 
					    HFModel = None
 | 
				
			||||||
 | 
					    DistModel = None
 | 
				
			||||||
    Processor = None
 | 
					    Processor = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -55,3 +57,37 @@ class DatasetInfo(TypedDict, total=False):
 | 
				
			|||||||
    """Dataset weight, default to 1.0."""
 | 
					    """Dataset weight, default to 1.0."""
 | 
				
			||||||
    streaming: NotRequired[bool]
 | 
					    streaming: NotRequired[bool]
 | 
				
			||||||
    """Is streaming dataset, default to False."""
 | 
					    """Is streaming dataset, default to False."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Content(TypedDict):
 | 
				
			||||||
 | 
					    type: Literal["text", "tools", "reasoning", "tool_calls", "image_url"]
 | 
				
			||||||
 | 
					    value: str
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Message(TypedDict):
 | 
				
			||||||
 | 
					    role: Literal["system", "user", "assistant"]
 | 
				
			||||||
 | 
					    content: list[Content]
 | 
				
			||||||
 | 
					    loss_weight: float
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SFTSample(TypedDict):
 | 
				
			||||||
 | 
					    messages: list[Message]
 | 
				
			||||||
 | 
					    extra_info: NotRequired[str]
 | 
				
			||||||
 | 
					    _dataset_name: NotRequired[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class DPOSample(TypedDict):
 | 
				
			||||||
 | 
					    chosen_messages: list[Message]
 | 
				
			||||||
 | 
					    rejected_messages: list[Message]
 | 
				
			||||||
 | 
					    extra_info: NotRequired[str]
 | 
				
			||||||
 | 
					    _dataset_name: NotRequired[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Sample = Union[SFTSample, DPOSample]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Model(TypedDict):
 | 
				
			||||||
 | 
					    hf_model: HFModel
 | 
				
			||||||
 | 
					    """HF model."""
 | 
				
			||||||
 | 
					    dist_model: DistModel
 | 
				
			||||||
 | 
					    """Distributed model."""
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,71 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Callable, NotRequired, TypedDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ...extras.types import Sample, SFTSample
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AlpacaSample(TypedDict, total=False):
 | 
				
			||||||
 | 
					    system: NotRequired[str]
 | 
				
			||||||
 | 
					    instruction: NotRequired[str]
 | 
				
			||||||
 | 
					    input: NotRequired[str]
 | 
				
			||||||
 | 
					    output: NotRequired[str]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
 | 
				
			||||||
 | 
					    """Convert Alpaca sample to SFT sample.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        raw_sample (AlpacaSample): Alpaca sample.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns:
 | 
				
			||||||
 | 
					        SFTSample: SFT sample.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    messages = []
 | 
				
			||||||
 | 
					    if "system" in raw_sample:
 | 
				
			||||||
 | 
					        messages.append(
 | 
				
			||||||
 | 
					            {"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if "instruction" in raw_sample or "input" in raw_sample:
 | 
				
			||||||
 | 
					        messages.append(
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "role": "user",
 | 
				
			||||||
 | 
					                "content": [
 | 
				
			||||||
 | 
					                    {"type": "text", "value": raw_sample.get("instruction", "") + raw_sample.get("input", "")}
 | 
				
			||||||
 | 
					                ],
 | 
				
			||||||
 | 
					                "loss_weight": 0.0,
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if "output" in raw_sample:
 | 
				
			||||||
 | 
					        messages.append(
 | 
				
			||||||
 | 
					            {"role": "assistant", "content": [{"type": "text", "value": raw_sample["output"]}], "loss_weight": 1.0}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return {"messages": messages}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					CONVERTERS = {
 | 
				
			||||||
 | 
					    "alpaca": alpaca_converter,
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_converter(converter_name: str) -> Callable[[dict], Sample]:
 | 
				
			||||||
 | 
					    if converter_name not in CONVERTERS:
 | 
				
			||||||
 | 
					        raise ValueError(f"Converter {converter_name} not found.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return CONVERTERS[converter_name]
 | 
				
			||||||
@ -15,7 +15,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from typing import Literal, Optional, Union
 | 
					from typing import Any, Literal, Optional, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from datasets import load_dataset
 | 
					from datasets import load_dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -25,7 +25,10 @@ from ...extras.types import DatasetInfo, HFDataset
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class DataLoaderPlugin:
 | 
					class DataLoaderPlugin:
 | 
				
			||||||
 | 
					    """Plugin for loading dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    args: DataArguments
 | 
					    args: DataArguments
 | 
				
			||||||
 | 
					    """Data arguments."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
 | 
					    def _get_builder_name(self, path: str) -> Literal["arrow", "csv", "json", "parquet", "text"]:
 | 
				
			||||||
        """Get dataset builder name.
 | 
					        """Get dataset builder name.
 | 
				
			||||||
@ -66,9 +69,21 @@ class DataLoaderPlugin:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class DataIndexPlugin:
 | 
					class DataIndexPlugin:
 | 
				
			||||||
 | 
					    """Plugin for adjusting dataset index."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def adjust_data_index(
 | 
					    def adjust_data_index(
 | 
				
			||||||
        self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
 | 
					        self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
 | 
				
			||||||
    ) -> list[tuple[str, int]]:
 | 
					    ) -> list[tuple[str, int]]:
 | 
				
			||||||
 | 
					        """Adjust dataset index by size and weight.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
 | 
				
			||||||
 | 
					            size (Optional[int]): Desired dataset size.
 | 
				
			||||||
 | 
					            weight (Optional[float]): Desired dataset weight.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            list[tuple[str, int]]: Adjusted dataset index.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        if size is not None:
 | 
					        if size is not None:
 | 
				
			||||||
            data_index = self.adjust_by_size(data_index, size)
 | 
					            data_index = self.adjust_by_size(data_index, size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -85,18 +100,24 @@ class DataIndexPlugin:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class DataGetItemPlugin:
 | 
					class DataSelectorPlugin:
 | 
				
			||||||
    datasets: dict[str, HFDataset]
 | 
					    """Plugin for selecting dataset samples."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    data_index: list[tuple[str, int]]
 | 
					    data_index: list[tuple[str, int]]
 | 
				
			||||||
 | 
					    """List of (dataset_name, sample_index)"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_by_index(self, index: int) -> dict:
 | 
					    def select(self, index: Union[slice, list[int], Any]) -> Union[tuple[str, int], list[tuple[str, int]]]:
 | 
				
			||||||
        dataset_name, sample_index = self.data_index[index]
 | 
					        """Select dataset samples.
 | 
				
			||||||
        return {"_dataset_name": dataset_name, **self.datasets[dataset_name][sample_index]}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_data(self, index: Union[slice, list[int]]) -> list[dict]:
 | 
					        Args:
 | 
				
			||||||
 | 
					            index (Union[slice, list[int], Any]): Index of dataset samples.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        if isinstance(index, slice):
 | 
					        if isinstance(index, slice):
 | 
				
			||||||
            return [self._get_by_index(i) for i in range(*index.indices(len(self.data_index)))]
 | 
					            return [self.data_index[i] for i in range(*index.indices(len(self.data_index)))]
 | 
				
			||||||
        elif isinstance(index, list):
 | 
					        elif isinstance(index, list):
 | 
				
			||||||
            return [self._get_by_index(i) for i in index]
 | 
					            return [self.data_index[i] for i in index]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise ValueError(f"Invalid index type {type(index)}.")
 | 
					            raise ValueError(f"Invalid index type {type(index)}.")
 | 
				
			||||||
 | 
				
			|||||||
@ -12,7 +12,6 @@
 | 
				
			|||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
@ -22,14 +21,11 @@ from llamafactory.v1.config.data_args import DataArguments
 | 
				
			|||||||
from llamafactory.v1.core.data_engine import DataEngine
 | 
					from llamafactory.v1.core.data_engine import DataEngine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/v1-sft-demo")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@pytest.mark.parametrize("num_samples", [16])
 | 
					@pytest.mark.parametrize("num_samples", [16])
 | 
				
			||||||
def test_map_dataset(num_samples: int):
 | 
					def test_map_dataset(num_samples: int):
 | 
				
			||||||
    data_args = DataArguments(dataset=TINY_DATA)
 | 
					    data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
 | 
				
			||||||
    data_engine = DataEngine(data_args)
 | 
					    data_engine = DataEngine(data_args)
 | 
				
			||||||
    original_data = load_dataset(TINY_DATA, split="train")
 | 
					    original_data = load_dataset("llamafactory/v1-sft-demo", split="train")
 | 
				
			||||||
    indexes = random.choices(range(len(data_engine)), k=num_samples)
 | 
					    indexes = random.choices(range(len(data_engine)), k=num_samples)
 | 
				
			||||||
    for index in indexes:
 | 
					    for index in indexes:
 | 
				
			||||||
        print(data_engine[index])
 | 
					        print(data_engine[index])
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										52
									
								
								tests_v1/plugins/data_plugins/test_converter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								tests_v1/plugins/data_plugins/test_converter.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					# Copyright 2025 the LlamaFactory team.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||||
 | 
					# you may not use this file except in compliance with the License.
 | 
				
			||||||
 | 
					# You may obtain a copy of the License at
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||||
 | 
					# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||||
 | 
					# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||||
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					from datasets import load_dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from llamafactory.v1.config.data_args import DataArguments
 | 
				
			||||||
 | 
					from llamafactory.v1.core.data_engine import DataEngine
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize("num_samples", [16])
 | 
				
			||||||
 | 
					def test_alpaca_converter(num_samples: int):
 | 
				
			||||||
 | 
					    data_args = DataArguments(dataset="llamafactory/v1-sft-demo/dataset_info.yaml")
 | 
				
			||||||
 | 
					    data_engine = DataEngine(data_args)
 | 
				
			||||||
 | 
					    original_data = load_dataset("llamafactory/tiny-supervised-dataset", split="train")
 | 
				
			||||||
 | 
					    indexes = random.choices(range(len(data_engine)), k=num_samples)
 | 
				
			||||||
 | 
					    for index in indexes:
 | 
				
			||||||
 | 
					        print(data_engine[index])
 | 
				
			||||||
 | 
					        expected_data = {
 | 
				
			||||||
 | 
					            "messages": [
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "role": "user",
 | 
				
			||||||
 | 
					                    "content": [
 | 
				
			||||||
 | 
					                        {"type": "text", "value": original_data[index]["instruction"] + original_data[index]["input"]}
 | 
				
			||||||
 | 
					                    ],
 | 
				
			||||||
 | 
					                    "loss_weight": 0.0,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    "role": "assistant",
 | 
				
			||||||
 | 
					                    "content": [{"type": "text", "value": original_data[index]["output"]}],
 | 
				
			||||||
 | 
					                    "loss_weight": 1.0,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    test_alpaca_converter(1)
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user