diff --git a/examples/v1/train_batching_strategy/train_full_fsdp2_batching_normal.yaml b/examples/v1/train_batching_strategy/train_full_fsdp2_batching_normal.yaml new file mode 100644 index 000000000..e14d52fd7 --- /dev/null +++ b/examples/v1/train_batching_strategy/train_full_fsdp2_batching_normal.yaml @@ -0,0 +1,31 @@ +model: Qwen/Qwen3-0.6B +model_class: llm + +template: qwen3_nothink + + +kernel_config: + name: auto + include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null + +quant_config: null + +dist_config: + name: fsdp2 + dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: outputs/test_fsdp2 +micro_batch_size: 2 +batching_strategy: normal + +cutoff_len: 2048 +learning_rate: 1.0e-4 +max_steps: 10 + +### sample +sample_backend: hf +max_new_tokens: 128 diff --git a/examples/v1/train_batching_strategy/train_full_fsdp2_dynamic_batching.yaml b/examples/v1/train_batching_strategy/train_full_fsdp2_dynamic_batching.yaml new file mode 100644 index 000000000..ff0dc2cc2 --- /dev/null +++ b/examples/v1/train_batching_strategy/train_full_fsdp2_dynamic_batching.yaml @@ -0,0 +1,30 @@ +model: Qwen/Qwen3-0.6B +model_class: llm + +template: qwen3_nothink + +kernel_config: + name: auto + include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null + +quant_config: null + +dist_config: + name: fsdp2 + dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp + +### data +train_dataset: data/v1_sft_demo.yaml + + +### training +output_dir: outputs/test_fsdp2 +micro_batch_size: 2 +batching_strategy: dynamic_batching +cutoff_len: 2048 +learning_rate: 1.0e-4 +max_steps: 10 + +### sample +sample_backend: hf +max_new_tokens: 128 diff --git a/examples/v1/train_batching_strategy/train_full_fsdp2_padding_free.yaml b/examples/v1/train_batching_strategy/train_full_fsdp2_padding_free.yaml new file mode 100644 index 000000000..b841cca80 --- /dev/null +++ b/examples/v1/train_batching_strategy/train_full_fsdp2_padding_free.yaml @@ -0,0 +1,30 @@ +model: Qwen/Qwen3-0.6B +model_class: llm + +template: qwen3_nothink + +kernel_config: + name: auto + include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null + +quant_config: null + +dist_config: + name: fsdp2 + dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp + +### data +train_dataset: data/v1_sft_demo.yaml + +### training +output_dir: outputs/test_fsdp2 +micro_batch_size: 4 +batching_strategy: padding_free +flash_attn: fa2 +cutoff_len: 2048 +learning_rate: 1.0e-4 +max_steps: 10 + +### sample +sample_backend: hf +max_new_tokens: 128 diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 3dc7fd730..b3826a7a9 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -20,7 +20,6 @@ import sys import time from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field from datetime import timedelta from typing import TYPE_CHECKING, Any, Optional @@ -584,7 +583,7 @@ class ModuleProfilerCallback(TrainerCallback): if matched: logger.info_rank0( f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}" - + (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "") + + (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "") ) else: logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}") @@ -616,7 +615,7 @@ class ModuleProfilerCallback(TrainerCallback): bwd = self._backward_times.get(name, []) fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0 bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0 - lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}") + lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}") logger.info_rank0("\n".join(lines)) self._forward_times.clear() diff --git a/src/llamafactory/v1/config/__init__.py b/src/llamafactory/v1/config/__init__.py index f334d0524..4bb5fc32c 100644 --- a/src/llamafactory/v1/config/__init__.py +++ b/src/llamafactory/v1/config/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ..utils.types import AttentionFunction from .arg_parser import InputArgument, get_args from .arg_utils import BatchingStrategy, ModelClass, SampleBackend from .data_args import DataArguments @@ -21,6 +22,7 @@ from .training_args import TrainingArguments __all__ = [ + "AttentionFunction", "BatchingStrategy", "DataArguments", "InputArgument", diff --git a/src/llamafactory/v1/config/model_args.py b/src/llamafactory/v1/config/model_args.py index 4ab1e561f..7cdcd2d4c 100644 --- a/src/llamafactory/v1/config/model_args.py +++ b/src/llamafactory/v1/config/model_args.py @@ -15,6 +15,7 @@ from dataclasses import dataclass, field +from ..utils.types import AttentionFunction from .arg_utils import ModelClass, PluginConfig, get_plugin_config @@ -32,6 +33,12 @@ class ModelArguments: default=False, metadata={"help": "Trust remote code from Hugging Face."}, ) + flash_attn: AttentionFunction = field( + default=AttentionFunction.SDPA, + metadata={ + "help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models." + }, + ) model_class: ModelClass = field( default=ModelClass.LLM, metadata={"help": "Model class from Hugging Face."}, @@ -54,6 +61,12 @@ class ModelArguments: ) def __post_init__(self) -> None: + supported_flash_attn = [item.value for item in AttentionFunction] + if self.flash_attn not in supported_flash_attn: + raise ValueError( + f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}." + ) + self.init_config = get_plugin_config(self.init_config) self.peft_config = get_plugin_config(self.peft_config) self.kernel_config = get_plugin_config(self.kernel_config) diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index 8ede106f8..0b5fc1ff2 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -116,3 +116,9 @@ class TrainingArguments: self.dist_config = get_plugin_config(self.dist_config) self.optim_config = get_plugin_config(self.optim_config) self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config) + + if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING): + if self.max_steps is None or self.max_steps <= 0: + raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.") + if self.save_epochs is not None: + raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.") diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index ff1a60539..c2eb2bebd 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -34,7 +34,7 @@ import torch.nn.functional as F from ..accelerator.helper import ReduceOp from ..accelerator.interface import Dim, DistributedInterface -from ..config import TrainingArguments +from ..config import BatchingStrategy, TrainingArguments from ..utils import logging from ..utils.callbacks import ( CallbackHandler, @@ -147,13 +147,19 @@ class BaseTrainer: from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin if model.config._attn_implementation != "flash_attention_2": - logger.warning_rank0( - "Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2." + raise ValueError( + "Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`." ) - model.config._attn_implementation = "flash_attention_2" + SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config) def _create_batch_generator(self) -> None: + if ( + self.args.batching_strategy == BatchingStrategy.PADDING_FREE + and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2" + ): + raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.") + self.train_batch_generator = BatchGenerator( dataset=self.train_dataset, renderer=self.renderer, @@ -237,6 +243,7 @@ class BaseTrainer: self.train_batch_generator.set_epoch(epoch) self.callback_handler.on_epoch_begin(self.args, self.state) + # BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step. for micro_batches in self.train_batch_generator: self.global_step += 1 diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/core/model_engine.py index 24e97b0d7..2f6f30d2c 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -120,6 +120,7 @@ class ModelEngine: init_device = DistributedInterface().current_device init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device} + logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.") if self.args.quant_config is not None: from ..plugins.model_plugins.quantization import QuantizationPlugin @@ -164,6 +165,7 @@ class ModelEngine: self.args.model, config=self.model_config, dtype="auto", + attn_implementation=self.args.flash_attn, trust_remote_code=self.args.trust_remote_code, **init_kwargs, ) diff --git a/src/llamafactory/v1/core/utils/batching.py b/src/llamafactory/v1/core/utils/batching.py index 3b1ea5c8f..7ddd70c55 100644 --- a/src/llamafactory/v1/core/utils/batching.py +++ b/src/llamafactory/v1/core/utils/batching.py @@ -42,6 +42,8 @@ from .rendering import Renderer logger = logging.get_logger(__name__) +__all__ = ["BatchGenerator"] + def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: micro_batch_size = batch_info["micro_batch_size"] @@ -102,19 +104,18 @@ class BatchGenerator(Iterator): if not self.drop_last: raise ValueError("Drop last must be True.") + self._batch_info: BatchInfo = { + "micro_batch_size": self.micro_batch_size, + "num_micro_batch": self.num_micro_batch, + "cutoff_len": self.cutoff_len, + } + self._init_data_provider() self._is_resuming: bool = False self._data_iter = iter(self._data_provider) self._buffer = StatefulBuffer() - self._batch_info: BatchInfo = { - "micro_batch_size": self.micro_batch_size, - "num_micro_batch": self.num_micro_batch, - "cutoff_len": self.cutoff_len, - "data_iter": self._data_iter, - } - logger.info_rank0( f"Init unified data loader with global batch size {self.global_batch_size}, " f"micro batch size {self.micro_batch_size}, " @@ -137,12 +138,19 @@ class BatchGenerator(Iterator): else: raise NotImplementedError("Iterable dataset is not supported yet.") + if self.batching_strategy == BatchingStrategy.NORMAL: + batch_size = self.micro_batch_size * self.num_micro_batch + else: + from ...plugins.trainer_plugins.batching import BatchingPlugin + + batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info) + generator_seed = torch.Generator() generator_seed.manual_seed(self.seed) self._data_provider = StatefulDataLoader( self.dataset, - batch_size=self.micro_batch_size * self.num_micro_batch, + batch_size=batch_size, sampler=sampler, num_workers=self.batching_workers, collate_fn=self.renderer.process_samples, @@ -156,8 +164,7 @@ class BatchGenerator(Iterator): else: from ...plugins.trainer_plugins.batching import BatchingPlugin - self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider) - raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.") + self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info) def __len__(self) -> int: return self._length @@ -190,7 +197,7 @@ class BatchGenerator(Iterator): else: from ...plugins.trainer_plugins.batching import BatchingPlugin - BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info) + BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples) def _generate_batch(self) -> list[BatchInput] | None: if self.batching_strategy == BatchingStrategy.NORMAL: @@ -200,6 +207,20 @@ class BatchGenerator(Iterator): return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info) + def _next_samples(self, restart: bool) -> list[ModelInput] | None: + try: + return next(self._data_iter) + except StopIteration: + if not restart: + return None + + # Dynamic batching may restart the provider to fill one token-budgeted batch. + self._data_iter = iter(self._data_provider) + try: + return next(self._data_iter) + except StopIteration: + return None + def state_dict(self) -> dict[str, Any]: return { "buffer": self._buffer.state_dict(), diff --git a/src/llamafactory/v1/plugins/trainer_plugins/batching.py b/src/llamafactory/v1/plugins/trainer_plugins/batching.py index aef22eac2..7c945c020 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/batching.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/batching.py @@ -12,23 +12,197 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable +from math import ceil +from typing import Any + +import torch +from torch.utils.data import default_collate + +from ...utils.constants import IGNORE_INDEX +from ...utils.helper import pad_and_truncate from ...utils.objects import StatefulBuffer from ...utils.plugin import BasePlugin -from ...utils.types import BatchInfo, BatchInput, DataLoader +from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput class BatchingPlugin(BasePlugin): - def compute_length(self, data_provider: DataLoader) -> int: + def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int: + """Return the raw data provider batch size for this batching strategy.""" + return self["get_data_provider_batch_size"](batch_info) + + def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int: """Compute the length of the batch generator. The approximate length is used to calculate the lr schedule. """ - raise NotImplementedError() + return self["compute_length"](data_provider, batch_info) - def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None: + def fill_buffer( + self, + buffer: StatefulBuffer, + batch_info: BatchInfo, + next_samples: Callable[[bool], list[ModelInput] | None], + ) -> None: """Fill the buffer with data.""" - raise NotImplementedError() + return self["fill_buffer"](buffer, batch_info, next_samples) def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: """Generate a batch from the buffer.""" - raise NotImplementedError() + return self["generate_batch"](buffer, batch_info) + + +def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]: + """Return sample counts for micro batches formed by one padded-token budget.""" + budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"] + cutoff_len = batch_info["cutoff_len"] + sizes = [] + index = 0 + while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]: + max_sample_len = 0 + used = 0 + is_complete = False + while index + used < len(samples): + sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len) + padded_tokens = max(max_sample_len, sample_len) * (used + 1) + if used > 0 and padded_tokens > budget: + is_complete = True + break + + max_sample_len = max(max_sample_len, sample_len) + used += 1 + if max_sample_len * used >= budget: + is_complete = True + break + + if used == 0 or not is_complete: + break + + sizes.append(used) + index += used + + return sizes + + +def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None: + """Pack fixed samples into one padding-free sequence without a token budget.""" + packed: dict[str, list[Any]] = {} + position_ids: list[int] = [] + + for sample_index, sample in enumerate(samples): + # Padding-free still truncates each sample by cutoff_len before packing + # all samples into one contiguous sequence. + sample_len = min(len(sample["input_ids"]), cutoff_len) + if sample_len <= 0: + continue + + for key, value in sample.items(): + if key in ("attention_mask", "position_ids") or isinstance(value, str): + continue + + if key not in packed: + packed[key] = [] + + sliced_value = list(value[:sample_len]) + if sample_index > 0 and sliced_value: + if key == "labels": + sliced_value[0] = IGNORE_INDEX + elif key == "loss_weights": + sliced_value[0] = 0.0 + + packed[key].extend(sliced_value) + + position_ids.extend(range(sample_len)) + + if not position_ids: + return None + + packed["position_ids"] = position_ids + packed["attention_mask"] = [1] * len(position_ids) + return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()} + + +@BatchingPlugin("padding_free").register("get_data_provider_batch_size") +def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int: + return batch_info["micro_batch_size"] * batch_info["num_micro_batch"] + + +@BatchingPlugin("padding_free").register("compute_length") +def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int: + return len(data_provider) + + +@BatchingPlugin("padding_free").register("fill_buffer") +def fill_padding_free_buffer( + buffer: StatefulBuffer, + batch_info: BatchInfo, + next_samples: Callable[[bool], list[ModelInput] | None], +) -> None: + while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]: + samples = next_samples(False) + if samples is None: + break + + buffer.put(samples) + + +@BatchingPlugin("padding_free").register("generate_batch") +def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: + micro_batch_size = batch_info["micro_batch_size"] + num_micro_batch = batch_info["num_micro_batch"] + cutoff_len = batch_info["cutoff_len"] + batch_size = micro_batch_size * num_micro_batch + if len(buffer) < batch_size: + return None + + samples = buffer.get(batch_size) + batch = [] + for i in range(num_micro_batch): + micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size] + packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len) + if packed_micro_batch is None: + return None + + batch.append(packed_micro_batch) + + return batch + + +@BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size") +def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int: + return 1 + + +@BatchingPlugin("dynamic_batching").register("compute_length") +def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int: + batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"] + return ceil(len(data_provider) / batch_size) + + +@BatchingPlugin("dynamic_batching").register("fill_buffer") +def fill_dynamic_batching_buffer( + buffer: StatefulBuffer, + batch_info: BatchInfo, + next_samples: Callable[[bool], list[ModelInput] | None], +) -> None: + while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]: + samples = next_samples(True) + if samples is None: + break + + buffer.put(samples) + + +@BatchingPlugin("dynamic_batching").register("generate_batch") +def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: + micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) + if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]: + return None + + batch = [] + cutoff_len = batch_info["cutoff_len"] + for num_samples in micro_batch_sample_counts: + samples = buffer.get(num_samples) + batch.append(default_collate(pad_and_truncate(samples, cutoff_len))) + + return batch diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py index 2a4a9e392..b4d0babf0 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/hub.py @@ -61,6 +61,9 @@ def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt @DistributedPlugin("deepspeed").register() def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel: + if dist_config.get("cp_size", 1) > 1: + raise ValueError("CP currently requires `dist_config.name: fsdp2`.") + from .deepspeed import DeepSpeedEngine return DeepSpeedEngine( diff --git a/src/llamafactory/v1/utils/objects.py b/src/llamafactory/v1/utils/objects.py index 338f52365..fc7ea1da5 100644 --- a/src/llamafactory/v1/utils/objects.py +++ b/src/llamafactory/v1/utils/objects.py @@ -33,6 +33,10 @@ class StatefulBuffer: def size(self) -> int: return self._buffer_size + @property + def samples(self) -> list[ModelInput]: + return self._buffer + def put(self, samples: list[ModelInput]) -> None: """Add samples to the buffer.""" num_tokens = sum(len(sample["input_ids"]) for sample in samples) diff --git a/src/llamafactory/v1/utils/types.py b/src/llamafactory/v1/utils/types.py index b1f7d52cf..0cb57938b 100644 --- a/src/llamafactory/v1/utils/types.py +++ b/src/llamafactory/v1/utils/types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterator +from enum import StrEnum, unique from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union @@ -54,6 +54,13 @@ else: ProcessGroup = None +@unique +class AttentionFunction(StrEnum): + EAGER = "eager" + SDPA = "sdpa" + FLASH_ATTENTION_2 = "flash_attention_2" + + class DatasetInfo(TypedDict, total=False): path: str """Local file path.""" @@ -171,8 +178,6 @@ class BatchInfo(TypedDict): """Number of micro batches.""" cutoff_len: int """Cutoff length.""" - data_iter: Iterator[list[ModelInput]] - """Data iterator.""" class ModelOutput(NamedTuple): diff --git a/tests_v1/accelerator/test_interface.py b/tests_v1/accelerator/test_interface.py index d3838f8b1..38ec83b39 100644 --- a/tests_v1/accelerator/test_interface.py +++ b/tests_v1/accelerator/test_interface.py @@ -58,10 +58,3 @@ def test_multi_device(): master_port = find_available_port() world_size = 2 mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size) - - -if __name__ == "__main__": - """ - python tests_v1/accelerator/test_interface.py - """ - test_all_device() diff --git a/tests_v1/config/test_args_parser.py b/tests_v1/config/test_args_parser.py index db235ab54..d60bbeb17 100644 --- a/tests_v1/config/test_args_parser.py +++ b/tests_v1/config/test_args_parser.py @@ -70,13 +70,3 @@ def test_get_args_from_yaml(tmp_path: Path): assert training_args.bf16 is False assert training_args.dist_config is None assert sample_args.sample_backend == "hf" - - -if __name__ == "__main__": - """ - python -m tests_v1.config.test_args_parser - """ - import tempfile - - with tempfile.TemporaryDirectory() as tmp_dir: - test_get_args_from_yaml(tmp_path=Path(tmp_dir)) diff --git a/tests_v1/core/test_data_engine.py b/tests_v1/core/test_data_engine.py index 373069a66..aef337d43 100644 --- a/tests_v1/core/test_data_engine.py +++ b/tests_v1/core/test_data_engine.py @@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int): for index in indexes: print(data_engine[index]) assert data_engine[index] == {"_dataset_name": "default", **original_data[index]} - - -if __name__ == "__main__": - """ - python -m tests_v1.core.test_data_engine - """ - test_map_dataset(1) diff --git a/tests_v1/core/test_model_loader.py b/tests_v1/core/test_model_loader.py index 6228a3699..96dbdd2c1 100644 --- a/tests_v1/core/test_model_loader.py +++ b/tests_v1/core/test_model_loader.py @@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin(): assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__ assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__ - - -if __name__ == "__main__": - """ - python -m tests_v1.core.test_model_loader - """ - test_tiny_qwen() - test_tiny_qwen_with_kernel_plugin() diff --git a/tests_v1/core/utils/test_batching.py b/tests_v1/core/utils/test_batching.py index 87e8a89cb..f379d0763 100644 --- a/tests_v1/core/utils/test_batching.py +++ b/tests_v1/core/utils/test_batching.py @@ -16,6 +16,159 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume from llamafactory.v1.core.data_engine import DataEngine from llamafactory.v1.core.model_engine import ModelEngine from llamafactory.v1.core.utils.batching import BatchGenerator +from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes +from llamafactory.v1.utils.constants import IGNORE_INDEX +from llamafactory.v1.utils.objects import StatefulBuffer + + +def _make_model_input(length: int, start: int = 0): + input_ids = list(range(start, start + length)) + return { + "input_ids": input_ids, + "attention_mask": [1] * length, + "labels": input_ids.copy(), + "loss_weights": [1.0] * length, + } + + +class _RestartableDataProvider: + def __init__(self, batches): + self.batches = batches + self.num_iters = 0 + + def __iter__(self): + self.num_iters += 1 + return iter(self.batches) + + +def test_padding_free(): + buffer = StatefulBuffer() + # Input samples: + # sample 0 input_ids: [0, 1] + # sample 1 input_ids: [10, 11, 12, 13] + buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)]) + batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3} + + batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info) + + # Output batch: + # sample 1 is truncated to [10, 11, 12] + # both samples are packed into one sequence: [[0, 1, 10, 11, 12]] + assert batch is not None + assert len(batch) == 1 + assert batch[0]["input_ids"].shape == (1, 5) + assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]] + assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]] + assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]] + assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]] + assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]] + assert len(buffer) == 0 + + +def test_batching_plugin_data_provider_batch_sizes(): + batch_info = { + "micro_batch_size": 2, + "num_micro_batch": 3, + "cutoff_len": 10, + } + + assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6 + assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1 + + +def test_dynamic_batching(): + # Input samples: + # sample lengths: [3, 4, 6, 2, 8, 9] + # input_ids: + # [0, 1, 2] + # [10, 11, 12, 13] + # [20, 21, 22, 23, 24, 25] + # [30, 31] + # [40, 41, 42, 43, 44, 45, 46, 47] + # [50, 51, 52, 53, 54, 55, 56, 57, 58] + samples = [ + _make_model_input(3, 0), + _make_model_input(4, 10), + _make_model_input(6, 20), + _make_model_input(2, 30), + _make_model_input(8, 40), + _make_model_input(9, 50), + ] + batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10} + + # Dynamic batching output plan: + # dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size + # as the padded-token budget for one training micro batch. + # [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it. + assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3] + + buffer = StatefulBuffer() + buffer.put(samples) + batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) + + assert batch is not None + assert len(batch) == 1 + assert batch[0]["input_ids"].shape == (3, 6) + assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0] + assert len(buffer) == 3 + + +def test_dynamic_batching_returns_none_when_token_budget_is_incomplete(): + buffer = StatefulBuffer() + # Input buffer: + # only one sample with length [6]. + # cutoff_len * micro_batch_size gives a padded-token budget of 20. + # this buffer has not filled the budget and has no next sample to prove overflow, + # so dynamic batching cannot produce a batch yet. + buffer.put([_make_model_input(6, 0)]) + batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10} + + assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == [] + assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None + # Batch generation does not read from the data iterator. It only returns None and keeps + # existing samples in the buffer; BatchGenerator._fill_buffer handles refilling. + assert len(buffer) == 1 + + +def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete(): + # Input data provider: + # each iterator pass yields one sample with length [6]. + # each yielded item is a list[ModelInput], matching BatchGenerator._next_samples. + # _fill_buffer keeps restarting the iterator until the next appended sample + # proves that the previous dynamic micro batch has reached its budget boundary. + samples = [_make_model_input(6, 0)] + data_provider = _RestartableDataProvider([[sample] for sample in samples]) + + batch_generator = BatchGenerator.__new__(BatchGenerator) + batch_generator.batching_strategy = "dynamic_batching" + batch_generator.micro_batch_size = 2 + batch_generator.num_micro_batch = 1 + batch_generator._buffer = StatefulBuffer() + batch_generator._data_provider = data_provider + batch_generator._data_iter = iter(data_provider) + batch_generator._batch_info = { + "micro_batch_size": 2, + "num_micro_batch": 1, + "cutoff_len": 10, + } + + batch_generator._fill_buffer() + + # Filled buffer after restart: + # existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch + # because adding it to the first dynamic micro batch would exceed the budget. + assert data_provider.num_iters == 4 + assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3] + + batch = batch_generator._generate_batch() + + # Output batch: + # dynamic batching returns [micro_batch_0] + # micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6]. + assert batch is not None + assert len(batch) == 1 + assert batch[0]["input_ids"].shape == (3, 6) + assert len(batch_generator._buffer) == 1 def test_normal_batching(): @@ -43,10 +196,3 @@ def test_normal_batching(): batch = next(iter(batch_generator)) assert len(batch) == 2 assert batch[0]["input_ids"].shape == (4, 10) - - -if __name__ == "__main__": - """ - python -m tests_v1.core.utils.test_batching - """ - test_normal_batching() diff --git a/tests_v1/core/utils/test_rendering.py b/tests_v1/core/utils/test_rendering.py index 7e4797805..38ba2cc8f 100644 --- a/tests_v1/core/utils/test_rendering.py +++ b/tests_v1/core/utils/test_rendering.py @@ -227,17 +227,3 @@ def test_process_dpo_samples(): assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs) assert model_inputs[0]["extra_info"] == "test" assert model_inputs[0]["_dataset_name"] == "default" - - -if __name__ == "__main__": - """ - python -m tests_v1.core.utils.test_rendering - """ - test_chatml_rendering() - test_chatml_parse() - test_chatml_rendering_remote(16) - test_qwen3_nothink_rendering() - test_qwen3_nothink_parse() - test_qwen3_nothink_rendering_remote(16) - test_process_sft_samples() - test_process_dpo_samples() diff --git a/tests_v1/plugins/data_plugins/test_converter.py b/tests_v1/plugins/data_plugins/test_converter.py index 1722b4a67..b74c704df 100644 --- a/tests_v1/plugins/data_plugins/test_converter.py +++ b/tests_v1/plugins/data_plugins/test_converter.py @@ -117,12 +117,3 @@ def test_pair_converter(num_samples: int): ], } assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data} - - -if __name__ == "__main__": - """ - python -m tests_v1.plugins.data_plugins.test_converter - """ - test_alpaca_converter(1) - test_sharegpt_converter() - test_pair_converter(1) diff --git a/tests_v1/plugins/model_plugins/test_init_plugin.py b/tests_v1/plugins/model_plugins/test_init_plugin.py index 947f18bd9..ddfb03303 100644 --- a/tests_v1/plugins/model_plugins/test_init_plugin.py +++ b/tests_v1/plugins/model_plugins/test_init_plugin.py @@ -52,12 +52,3 @@ def test_init_on_default(): ) model_engine = ModelEngine(model_args=model_args) assert model_engine.model.device == DistributedInterface().current_device - - -if __name__ == "__main__": - """ - python tests_v1/plugins/model_plugins/test_init_plugin.py - """ - test_init_on_meta() - test_init_on_rank0() - test_init_on_default() diff --git a/tests_v1/sampler/test_cli_sampler.py b/tests_v1/sampler/test_cli_sampler.py index 9f858e1f9..68d05aec3 100644 --- a/tests_v1/sampler/test_cli_sampler.py +++ b/tests_v1/sampler/test_cli_sampler.py @@ -35,10 +35,3 @@ def test_sync_sampler(): "role": "assistant", "content": [{"type": "text", "value": "This is a test."}], } - - -if __name__ == "__main__": - """ - python tests_v1/sampler/test_cli_sampler.py - """ - test_sync_sampler()