diff --git a/src/llamafactory/v1/config/__init__.py b/src/llamafactory/v1/config/__init__.py index b9aceeb31..f334d0524 100644 --- a/src/llamafactory/v1/config/__init__.py +++ b/src/llamafactory/v1/config/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from .arg_parser import InputArgument, get_args -from .arg_utils import ModelClass, SampleBackend +from .arg_utils import BatchingStrategy, ModelClass, SampleBackend from .data_args import DataArguments from .model_args import ModelArguments from .sample_args import SampleArguments @@ -21,6 +21,7 @@ from .training_args import TrainingArguments __all__ = [ + "BatchingStrategy", "DataArguments", "InputArgument", "ModelArguments", diff --git a/src/llamafactory/v1/config/arg_utils.py b/src/llamafactory/v1/config/arg_utils.py index 5673bd687..db52f9700 100644 --- a/src/llamafactory/v1/config/arg_utils.py +++ b/src/llamafactory/v1/config/arg_utils.py @@ -50,6 +50,14 @@ class SampleBackend(StrEnum): VLLM = "vllm" +@unique +class BatchingStrategy(StrEnum): + NORMAL = "normal" + PADDING_FREE = "padding_free" + DYNAMIC_BATCHING = "dynamic_batching" + DYNAMIC_PADDING_FREE = "dynamic_padding_free" + + def _convert_str_dict(data: dict) -> dict: """Parse string representation inside the dictionary. diff --git a/src/llamafactory/v1/config/data_args.py b/src/llamafactory/v1/config/data_args.py index c1bd5f23f..f9b5d06a9 100644 --- a/src/llamafactory/v1/config/data_args.py +++ b/src/llamafactory/v1/config/data_args.py @@ -22,7 +22,3 @@ class DataArguments: default=None, metadata={"help": "Path to the dataset."}, ) - cutoff_len: int = field( - default=2048, - metadata={"help": "Cutoff length for the dataset."}, - ) diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 9bee3095b..e937170f7 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -16,7 +16,7 @@ import os from dataclasses import dataclass, field from uuid import uuid4 -from .arg_utils import PluginConfig, get_plugin_config +from .arg_utils import BatchingStrategy, PluginConfig, get_plugin_config @dataclass @@ -29,18 +29,30 @@ class TrainingArguments: default=1, metadata={"help": "Micro batch size for training."}, ) - global_batch_size: int = field( - default=1, - metadata={"help": "Global batch size for training."}, + global_batch_size: int | None = field( + default=None, + metadata={"help": "Global batch size for training, default to DP size * micro batch size."}, ) learning_rate: float = field( default=1e-4, metadata={"help": "Learning rate for training."}, ) + cutoff_len: int = field( + default=2048, + metadata={"help": "Maximum sequence length for training."}, + ) bf16: bool = field( default=False, metadata={"help": "Use bf16 for training."}, ) + batching_strategy: BatchingStrategy = field( + default=BatchingStrategy.NORMAL, + metadata={"help": "Batching strategy for training."}, + ) + batching_workers: int = field( + default=16, + metadata={"help": "Number of workers for batching."}, + ) dist_config: PluginConfig | None = field( default=None, metadata={"help": "Distribution configuration for training."}, diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index 96b2f5e2e..df9af26e4 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -29,7 +29,6 @@ Train Phase: from ..config.training_args import TrainingArguments from ..utils.types import HFModel, TorchDataset -from .utils.data_collator import DataCollator from .utils.rendering import Renderer @@ -45,7 +44,6 @@ class BaseTrainer: self.model = model self.renderer = renderer self.dataset = dataset - self.data_collator = DataCollator() self.optimizer = None self.lr_scheduler = None diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py index 18f14f1b6..60a2667d0 100644 --- a/src/llamafactory/v1/core/data_engine.py +++ b/src/llamafactory/v1/core/data_engine.py @@ -82,14 +82,17 @@ class DataEngine(Dataset): def _load_dataset(self) -> None: """Load datasets according to dataset info.""" + is_streaming = [dataset_info.get("streaming", False) for dataset_info in self.dataset_infos.values()] + self.streaming = any(is_streaming) + if all(is_streaming) != any(is_streaming): + raise ValueError("All datasets must be streaming or non-streaming.") + for dataset_name, dataset_info in self.dataset_infos.items(): split = dataset_info.get("split", "train") - streaming = dataset_info.get("streaming", False) - self.streaming |= streaming if dataset_info.get("source", "hf_hub") == "hf_hub": from datasets import load_dataset - self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=streaming) + self.datasets[dataset_name] = load_dataset(dataset_info["path"], split=split, streaming=self.streaming) else: # data loader plugin from ..plugins.data_plugins.loader import DataLoaderPlugin @@ -98,8 +101,7 @@ class DataEngine(Dataset): def _build_data_index(self) -> None: """Build dataset index.""" for dataset_name, dataset in self.datasets.items(): - streaming = self.dataset_infos[dataset_name].get("streaming", False) - if streaming: + if self.streaming: data_index = [(dataset_name, -1) for _ in range(1000)] else: data_index = [(dataset_name, sample_index) for sample_index in range(len(dataset))] @@ -185,8 +187,8 @@ class DataEngine(Dataset): if __name__ == "__main__": """ - python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_sft_demo.yaml - python -m llamafactory.v1.core.data_engine --model none --dataset data/v1_dpo_demo.yaml + python -m llamafactory.v1.core.data_engine --dataset data/v1_sft_demo.yaml + python -m llamafactory.v1.core.data_engine --dataset data/v1_dpo_demo.yaml """ from ..config.arg_parser import get_args diff --git a/src/llamafactory/v1/core/utils/batching.py b/src/llamafactory/v1/core/utils/batching.py new file mode 100644 index 000000000..7f4c724c6 --- /dev/null +++ b/src/llamafactory/v1/core/utils/batching.py @@ -0,0 +1,244 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Batching utils supports stateful dataloader. + +1. Init stateful dataloader (tokenize) +2. Add to buffer +3. Yield batch indexes (micro batch * grad acc) + a) non pack + non dynamic + b) non pack + dynamic + c) pack + non dynamic + d) pack + dynamic +""" + +from collections.abc import Iterator +from typing import Any + +from torch.utils.data import default_collate +from torchdata.stateful_dataloader import StatefulDataLoader +from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler + +from ...accelerator.interface import DistributedInterface +from ...config import BatchingStrategy +from ...utils import logging +from ...utils.helper import pad_and_truncate +from ...utils.types import BatchInput, ModelInput, TorchDataset +from .rendering import Renderer + + +logger = logging.get_logger(__name__) + + +def default_collate_fn( + buffer: list[ModelInput], buffer_tokens: int, micro_batch_size: int, num_micro_batch: int, cutoff_len: int +) -> tuple[list[ModelInput], int, list[BatchInput]]: + batch_size = micro_batch_size * num_micro_batch + if len(buffer) < batch_size: + return buffer, buffer_tokens, None + + samples = buffer[:batch_size] + buffer = buffer[batch_size:] + buffer_tokens -= sum(len(sample["input_ids"]) for sample in samples) + + batch = [] + for i in range(num_micro_batch): + micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size] + batch.append(default_collate(pad_and_truncate(micro_batch, cutoff_len))) + + return buffer, buffer_tokens, batch + + +class BatchGenerator(Iterator): + def __init__( + self, + dataset: TorchDataset, + renderer: Renderer, + micro_batch_size: int = 1, + global_batch_size: int | None = None, + cutoff_len: int = 2048, + batching_workers: int = 0, + batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL, + pin_memory: bool = True, + drop_last: bool = True, + ) -> None: + self.dataset = dataset + self.renderer = renderer + + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.cutoff_len = cutoff_len + self.batching_workers = batching_workers + self.batching_strategy = batching_strategy + self.pin_memory = pin_memory + self.drop_last = drop_last + # TODO: support length and infinity + + dp_size = DistributedInterface().get_world_size("dp") + + if self.global_batch_size is None: + self.global_batch_size = dp_size * micro_batch_size + self.num_micro_batch = 1 + elif self.global_batch_size % (dp_size * micro_batch_size) == 0: + self.num_micro_batch = global_batch_size // dp_size // micro_batch_size + else: + raise ValueError( + "Global batch size must be divisible by DP size and micro batch size. " + f"Got {global_batch_size} % ({dp_size} * {micro_batch_size}) != 0." + ) + + if not self.drop_last: + raise ValueError("Drop last must be True.") + + self._init_data_provider() + + self._is_resuming: bool = False + self._data_iter = iter(self._data_provider) + self._buffer: list[ModelInput] = [] + self._buffer_tokens: int = 0 + self._max_buffer_tokens: int = self.micro_batch_size * self.num_micro_batch * self.cutoff_len + + logger.info_rank0( + f"Init unified data loader with global batch size {self.global_batch_size}, " + f"micro batch size {self.micro_batch_size}, " + f"num micro batch {self.num_micro_batch}, " + f"cutoff len {self.cutoff_len}, " + f"batching workers {self.batching_workers}, " + f"batching strategy {self.batching_strategy}." + ) + + def _init_data_provider(self) -> None: + if len(self.dataset) != -1: + sampler = StatefulDistributedSampler( + self.dataset, + num_replicas=DistributedInterface().get_world_size("dp"), + rank=DistributedInterface().get_rank("dp"), + shuffle=True, + seed=0, + drop_last=self.drop_last, + ) + else: + raise NotImplementedError("Iterable dataset is not supported yet.") + + self._data_provider = StatefulDataLoader( + self.dataset, + batch_size=self.micro_batch_size * self.num_micro_batch, + sampler=sampler, + num_workers=self.batching_workers, + collate_fn=self.renderer.process_samples, + pin_memory=self.pin_memory, + drop_last=self.drop_last, + ) + if self.batching_strategy == BatchingStrategy.NORMAL: + self._length = len(self._data_provider) + else: + from ...plugins.trainer_plugins.batching import BatchingPlugin + + self._length = BatchingPlugin(self.batching_strategy).compute_length() + raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.") + + def __len__(self) -> int: + return self._length + + def __iter__(self): + if not self._is_resuming: + self._buffer.clear() + self._buffer_tokens = 0 + + self._data_iter = iter(self._data_provider) + self._is_resuming = False + return self + + def __next__(self): + batch = self._next_batch() + if batch is None: + raise StopIteration + + return batch + + def _next_batch(self) -> list[BatchInput] | None: + while self._buffer_tokens < self._max_buffer_tokens: + try: + samples: list[ModelInput] = next(self._data_iter) + except StopIteration: + break + + num_tokens = sum(len(sample["input_ids"]) for sample in samples) + self._buffer.extend(samples) + self._buffer_tokens += num_tokens + + return self._build_batch() + + def _build_batch(self) -> list[BatchInput] | None: + if self.batching_strategy == BatchingStrategy.NORMAL: + self._buffer, self._buffer_tokens, batch = default_collate_fn( + self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len + ) + return batch + else: + from ...plugins.trainer_plugins.batching import BatchingPlugin + + self._buffer, self._buffer_tokens, batch = BatchingPlugin(self.batching_strategy)( + self._buffer, self._buffer_tokens, self.micro_batch_size, self.num_micro_batch, self.cutoff_len + ) + return batch + + def state_dict(self) -> dict[str, Any]: + return { + "buffer": self._buffer, + "buffer_tokens": self._buffer_tokens, + "data_provider": self._data_provider.state_dict(), + } + + def load_state_dict(self, state: dict[str, Any]) -> None: + self._buffer = state["buffer"] + self._buffer_tokens = state["buffer_tokens"] + self._data_provider.load_state_dict(state["data_provider"]) + self._is_resuming = True + + def set_epoch(self, epoch: int) -> None: + if hasattr(self._data_provider.sampler, "set_epoch"): + self._data_provider.sampler.set_epoch(epoch) + + +if __name__ == "__main__": + """ + python -m llamafactory.v1.core.utils.batching \ + --model llamafactory/tiny-random-qwen2.5 \ + --dataset data/v1_sft_demo.yaml \ + --micro_batch_size 2 \ + --global_batch_size 4 \ + --batching_workers 0 + """ + from ...config.arg_parser import get_args + from ..data_engine import DataEngine + from ..model_engine import ModelEngine + + data_args, model_args, training_args, _ = get_args() + data_engine = DataEngine(data_args=data_args) + model_engine = ModelEngine(model_args=model_args) + batch_generator = BatchGenerator( + data_engine, + model_engine.renderer, + micro_batch_size=training_args.micro_batch_size, + global_batch_size=training_args.global_batch_size, + cutoff_len=training_args.cutoff_len, + batching_workers=training_args.batching_workers, + batching_strategy=training_args.batching_strategy, + ) + for batch in batch_generator: + print(batch) + print(len(batch)) + print(batch[0]["input_ids"].shape) + break diff --git a/src/llamafactory/v1/core/utils/data_loader.py b/src/llamafactory/v1/core/utils/data_loader.py deleted file mode 100644 index a3bb9bdbe..000000000 --- a/src/llamafactory/v1/core/utils/data_loader.py +++ /dev/null @@ -1,277 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import copy -import sys -from collections.abc import Generator, Iterator -from dataclasses import dataclass -from typing import Optional - -from torchdata.stateful_dataloader import StatefulDataLoader -from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler - -from ...utils.batching_queue import BaseBatchingQueue -from ...utils.logging import get_logger -from ...utils.types import Processor, TorchDataset -from .data_collator import DataCollator - - -logger = get_logger(__name__) - - -# base dataloader -class DistributedDataloader(StatefulDataLoader): - """Base Distributed DataLoader.""" - - dataset: "TorchDataset" - sampler: "StatefulDistributedSampler" - - def set_epoch(self, epoch: int) -> None: - if self.sampler is not None and hasattr(self.sampler, "set_epoch"): - self.sampler.set_epoch(epoch) - elif hasattr(self.dataset, "set_epoch"): - self.dataset.set_epoch(epoch) - - -@dataclass -class BaseDataLoader: - """Default DataLoader.""" - - processor: Processor - - def __init__(self, dataset: TorchDataset) -> None: - self.dataset = dataset - # guidlines: fetch until get fixed batchsize. - # save state_dict for buffer. - # resume with state - - # 1. Init stateful dataloader (tokenize) - # 2. Add to buffer (2 * max seq len per device) - # 3. Yield batch indexes (micro batch * grad acc) - # a ) non pack + non dynamic - # b ) non pack + dynamic - # c ) pack + non dynamic - # d ) pack + dynamic - - def init_dataloader(self) -> None: - ### init dataloader - pass - - def __iter__(self) -> Iterator: - pass - - def __next__(self) -> any: - pass - - -@dataclass -class DataLoader: - """Default DataLoader.""" - - processor: "Processor" - dataloader: "DistributedDataloader" - batching_queue: "BaseBatchingQueue" - collate_fn: "DataCollator" - num_micro_batch: int = 1 - length: int = 0 - drop_last: bool = True - - def __init__( - self, - dataloader: any, - collate_fn: "DataCollator", - num_micro_batch: int = 1, - length: int = 0, - drop_last: bool = True, - batching_queue: Optional["BaseBatchingQueue"] = None, - ) -> None: - self.batching_queue = batching_queue - self.num_micro_batch = num_micro_batch - self.step = 0 - self._collate_fn = collate_fn - self._dataloader = dataloader - self._drop_last = drop_last - self._data_iter: Iterator - self._resume = False - self._batch_data_iter: Generator - - if length > 0: - self._length = length - elif length == -1: - self._length = sys.maxsize - else: - self._length = len(self._dataloader) - - def __len__(self): - return self._length - - def __iter__(self) -> Iterator: - if not self._resume: - self.step = 0 - self._data_iter = iter(self._dataloader) - self._batch_data_iter = self.batch_data_generator() - self._resume = False - return self - - def __next__(self): - return next(self._batch_data_iter) # FIXME maybe we can move origin_batch_data_generator to here - - def origin_batch_data_generator(self): - """Standard pass-through generator if do not use batching queue.""" - while True: - if self._length > 0 and self.step >= self._length: - return - - try: - batch = [] - data = next(self._data_iter) - # split data into micro batches - for i in range(0, len(data), self.num_micro_batch): - micro_batch = data[i : i + self.num_micro_batch] - if self._collate_fn: - micro_batch = self._collate_fn(micro_batch) - batch.append(micro_batch) - yield batch - self.step += 1 - except StopIteration: - if self.step < self._length: - # Restart iterator to fill the requested length - self._data_iter = iter(self._dataloader) - try: - batch = [] - data = next(self._data_iter) - for i in range(0, len(data), self.num_micro_batch): - micro_batch = data[i : i + self.num_micro_batch] - if self._collate_fn: - micro_batch = self._collate_fn(micro_batch) - batch.append(micro_batch) - yield batch - self.step += 1 - except StopIteration: - return - else: - return - except Exception as e: - logger.error(f"DataLoader origin_batch_data_generator exception: {e}") - raise - - def batch_data_generator(self): - if self.batching_queue is None: - yield from self.origin_batch_data_generator() - return - - batch = [] - - while True: - if self._length and self.step >= self._length: - return - - if self.batching_queue.is_full_filled(): - micro_batch = self.batching_queue.get_micro_batch(self.step) - if self._collate_fn: - micro_batch = self._collate_fn(micro_batch) - batch.append(micro_batch) - if len(batch) == self.num_micro_batch: - yield batch - self.step += 1 - batch = [] - - try: - processing_item = next(self._data_iter) - except Exception as e: - if isinstance(e, StopIteration): - if self.step < self._length: - # call iter until reach length - self._data_iter = iter(self._dataloader) - processing_item = next(self._data_iter) - elif not self._drop_last and not self.batching_queue.empty(): - while not self.batching_queue.empty(): - micro_batch = self.batching_queue.get_micro_batch(self.step) - if self._collate_fn: - micro_batch = self._collate_fn(micro_batch) - batch.append(micro_batch) - if len(batch) == self.num_micro_batch: - yield batch - self.step += 1 - batch = [] - - while len(batch) < self.num_micro_batch: - padding_batch = copy.deepcopy(micro_batch) - padding_batch["is_padded"] = True - batch.append(padding_batch) - yield batch - self.step += 1 - return - else: - return - else: - logger.error(f"DataLoader iter data exception: {e}") - raise - - # put processing_item to buffer - if isinstance(processing_item, dict): - processing_item = [processing_item] - - for item in processing_item: - self.batching_queue.put_item(item) - - def state_dict(self): - # save state - state = self.__dict__.copy() - # remove internal fields - for k in list(state.keys()): - if k.startswith("_"): - del state[k] - - # save dataloader state - if hasattr(self._dataloader, "state_dict"): - state["dataloader_state"] = self._dataloader.state_dict() - elif hasattr(self._dataloader, "__getstate__"): - state["dataloader_state"] = self._dataloader.__getstate__() - - batching_strategy = getattr(self, "batching_strategy", None) - if batching_strategy and hasattr(batching_strategy, "state_dict"): - state["batching_strategy_state"] = batching_strategy.state_dict() - if "batching_strategy" in state: - del state["batching_strategy"] - - return copy.deepcopy(state) - - def load_state_dict(self, state: dict[str, any]): - if state["num_micro_batch"] != self.num_micro_batch: - logger.warning( - f"num_micro_batch changed: [ {state['num_micro_batch']} -> {self.num_micro_batch} ], will clear prefetch buffer" - ) - del state["num_micro_batch"] - self.__dict__.update(state) - self._resume = True - - if hasattr(self._dataloader, "load_state_dict"): - self._dataloader.load_state_dict(state["dataloader_state"]) - elif hasattr(self._dataloader, "__getstate__"): - self._dataloader.__setstate__(state["dataloader_state"]) - - if "batching_strategy_state" in state: - batching_strategy = getattr(self, "batching_strategy", None) - if batching_strategy: - batching_strategy.load_state_dict(state["batching_strategy_state"]) - del state["batching_strategy_state"] - - self._data_iter = iter(self._dataloader) - self._batch_data_iter = self.batch_data_generator() - - def set_epoch(self, epoch: int) -> None: - if hasattr(self._dataloader, "set_epoch"): - self._dataloader.set_epoch(epoch) diff --git a/src/llamafactory/v1/core/utils/rendering.py b/src/llamafactory/v1/core/utils/rendering.py index 15ca4c6b9..b4c0b02d6 100644 --- a/src/llamafactory/v1/core/utils/rendering.py +++ b/src/llamafactory/v1/core/utils/rendering.py @@ -12,10 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Rendering utils. + +How to use: +renderer = Renderer(template, processor) +renderer.render_messages(messages: list[Message], tools: str | None) -> ModelInputs +renderer.parse_message(text: str) -> Message +renderer.process_samples(samples: list[Sample]) -> list[ModelInput] +""" + +import numpy as np from ...utils.constants import IGNORE_INDEX from ...utils.helper import get_tokenizer -from ...utils.types import Message, ModelInput, Processor +from ...utils.types import Message, ModelInput, Processor, Sample def render_chatml_messages( @@ -64,7 +74,7 @@ def render_chatml_messages( def parse_chatml_message(generated_text: str) -> Message: - """Parse a message in ChatML format. Supports interleaved reasoning and tool calls. + """Parse a message in ChatML format. Args: generated_text (str): The generated text in ChatML format. @@ -83,6 +93,16 @@ class Renderer: def render_messages( self, messages: list[Message], tools: str | None = None, is_generate: bool = False ) -> ModelInput: + """Apply template to messages and convert them to model input. + + Args: + messages (list[Message]): The messages to render. + tools (str | None, optional): The tools to use. Defaults to None. + is_generate (bool, optional): Whether to render for generation. Defaults to False. + + Returns: + ModelInput: The rendered model input. + """ if self.template == "chatml": return render_chatml_messages(self.processor, messages, tools, is_generate) else: @@ -91,9 +111,59 @@ class Renderer: return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate) def parse_message(self, generated_text: str) -> Message: + """Parse a message in the template format. + + Args: + generated_text (str): The generated text in the template format. + + Returns: + Message: The parsed message. + """ if self.template == "chatml": return parse_chatml_message(generated_text) else: from ...plugins.model_plugins.rendering import RenderingPlugin return RenderingPlugin(self.template).parse_message(generated_text) + + def process_samples(self, samples: list[Sample]) -> list[ModelInput]: + """Process samples to model input. + + Args: + samples (list[Sample]): The samples to process. + + Returns: + list[ModelInput]: The processed model inputs. + """ + model_inputs = [] + for sample in samples: + if "messages" in sample: + model_input = self.render_messages(sample["messages"], sample.get("tools")) + elif "chosen_messages" in sample and "rejected_messages" in sample: + chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools")) + rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools")) + chosen_input["token_type_ids"] = [0] * len(chosen_input["input_ids"]) + rejected_input["token_type_ids"] = [1] * len(rejected_input["input_ids"]) + model_input = ModelInput( + input_ids=chosen_input["input_ids"] + rejected_input["input_ids"], + attention_mask=chosen_input["attention_mask"] + rejected_input["attention_mask"], + labels=chosen_input["labels"] + rejected_input["labels"], + loss_weights=chosen_input["loss_weights"] + rejected_input["loss_weights"], + token_type_ids=chosen_input["token_type_ids"] + rejected_input["token_type_ids"], + ) + if "position_ids" in chosen_input: + model_input["position_ids"] = np.concatenate( + [chosen_input["position_ids"], rejected_input["position_ids"]], axis=-1 + ) + else: + raise ValueError("No valid messages or chosen_messages/rejected_messages found in sample.") + + if "extra_info" in sample: + model_input["extra_info"] = sample["extra_info"] + + if "_dataset_name" in sample: + model_input["_dataset_name"] = sample["_dataset_name"] + + model_inputs.append(model_input) + + return model_inputs diff --git a/src/llamafactory/v1/plugins/data_plugins/converter.py b/src/llamafactory/v1/plugins/data_plugins/converter.py index 4dfcb316c..7075fe5dc 100644 --- a/src/llamafactory/v1/plugins/data_plugins/converter.py +++ b/src/llamafactory/v1/plugins/data_plugins/converter.py @@ -32,7 +32,8 @@ class AlpacaSample(TypedDict, total=False): SharegptMessage = TypedDict( - "SharegptMessage", {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str} + "SharegptMessage", + {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}, ) @@ -118,15 +119,8 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample: "observation": "tool", "function_call": "assistant", } + sample = {} messages = [] - tools = raw_sample.get("tools") - if tools: - try: - tools: list[dict[str, Any]] = json.loads(tools) - except json.JSONDecodeError: - logger.warning_rank0(f"Invalid tools format: {str(tools)}") - tools = [] - for message in raw_sample.get("conversations", []): tag = message["from"] if tag not in tag_mapping: @@ -157,10 +151,17 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample: } ) + sample["messages"] = messages + + tools = raw_sample.get("tools") if tools: - return {"messages": messages, "tools": json.dumps(tools)} - else: - return {"messages": messages} + try: + tools: list[dict[str, Any]] = json.loads(tools) + sample["tools"] = json.dumps(tools) + except json.JSONDecodeError: + logger.warning_rank0(f"Invalid tools format: {str(tools)}") + + return sample @DataConverterPlugin("pair").register() @@ -179,17 +180,44 @@ def pair_converter(raw_sample: PairSample) -> DPOSample: def process_message(raw_messages: list[OpenaiMessage]): messages = [] for message in raw_messages: - messages.append( - { - "role": message["role"], - "content": [{"type": "text", "value": message["content"]}], - "loss_weight": 1.0 if message["role"] == "assistant" else 0.0, - } - ) + if message["role"] == "tool": + try: + tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"]) + except json.JSONDecodeError: + logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}") + continue + + if not isinstance(tool_calls, list): + tool_calls = [tool_calls] + + messages.append( + { + "role": message["role"], + "content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls], + "loss_weight": 1.0 if message["role"] == "assistant" else 0.0, + } + ) + else: + messages.append( + { + "role": message["role"], + "content": [{"type": "text", "value": message["content"]}], + "loss_weight": 1.0 if message["role"] == "assistant" else 0.0, + } + ) return messages - chosen_messages = process_message(raw_sample.get("chosen", [])) - rejected_messages = process_message(raw_sample.get("rejected", [])) + sample = {} + sample["chosen_messages"] = process_message(raw_sample.get("chosen", [])) + sample["rejected_messages"] = process_message(raw_sample.get("rejected", [])) - return {"chosen_messages": chosen_messages, "rejected_messages": rejected_messages} + tools = raw_sample.get("tools") + if tools: + try: + tools: list[dict[str, Any]] = json.loads(tools) + sample["tools"] = json.dumps(tools) + except json.JSONDecodeError: + logger.warning_rank0(f"Invalid tools format: {str(tools)}") + + return sample diff --git a/src/llamafactory/v1/plugins/model_plugins/rendering.py b/src/llamafactory/v1/plugins/model_plugins/rendering.py index c341f7a42..1ada523e9 100644 --- a/src/llamafactory/v1/plugins/model_plugins/rendering.py +++ b/src/llamafactory/v1/plugins/model_plugins/rendering.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import json import re @@ -51,12 +50,16 @@ def _update_model_input( @RenderingPlugin("qwen3_nothink").register("render_messages") -def render_qwen_messages( +def render_qwen3_nothink_messages( processor: Processor, messages: list[Message], tools: str | None = None, is_generate: bool = False, ) -> ModelInput: + """Render messages in the Qwen3 nothink template format. + + See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507 + """ input_ids, labels, loss_weights = [], [], [] temp_str, temp_weight = "", 0.0 if tools: @@ -179,7 +182,15 @@ def render_qwen_messages( @RenderingPlugin("qwen3_nothink").register("parse_message") -def parse_qwen_message(generated_text: str) -> Message: +def parse_qwen3_nothink_message(generated_text: str) -> Message: + """Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls. + + Args: + generated_text (str): The generated text in the Qwen3 nothink template format. + + Returns: + Message: The parsed message. + """ pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*\s*", re.DOTALL) content = [] last_end = 0 diff --git a/src/llamafactory/v1/plugins/trainer_plugins/batching.py b/src/llamafactory/v1/plugins/trainer_plugins/batching.py new file mode 100644 index 000000000..fa09dcff4 --- /dev/null +++ b/src/llamafactory/v1/plugins/trainer_plugins/batching.py @@ -0,0 +1,19 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils.plugin import BasePlugin + + +class BatchingPlugin(BasePlugin): + pass diff --git a/src/llamafactory/v1/trainers/sft_trainer.py b/src/llamafactory/v1/trainers/sft_trainer.py index 2cb8d3a3c..3a91c4f4c 100644 --- a/src/llamafactory/v1/trainers/sft_trainer.py +++ b/src/llamafactory/v1/trainers/sft_trainer.py @@ -14,7 +14,7 @@ from ..accelerator.interface import DistributedInterface -from ..config.arg_parser import get_args +from ..config import InputArgument, get_args from ..core.base_trainer import BaseTrainer from ..core.data_engine import DataEngine from ..core.model_engine import ModelEngine @@ -24,15 +24,15 @@ class SFTTrainer(BaseTrainer): pass -def run_sft(user_args): - model_args, data_args, training_args, _ = get_args(user_args) +def run_sft(args: InputArgument = None): + model_args, data_args, training_args, _ = get_args(args) DistributedInterface(training_args.dist_config) data_engine = DataEngine(data_args) model_engine = ModelEngine(model_args) trainer = SFTTrainer( args=training_args, model=model_engine.model, - processor=model_engine.processor, + renderer=model_engine.renderer, dataset=data_engine, ) trainer.fit() diff --git a/src/llamafactory/v1/utils/batching_queue.py b/src/llamafactory/v1/utils/batching_queue.py deleted file mode 100644 index ce71a6d29..000000000 --- a/src/llamafactory/v1/utils/batching_queue.py +++ /dev/null @@ -1,220 +0,0 @@ -# Copyright 2025 Bytedance Ltd. and the LlamaFactory team. -# -# This code is inspired by the Bytedance's VeOmni library. -# https://github.com/ByteDance-Seed/VeOmni/blob/v0.1.4/veomni/data/dynamic_batching.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod - - -class DynamicBatchSizeBuffer: - """A buffer to store samples for dynamic batch size.""" - - def __init__(self): - self._buffer: list[dict[str, any]] = [] - self._buffer_sample_lengths: list[int] = [] - self._deleted_indices: set[int] = set() - self._current_index: int = 0 - self._total_token_count: int = 0 - - def append(self, item: dict[str, any]) -> None: - """Append a sample to the buffer. - - Args: - item: A sample to append to the buffer. - The sample should be a dict with the following keys: - - input_ids: torch.Tensor of shape (seq_len, ) - - attention_mask: torch.Tensor of shape (seq_len, ) - """ - self._buffer.append(item) - sample_length = int(item["attention_mask"].sum().item()) - self._buffer_sample_lengths.append(sample_length) - self._total_token_count += sample_length - - def get_samples(self, max_tokens_per_iteration: int, force: bool = True) -> list[dict[str, any]]: - """Get samples from the buffer that fit within the token budget. - - Args: - max_tokens_per_iteration: Maximum number of tokens to retrieve. - force: If True, the first available sample will be returned even - if it exceeds the token budget. - - Returns: - A list of samples that fit within the token budget. - - Raises: - AssertionError: If no samples are found (should not happen in normal operation). - """ - cum_seq_len = 0 - samples = [] - - while self._current_index < len(self._buffer) and cum_seq_len < max_tokens_per_iteration: - if self._current_index in self._deleted_indices: - self._current_index += 1 - continue - - seq_len = self._buffer_sample_lengths[self._current_index] - remaining_tokens = max_tokens_per_iteration - cum_seq_len - - # Check if we can add this sample - can_add = (force and cum_seq_len == 0) or (seq_len <= remaining_tokens) - - if can_add: - cum_seq_len += seq_len - samples.append(self._buffer[self._current_index]) - self._deleted_indices.add(self._current_index) - - self._current_index += 1 - - assert len(samples) > 0, "No samples found in buffer" - return samples - - def __len__(self) -> int: - """Return the number of samples in the buffer.""" - return len(self._buffer) - - @property - def total_token_count(self) -> int: - """Return the total number of tokens in the buffer.""" - return self._total_token_count - - def flush(self) -> None: - tokens_to_remove = sum(self._buffer_sample_lengths[idx] for idx in self._deleted_indices) - self._total_token_count -= tokens_to_remove - - buffer_length = len(self._buffer) - self._buffer = [self._buffer[idx] for idx in range(buffer_length) if idx not in self._deleted_indices] - self._buffer_sample_lengths = [ - self._buffer_sample_lengths[idx] for idx in range(buffer_length) if idx not in self._deleted_indices - ] - - self._current_index = 0 - self._deleted_indices.clear() - - -class BaseBatchingQueue(ABC): - """Base class for batching queue.""" - - @abstractmethod - def is_full_filled(self) -> bool: - raise NotImplementedError("Subclasses must implement `is_full_filled`") - - @abstractmethod - def put_item(self, item: dict[str, any]) -> None: - raise NotImplementedError("Subclasses must implement `put_item`") - - @abstractmethod - def get_micro_batch(self, step: int) -> list[dict[str, any]]: - raise NotImplementedError("Subclasses must implement `get_micro_batch`") - - @abstractmethod - def empty(self) -> bool: - raise NotImplementedError("Subclasses must implement `empty`") - - -class IdentityPacker: - def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken): - self.token_micro_bsz = token_micro_bsz - self.bsz_warmup_steps = bsz_warmup_steps - self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken - - def __call__(self, samples): - return samples - - def get_token_num_to_request(self, cur_step, warmup): - return ( - (self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps - + self.bsz_warmup_init_mbtoken - if warmup - else self.token_micro_bsz - ) - - -class TextBatchingQueue(BaseBatchingQueue): - """Batching text queue for text data.""" - - def __init__( - self, - token_micro_bsz, - buffer_size: int = 500, - bsz_warmup_steps: int = -1, - bsz_warmup_init_mbtoken: int = 200, - ) -> None: - super().__init__() - self._step = 0 - self.token_micro_bsz = token_micro_bsz - self.bsz_warmup_steps = bsz_warmup_steps - self.buffer_size = buffer_size # minimum samples in buffer - self.buffer = DynamicBatchSizeBuffer() - self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken # training warmup args - assert self.bsz_warmup_init_mbtoken >= 0 - - self.packer = IdentityPacker( - token_micro_bsz=token_micro_bsz, - bsz_warmup_steps=bsz_warmup_steps, - bsz_warmup_init_mbtoken=bsz_warmup_init_mbtoken, - ) - - def is_full_filled(self) -> bool: - return len(self.buffer) >= self.buffer_size and self.buffer.total_token_count >= self.token_micro_bsz - - def put_item(self, item: dict[str, any]): - if len(item["input_ids"]) == 1: - print("WARNING: EMPTY STRING.") - return - self.buffer.append(item) - - def get_token_num_to_request(self): - if self.packer is not None: - warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0 - return self.packer.get_token_num_to_request(self._step, warmup=warmup) - else: - return self.get_cur_token_micro_bsz() - - def get_cur_token_micro_bsz(self): - warmup = self._step <= self.bsz_warmup_steps and self.bsz_warmup_steps > 0 - if warmup: - return ( - self.token_micro_bsz - self.bsz_warmup_init_mbtoken - ) * self._step // self.bsz_warmup_steps + self.bsz_warmup_init_mbtoken - else: - return self.token_micro_bsz - - def get_micro_batch(self, step) -> any: - """Get a micro batch from the buffer according to the current step. - - Args: - step: the current step. - - Returns: - data: a list of samples. - """ - self._step = step - n_token_per_iter = self.get_token_num_to_request() - cur_token_micro_bsz = self.get_cur_token_micro_bsz() - assert cur_token_micro_bsz % n_token_per_iter == 0, ( - "The token num to get for each request should be divisible by token micro bsz." - ) - n_iter = int(cur_token_micro_bsz // n_token_per_iter) - data = [] - for _ in range(n_iter): - samples = self.buffer.get_samples(n_token_per_iter) - if self.packer: - samples = self.packer(samples) # maybe packed into one sample, but wrapped in list. - data.extend(samples) - self.buffer.flush() # remove the selected samples. - return data - - def empty(self) -> bool: - return len(self.buffer) == 0 diff --git a/src/llamafactory/v1/utils/dtype.py b/src/llamafactory/v1/utils/dtype.py index f3f262007..331c9bddf 100644 --- a/src/llamafactory/v1/utils/dtype.py +++ b/src/llamafactory/v1/utils/dtype.py @@ -32,8 +32,8 @@ class DtypeRegistry: class DtypeInterface: """Type of precision used.""" - _is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface.current_accelerator) - _is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface.current_accelerator) + _is_fp16_available = is_torch_fp16_available_on_device(DistributedInterface().current_device) + _is_bf16_available = is_torch_bf16_available_on_device(DistributedInterface().current_device) _is_fp32_available = True @staticmethod diff --git a/src/llamafactory/v1/utils/helper.py b/src/llamafactory/v1/utils/helper.py index 7cbd9f336..8b94e71a7 100644 --- a/src/llamafactory/v1/utils/helper.py +++ b/src/llamafactory/v1/utils/helper.py @@ -12,9 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import torch from transformers import PreTrainedTokenizer -from .types import Processor +from .constants import IGNORE_INDEX +from .types import BatchInput, ModelInput, Processor, Tensor + + +def is_tokenizer(processor: Processor) -> bool: + """Check if processor is tokenizer. + + Args: + processor: Processor. + + Returns: + Whether processor is tokenizer. + """ + return not hasattr(processor, "tokenizer") def get_tokenizer(processor: Processor) -> PreTrainedTokenizer: @@ -27,3 +42,34 @@ def get_tokenizer(processor: Processor) -> PreTrainedTokenizer: Tokenizer. """ return processor.tokenizer if hasattr(processor, "tokenizer") else processor + + +def _pad_and_truncate(tensor: Tensor, max_seqlen: int, pad_value: int = 0) -> Tensor: + if tensor.shape[-1] >= max_seqlen: + return tensor[..., :max_seqlen] + + pad_shape = list(tensor.shape) + pad_shape[-1] = max_seqlen - tensor.shape[-1] + pad_tensor = torch.full(pad_shape, pad_value, dtype=tensor.dtype, device=tensor.device) + return torch.cat([tensor, pad_tensor], dim=-1) + + +def pad_and_truncate(samples: list[ModelInput], max_seqlen: int) -> list[BatchInput]: + max_length = min(max(len(sample["input_ids"]) for sample in samples), max_seqlen) + padded_samples = [] + for sample in samples: + padded_sample = {} + for key, value in sample.items(): + if "label" in key: + pad_value = IGNORE_INDEX + else: + pad_value = 0 + + if not isinstance(value, str): + padded_sample[key] = _pad_and_truncate(torch.tensor(value), max_length, pad_value) + else: + padded_sample[key] = value + + padded_samples.append(padded_sample) + + return padded_samples diff --git a/src/llamafactory/v1/utils/types.py b/src/llamafactory/v1/utils/types.py index 25699d7fc..fdd2509b3 100644 --- a/src/llamafactory/v1/utils/types.py +++ b/src/llamafactory/v1/utils/types.py @@ -144,3 +144,20 @@ class ModelInput(TypedDict, total=False): """Loss weight for each token, default to 1.0.""" position_ids: NotRequired[list[int] | list[list[int]]] """Position ids for the model (optional).""" + token_type_ids: NotRequired[list[int]] + """Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages.""" + + +class BatchInput(TypedDict, total=False): + input_ids: Tensor + """Input ids for the model.""" + attention_mask: Tensor + """Attention mask for the model.""" + labels: Tensor + """Labels for the model.""" + loss_weights: Tensor + """Loss weight for each token, default to 1.0.""" + position_ids: NotRequired[Tensor] + """Position ids for the model (optional).""" + token_type_ids: NotRequired[Tensor] + """Token type ids used in DPO, 0 represents the chosen messages, 1 represents the rejected messages.""" diff --git a/tests/version.txt b/tests/version.txt index 58a2ff1ca..e082f373b 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.5.103 +0.9.5.104 diff --git a/tests_v1/accelerator/test_interface.py b/tests_v1/accelerator/test_interface.py index 8bce16b56..38ec83b39 100644 --- a/tests_v1/accelerator/test_interface.py +++ b/tests_v1/accelerator/test_interface.py @@ -56,4 +56,5 @@ def test_all_device(): @pytest.mark.require_distributed(2) def test_multi_device(): master_port = find_available_port() - mp.spawn(_all_reduce_tests, args=(2, master_port), nprocs=2) + world_size = 2 + mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size) diff --git a/tests_v1/core/test_model_loader.py b/tests_v1/core/test_model_loader.py index 5f0150b63..bf426c93b 100644 --- a/tests_v1/core/test_model_loader.py +++ b/tests_v1/core/test_model_loader.py @@ -14,28 +14,24 @@ import torch -from llamafactory.v1.config.model_args import ModelArguments, PluginConfig +from llamafactory.v1.config.model_args import ModelArguments from llamafactory.v1.core.model_engine import ModelEngine def test_tiny_qwen(): - from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast - - model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5") + model_args = ModelArguments(model="llamafactory/tiny-random-qwen3") model_engine = ModelEngine(model_args) - assert isinstance(model_engine.processor, Qwen2TokenizerFast) - assert isinstance(model_engine.model_config, Qwen2Config) - assert isinstance(model_engine.model, Qwen2ForCausalLM) + assert "Qwen2Tokenizer" in model_engine.processor.__class__.__name__ + assert "Qwen3Config" in model_engine.model_config.__class__.__name__ + assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__ assert model_engine.model.dtype == torch.bfloat16 def test_tiny_qwen_with_kernel_plugin(): - from transformers import Qwen2ForCausalLM - from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import npu_rms_norm_forward model_args = ModelArguments( - model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto") + model="llamafactory/tiny-random-qwen3", kernel_config={"name": "auto", "include_kernels": "auto"} ) model_engine = ModelEngine(model_args) # test enable apply kernel plugin @@ -44,7 +40,7 @@ def test_tiny_qwen_with_kernel_plugin(): else: assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__ - assert isinstance(model_engine.model, Qwen2ForCausalLM) + assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__ if __name__ == "__main__": diff --git a/tests_v1/core/utils/test_batching.py b/tests_v1/core/utils/test_batching.py new file mode 100644 index 000000000..9d3461337 --- /dev/null +++ b/tests_v1/core/utils/test_batching.py @@ -0,0 +1,49 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArguments +from llamafactory.v1.core.data_engine import DataEngine +from llamafactory.v1.core.model_engine import ModelEngine +from llamafactory.v1.core.utils.batching import BatchGenerator + + +def test_normal_batching(): + data_args = DataArguments(dataset="llamafactory/v1-sft-demo") + data_engine = DataEngine(data_args=data_args) + model_args = ModelArguments(model="llamafactory/tiny-random-qwen3") + model_engine = ModelEngine(model_args=model_args) + training_args = TrainingArguments( + micro_batch_size=4, + global_batch_size=8, + cutoff_len=10, + batching_workers=0, + batching_strategy="normal", + ) + batch_generator = BatchGenerator( + data_engine, + model_engine.renderer, + micro_batch_size=training_args.micro_batch_size, + global_batch_size=training_args.global_batch_size, + cutoff_len=training_args.cutoff_len, + batching_workers=training_args.batching_workers, + batching_strategy=training_args.batching_strategy, + ) + assert len(batch_generator) == len(data_engine) // training_args.global_batch_size + batch = next(iter(batch_generator)) + assert len(batch) == 2 + assert batch[0]["input_ids"].shape == (4, 10) + + +if __name__ == "__main__": + test_normal_batching() diff --git a/tests_v1/core/utils/test_data_loader.py b/tests_v1/core/utils/test_data_loader.py deleted file mode 100644 index cbddf0887..000000000 --- a/tests_v1/core/utils/test_data_loader.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Integration tests for DataLoader with different combinations of packing and dynamic batching. - -Tests the 4 scenarios: -a) non pack + non dynamic. -b) non pack + dynamic. -c) pack + non dynamic. -d) pack + dynamic. -""" - -# import torch -# from torch.utils.data import DataLoader as TorchDataLoader -# from torch.utils.data import Dataset -# from transformers import AutoTokenizer - -# from llamafactory.v1.config.data_args import DataArguments -# from llamafactory.v1.core.data_engine import DataEngine -# from llamafactory.v1.core.utils.data_collator import DefaultCollator -# from llamafactory.v1.core.utils.data_loader import DataLoader -# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate -# from llamafactory.v1.utils.batching_queue import TextBatchingQueue - - -# class TensorDataset(Dataset): -# """Wrapper dataset that converts DataEngine samples to tensor format.""" - -# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None): -# self.data_engine = data_engine -# self.processor = processor -# self.template = template -# self.max_samples = max_samples or len(data_engine) -# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor - -# def __len__(self): -# return min(self.max_samples, len(self.data_engine)) - -# def __getitem__(self, idx): -# # Get sample from DataEngine -# sample = self.data_engine[idx] - -# # Extract messages from sample -# # DataEngine returns samples with format like {"messages": [...], ...} -# # For llamafactory/v1-sft-demo, the format should have "messages" field -# messages = None -# if "messages" in sample: -# messages = sample["messages"] -# elif "conversations" in sample: -# messages = sample["conversations"] -# elif "conversation" in sample: -# messages = sample["conversation"] -# else: -# # Try to find message-like fields (skip _dataset_name) -# for key, value in sample.items(): -# if key.startswith("_"): -# continue -# if isinstance(value, list) and len(value) > 0: -# # Check if it looks like a message list -# if isinstance(value[0], dict) and "role" in value[0]: -# messages = value -# break - -# if messages is None: -# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}") - -# # Encode messages using template -# encoded = self.template.encode_messages(self.tokenizer, messages) - -# # Convert to tensors -# return { -# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), -# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), -# "labels": torch.tensor(encoded["labels"], dtype=torch.long), -# } - - -# def create_real_dataset(max_samples: int = 20, batch_size: int = 4): -# """Create a real dataset using DataEngine.""" -# data_args = DataArguments(dataset="llamafactory/v1-sft-demo") -# data_engine = DataEngine(data_args) - -# # Create processor and template -# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5") -# template = QwenTemplate() - -# # Create tensor dataset -# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples) - -# # Create torch DataLoader -# torch_dataloader = TorchDataLoader( -# raw_data_dataset, -# batch_size=batch_size, -# shuffle=False, -# collate_fn=lambda x: x, -# ) - -# return torch_dataloader, processor, template - - -# class TestDataLoaderNonPackNonDynamic: -# """Test case a) non pack + non dynamic.""" - -# def test_basic_functionality(self): -# """Test DataLoader without packing and without dynamic batching.""" -# # Create real dataset -# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8) - -# # Create collator (non-packing) -# collator = DefaultCollator(processor=processor, template=template) - -# # Create DataLoader without batching_queue (non-dynamic) -# data_loader = DataLoader( -# dataloader=torch_dataloader, -# collate_fn=collator, -# num_micro_batch=1, -# batching_queue=None, -# ) - -# # Iterate and check results -# batches = list(iter(data_loader)) -# assert len(batches) > 0 - -# # Check first batch -# one_batch = batches[0] -# micro_batches = one_batch[0] -# assert "input_ids" in micro_batches -# assert "attention_mask" in micro_batches -# assert "labels" in micro_batches -# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1 -# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len] - - -# class TestDataLoaderNonPackDynamic: -# """Test case b) non pack + dynamic.""" - -# def test_basic_functionality(self): -# """Test DataLoader without packing but with dynamic batching.""" -# # Create real dataset -# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8) -# collator = DefaultCollator(processor=processor, template=template) - -# # Create batching queue for dynamic batching -# batching_queue = TextBatchingQueue( -# token_micro_bsz=120, -# buffer_size=8, -# ) - -# data_loader = DataLoader( -# dataloader=torch_dataloader, -# collate_fn=collator, -# num_micro_batch=4, -# batching_queue=batching_queue, -# ) - -# # Iterate and check -# batches = list(iter(data_loader)) -# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]] -# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first) -# assert len(batches) > 0 diff --git a/tests_v1/core/utils/test_rendering.py b/tests_v1/core/utils/test_rendering.py index fa17660d6..40dd12532 100644 --- a/tests_v1/core/utils/test_rendering.py +++ b/tests_v1/core/utils/test_rendering.py @@ -184,6 +184,40 @@ def test_qwen3_nothink_rendering_remote(num_samples: int): assert v1_inputs["input_ids"][: len(prefix)] == prefix +def test_process_sft_samples(): + tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") + renderer = Renderer(template="chatml", processor=tokenizer) + hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES) + + samples = [{"messages": V1_MESSAGES, "extra_info": "test", "_dataset_name": "default"}] + model_inputs = renderer.process_samples(samples) + assert len(model_inputs) == 1 + assert model_inputs[0]["input_ids"] == hf_inputs + assert model_inputs[0]["extra_info"] == "test" + assert model_inputs[0]["_dataset_name"] == "default" + + +def test_process_dpo_samples(): + tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") + renderer = Renderer(template="chatml", processor=tokenizer) + hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES) + + samples = [ + { + "chosen_messages": V1_MESSAGES, + "rejected_messages": V1_MESSAGES, + "extra_info": "test", + "_dataset_name": "default", + } + ] + model_inputs = renderer.process_samples(samples) + assert len(model_inputs) == 1 + assert model_inputs[0]["input_ids"] == hf_inputs * 2 + assert model_inputs[0]["token_type_ids"] == [0] * len(hf_inputs) + [1] * len(hf_inputs) + assert model_inputs[0]["extra_info"] == "test" + assert model_inputs[0]["_dataset_name"] == "default" + + if __name__ == "__main__": test_chatml_rendering() test_chatml_parse() @@ -191,3 +225,5 @@ if __name__ == "__main__": test_qwen3_nothink_rendering() test_qwen3_nothink_parse() test_qwen3_nothink_rendering_remote(16) + test_process_sft_samples() + test_process_dpo_samples() diff --git a/tests_v1/plugins/model_plugins/test_init_plugin.py b/tests_v1/plugins/model_plugins/test_init_plugin.py index d13294c8e..1b1ec104c 100644 --- a/tests_v1/plugins/model_plugins/test_init_plugin.py +++ b/tests_v1/plugins/model_plugins/test_init_plugin.py @@ -21,7 +21,7 @@ from llamafactory.v1.core.model_engine import ModelEngine def test_init_on_meta(): _, model_args, *_ = get_args( dict( - model="llamafactory/tiny-random-qwen2.5", + model="llamafactory/tiny-random-qwen3", init_config={"name": "init_on_meta"}, ) ) @@ -32,7 +32,7 @@ def test_init_on_meta(): def test_init_on_rank0(): _, model_args, *_ = get_args( dict( - model="llamafactory/tiny-random-qwen2.5", + model="llamafactory/tiny-random-qwen3", init_config={"name": "init_on_rank0"}, ) ) @@ -46,7 +46,7 @@ def test_init_on_rank0(): def test_init_on_default(): _, model_args, *_ = get_args( dict( - model="llamafactory/tiny-random-qwen2.5", + model="llamafactory/tiny-random-qwen3", init_config={"name": "init_on_default"}, ) ) diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index f087a822f..90524b7ba 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -43,7 +43,7 @@ def test_apply_kernel(mock_get_accelerator: MagicMock): reload_kernels() from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels - model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3") original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_swiglu_forward = model.model.layers[0].mlp.forward model = apply_default_kernels(model=model, include_kernels="npu_fused_rmsnorm") @@ -62,7 +62,7 @@ def test_apply_all_kernels(mock_get_accelerator: MagicMock): reload_kernels() from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_default_kernels - model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") + model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen3") original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_swiglu_forward = model.model.layers[0].mlp.forward diff --git a/tests_v1/utils/test_batching_queue.py b/tests_v1/utils/test_batching_queue.py deleted file mode 100644 index 0f9a68224..000000000 --- a/tests_v1/utils/test_batching_queue.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from llamafactory.v1.utils.batching_queue import DynamicBatchSizeBuffer, TextBatchingQueue - - -def create_sample(length: int): - """Helper to create a mock sample with a specific token length.""" - return {"input_ids": torch.ones(length), "attention_mask": torch.ones(length)} - - -class TestDynamicBatchSizeBuffer: - def test_append_and_token_count(self): - buffer = DynamicBatchSizeBuffer() - buffer.append(create_sample(10)) - buffer.append(create_sample(20)) - - assert len(buffer) == 2 - assert buffer.total_token_count == 30 - - def test_get_samples_within_budget(self): - buffer = DynamicBatchSizeBuffer() - buffer.append(create_sample(10)) - buffer.append(create_sample(10)) - buffer.append(create_sample(50)) # This one is large - - # Request 25 tokens. Should get the first two (20 tokens total) - samples = buffer.get_samples(max_tokens_per_iteration=25) - assert len(samples) == 2 - - def test_force_return_first_sample(self): - buffer = DynamicBatchSizeBuffer() - buffer.append(create_sample(100)) - - # Even though budget is 50, force=True (default) should return the 100-token sample - samples = buffer.get_samples(max_tokens_per_iteration=50, force=True) - assert len(samples) == 1 - assert len(samples[0]["input_ids"]) == 100 - - def test_flush_removes_used_samples(self): - buffer = DynamicBatchSizeBuffer() - buffer.append(create_sample(10)) - buffer.append(create_sample(20)) - - # Take the first sample - buffer.get_samples(max_tokens_per_iteration=15) - buffer.flush() - - assert len(buffer) == 1 - assert buffer.total_token_count == 20 - # The remaining sample should now be at the start - remaining = buffer.get_samples(max_tokens_per_iteration=50) - assert len(remaining[0]["input_ids"]) == 20 - - -class TestTextBatchingQueue: - def test_is_full_filled(self): - queue = TextBatchingQueue(token_micro_bsz=100, buffer_size=2) - - queue.put_item(create_sample(10)) - assert not queue.is_full_filled() # Only 1 sample, buffer_size=2 - - queue.put_item(create_sample(10)) - assert not queue.is_full_filled() # 2 samples, but only 20 tokens (min 100) - - queue.put_item(create_sample(90)) - assert queue.is_full_filled() # Meets both conditions - - def test_warmup_logic(self): - # token_micro_bsz=1000, starts at 200, reaches 1000 at step 10 - queue = TextBatchingQueue(token_micro_bsz=1000, bsz_warmup_steps=10, bsz_warmup_init_mbtoken=200) - - # Step 0: should be init value - assert queue.get_cur_token_micro_bsz() == 200 - - # Step 5: halfway through warmup (200 + (800 * 5/10)) = 600 - queue._step = 5 - assert queue.get_cur_token_micro_bsz() == 600 - - # Step 11: past warmup - queue._step = 11 - assert queue.get_cur_token_micro_bsz() == 1000 - - def test_get_micro_batch_integration(self): - queue = TextBatchingQueue(token_micro_bsz=50, buffer_size=1) - queue.put_item(create_sample(20)) - queue.put_item(create_sample(20)) - queue.put_item(create_sample(20)) - - # At step 0 (warmup not triggered as bsz_warmup_steps is -1 default), - # it should take samples up to 50 tokens. - batch = queue.get_micro_batch(step=0) - - assert len(batch) == 2 - assert queue.empty() is False - - batch_2 = queue.get_micro_batch(step=1) - assert len(batch_2) == 1 - assert queue.empty() is True