[v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469)

This commit is contained in:
jiaqiw09
2026-05-21 17:14:19 +08:00
committed by GitHub
parent 7e20db5735
commit bdcb92d035
23 changed files with 507 additions and 105 deletions

View File

@@ -0,0 +1,31 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: normal
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: dynamic_batching
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: padding_free
flash_attn: fa2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -20,7 +20,6 @@ import sys
import time import time
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@@ -584,7 +583,7 @@ class ModuleProfilerCallback(TrainerCallback):
if matched: if matched:
logger.info_rank0( logger.info_rank0(
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}" f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "") + (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
) )
else: else:
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}") logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
@@ -616,7 +615,7 @@ class ModuleProfilerCallback(TrainerCallback):
bwd = self._backward_times.get(name, []) bwd = self._backward_times.get(name, [])
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0 fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0 bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}") lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
logger.info_rank0("\n".join(lines)) logger.info_rank0("\n".join(lines))
self._forward_times.clear() self._forward_times.clear()

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ..utils.types import AttentionFunction
from .arg_parser import InputArgument, get_args from .arg_parser import InputArgument, get_args
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments from .data_args import DataArguments
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
__all__ = [ __all__ = [
"AttentionFunction",
"BatchingStrategy", "BatchingStrategy",
"DataArguments", "DataArguments",
"InputArgument", "InputArgument",

View File

@@ -15,6 +15,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from ..utils.types import AttentionFunction
from .arg_utils import ModelClass, PluginConfig, get_plugin_config from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@@ -32,6 +33,12 @@ class ModelArguments:
default=False, default=False,
metadata={"help": "Trust remote code from Hugging Face."}, metadata={"help": "Trust remote code from Hugging Face."},
) )
flash_attn: AttentionFunction = field(
default=AttentionFunction.SDPA,
metadata={
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
},
)
model_class: ModelClass = field( model_class: ModelClass = field(
default=ModelClass.LLM, default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."}, metadata={"help": "Model class from Hugging Face."},
@@ -54,6 +61,12 @@ class ModelArguments:
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
supported_flash_attn = [item.value for item in AttentionFunction]
if self.flash_attn not in supported_flash_attn:
raise ValueError(
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
)
self.init_config = get_plugin_config(self.init_config) self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config) self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config) self.kernel_config = get_plugin_config(self.kernel_config)

View File

@@ -116,3 +116,9 @@ class TrainingArguments:
self.dist_config = get_plugin_config(self.dist_config) self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config) self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config) self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING):
if self.max_steps is None or self.max_steps <= 0:
raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.")
if self.save_epochs is not None:
raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.")

View File

@@ -34,7 +34,7 @@ import torch.nn.functional as F
from ..accelerator.helper import ReduceOp from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments from ..config import BatchingStrategy, TrainingArguments
from ..utils import logging from ..utils import logging
from ..utils.callbacks import ( from ..utils.callbacks import (
CallbackHandler, CallbackHandler,
@@ -147,13 +147,19 @@ class BaseTrainer:
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
if model.config._attn_implementation != "flash_attention_2": if model.config._attn_implementation != "flash_attention_2":
logger.warning_rank0( raise ValueError(
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2." "Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
) )
model.config._attn_implementation = "flash_attention_2"
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config) SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
def _create_batch_generator(self) -> None: def _create_batch_generator(self) -> None:
if (
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
):
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
self.train_batch_generator = BatchGenerator( self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset, dataset=self.train_dataset,
renderer=self.renderer, renderer=self.renderer,
@@ -237,6 +243,7 @@ class BaseTrainer:
self.train_batch_generator.set_epoch(epoch) self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state) self.callback_handler.on_epoch_begin(self.args, self.state)
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
for micro_batches in self.train_batch_generator: for micro_batches in self.train_batch_generator:
self.global_step += 1 self.global_step += 1

View File

@@ -120,6 +120,7 @@ class ModelEngine:
init_device = DistributedInterface().current_device init_device = DistributedInterface().current_device
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device} init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
if self.args.quant_config is not None: if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin from ..plugins.model_plugins.quantization import QuantizationPlugin
@@ -164,6 +165,7 @@ class ModelEngine:
self.args.model, self.args.model,
config=self.model_config, config=self.model_config,
dtype="auto", dtype="auto",
attn_implementation=self.args.flash_attn,
trust_remote_code=self.args.trust_remote_code, trust_remote_code=self.args.trust_remote_code,
**init_kwargs, **init_kwargs,
) )

View File

@@ -42,6 +42,8 @@ from .rendering import Renderer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
__all__ = ["BatchGenerator"]
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"] micro_batch_size = batch_info["micro_batch_size"]
@@ -102,19 +104,18 @@ class BatchGenerator(Iterator):
if not self.drop_last: if not self.drop_last:
raise ValueError("Drop last must be True.") raise ValueError("Drop last must be True.")
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
}
self._init_data_provider() self._init_data_provider()
self._is_resuming: bool = False self._is_resuming: bool = False
self._data_iter = iter(self._data_provider) self._data_iter = iter(self._data_provider)
self._buffer = StatefulBuffer() self._buffer = StatefulBuffer()
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
"data_iter": self._data_iter,
}
logger.info_rank0( logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, " f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, " f"micro batch size {self.micro_batch_size}, "
@@ -137,12 +138,19 @@ class BatchGenerator(Iterator):
else: else:
raise NotImplementedError("Iterable dataset is not supported yet.") raise NotImplementedError("Iterable dataset is not supported yet.")
if self.batching_strategy == BatchingStrategy.NORMAL:
batch_size = self.micro_batch_size * self.num_micro_batch
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info)
generator_seed = torch.Generator() generator_seed = torch.Generator()
generator_seed.manual_seed(self.seed) generator_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader( self._data_provider = StatefulDataLoader(
self.dataset, self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch, batch_size=batch_size,
sampler=sampler, sampler=sampler,
num_workers=self.batching_workers, num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples, collate_fn=self.renderer.process_samples,
@@ -156,8 +164,7 @@ class BatchGenerator(Iterator):
else: else:
from ...plugins.trainer_plugins.batching import BatchingPlugin from ...plugins.trainer_plugins.batching import BatchingPlugin
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider) self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info)
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
def __len__(self) -> int: def __len__(self) -> int:
return self._length return self._length
@@ -190,7 +197,7 @@ class BatchGenerator(Iterator):
else: else:
from ...plugins.trainer_plugins.batching import BatchingPlugin from ...plugins.trainer_plugins.batching import BatchingPlugin
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info) BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples)
def _generate_batch(self) -> list[BatchInput] | None: def _generate_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL: if self.batching_strategy == BatchingStrategy.NORMAL:
@@ -200,6 +207,20 @@ class BatchGenerator(Iterator):
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info) return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
def _next_samples(self, restart: bool) -> list[ModelInput] | None:
try:
return next(self._data_iter)
except StopIteration:
if not restart:
return None
# Dynamic batching may restart the provider to fill one token-budgeted batch.
self._data_iter = iter(self._data_provider)
try:
return next(self._data_iter)
except StopIteration:
return None
def state_dict(self) -> dict[str, Any]: def state_dict(self) -> dict[str, Any]:
return { return {
"buffer": self._buffer.state_dict(), "buffer": self._buffer.state_dict(),

View File

@@ -12,23 +12,197 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Callable
from math import ceil
from typing import Any
import torch
from torch.utils.data import default_collate
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import pad_and_truncate
from ...utils.objects import StatefulBuffer from ...utils.objects import StatefulBuffer
from ...utils.plugin import BasePlugin from ...utils.plugin import BasePlugin
from ...utils.types import BatchInfo, BatchInput, DataLoader from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput
class BatchingPlugin(BasePlugin): class BatchingPlugin(BasePlugin):
def compute_length(self, data_provider: DataLoader) -> int: def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int:
"""Return the raw data provider batch size for this batching strategy."""
return self["get_data_provider_batch_size"](batch_info)
def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int:
"""Compute the length of the batch generator. """Compute the length of the batch generator.
The approximate length is used to calculate the lr schedule. The approximate length is used to calculate the lr schedule.
""" """
raise NotImplementedError() return self["compute_length"](data_provider, batch_info)
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None: def fill_buffer(
self,
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
"""Fill the buffer with data.""" """Fill the buffer with data."""
raise NotImplementedError() return self["fill_buffer"](buffer, batch_info, next_samples)
def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
"""Generate a batch from the buffer.""" """Generate a batch from the buffer."""
raise NotImplementedError() return self["generate_batch"](buffer, batch_info)
def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
"""Return sample counts for micro batches formed by one padded-token budget."""
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
max_sample_len = 0
used = 0
is_complete = False
while index + used < len(samples):
sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len)
padded_tokens = max(max_sample_len, sample_len) * (used + 1)
if used > 0 and padded_tokens > budget:
is_complete = True
break
max_sample_len = max(max_sample_len, sample_len)
used += 1
if max_sample_len * used >= budget:
is_complete = True
break
if used == 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
"""Pack fixed samples into one padding-free sequence without a token budget."""
packed: dict[str, list[Any]] = {}
position_ids: list[int] = []
for sample_index, sample in enumerate(samples):
# Padding-free still truncates each sample by cutoff_len before packing
# all samples into one contiguous sequence.
sample_len = min(len(sample["input_ids"]), cutoff_len)
if sample_len <= 0:
continue
for key, value in sample.items():
if key in ("attention_mask", "position_ids") or isinstance(value, str):
continue
if key not in packed:
packed[key] = []
sliced_value = list(value[:sample_len])
if sample_index > 0 and sliced_value:
if key == "labels":
sliced_value[0] = IGNORE_INDEX
elif key == "loss_weights":
sliced_value[0] = 0.0
packed[key].extend(sliced_value)
position_ids.extend(range(sample_len))
if not position_ids:
return None
packed["position_ids"] = position_ids
packed["attention_mask"] = [1] * len(position_ids)
return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()}
@BatchingPlugin("padding_free").register("get_data_provider_batch_size")
def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
@BatchingPlugin("padding_free").register("compute_length")
def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
return len(data_provider)
@BatchingPlugin("padding_free").register("fill_buffer")
def fill_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]:
samples = next_samples(False)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("padding_free").register("generate_batch")
def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
num_micro_batch = batch_info["num_micro_batch"]
cutoff_len = batch_info["cutoff_len"]
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return None
samples = buffer.get(batch_size)
batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len)
if packed_micro_batch is None:
return None
batch.append(packed_micro_batch)
return batch
@BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size")
def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_batching").register("compute_length")
def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_batching").register("fill_buffer")
def fill_dynamic_batching_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_batching").register("generate_batch")
def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
return batch

View File

@@ -61,6 +61,9 @@ def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt
@DistributedPlugin("deepspeed").register() @DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel: def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
if dist_config.get("cp_size", 1) > 1:
raise ValueError("CP currently requires `dist_config.name: fsdp2`.")
from .deepspeed import DeepSpeedEngine from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine( return DeepSpeedEngine(

View File

@@ -33,6 +33,10 @@ class StatefulBuffer:
def size(self) -> int: def size(self) -> int:
return self._buffer_size return self._buffer_size
@property
def samples(self) -> list[ModelInput]:
return self._buffer
def put(self, samples: list[ModelInput]) -> None: def put(self, samples: list[ModelInput]) -> None:
"""Add samples to the buffer.""" """Add samples to the buffer."""
num_tokens = sum(len(sample["input_ids"]) for sample in samples) num_tokens = sum(len(sample["input_ids"]) for sample in samples)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Iterator from enum import StrEnum, unique
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
@@ -54,6 +54,13 @@ else:
ProcessGroup = None ProcessGroup = None
@unique
class AttentionFunction(StrEnum):
EAGER = "eager"
SDPA = "sdpa"
FLASH_ATTENTION_2 = "flash_attention_2"
class DatasetInfo(TypedDict, total=False): class DatasetInfo(TypedDict, total=False):
path: str path: str
"""Local file path.""" """Local file path."""
@@ -171,8 +178,6 @@ class BatchInfo(TypedDict):
"""Number of micro batches.""" """Number of micro batches."""
cutoff_len: int cutoff_len: int
"""Cutoff length.""" """Cutoff length."""
data_iter: Iterator[list[ModelInput]]
"""Data iterator."""
class ModelOutput(NamedTuple): class ModelOutput(NamedTuple):

View File

@@ -58,10 +58,3 @@ def test_multi_device():
master_port = find_available_port() master_port = find_available_port()
world_size = 2 world_size = 2
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size) mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
if __name__ == "__main__":
"""
python tests_v1/accelerator/test_interface.py
"""
test_all_device()

View File

@@ -70,13 +70,3 @@ def test_get_args_from_yaml(tmp_path: Path):
assert training_args.bf16 is False assert training_args.bf16 is False
assert training_args.dist_config is None assert training_args.dist_config is None
assert sample_args.sample_backend == "hf" assert sample_args.sample_backend == "hf"
if __name__ == "__main__":
"""
python -m tests_v1.config.test_args_parser
"""
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
test_get_args_from_yaml(tmp_path=Path(tmp_dir))

View File

@@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int):
for index in indexes: for index in indexes:
print(data_engine[index]) print(data_engine[index])
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]} assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
if __name__ == "__main__":
"""
python -m tests_v1.core.test_data_engine
"""
test_map_dataset(1)

View File

@@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin():
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__ assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__ assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
if __name__ == "__main__":
"""
python -m tests_v1.core.test_model_loader
"""
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()

View File

@@ -16,6 +16,159 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
from llamafactory.v1.core.data_engine import DataEngine from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes
from llamafactory.v1.utils.constants import IGNORE_INDEX
from llamafactory.v1.utils.objects import StatefulBuffer
def _make_model_input(length: int, start: int = 0):
input_ids = list(range(start, start + length))
return {
"input_ids": input_ids,
"attention_mask": [1] * length,
"labels": input_ids.copy(),
"loss_weights": [1.0] * length,
}
class _RestartableDataProvider:
def __init__(self, batches):
self.batches = batches
self.num_iters = 0
def __iter__(self):
self.num_iters += 1
return iter(self.batches)
def test_padding_free():
buffer = StatefulBuffer()
# Input samples:
# sample 0 input_ids: [0, 1]
# sample 1 input_ids: [10, 11, 12, 13]
buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3}
batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info)
# Output batch:
# sample 1 is truncated to [10, 11, 12]
# both samples are packed into one sequence: [[0, 1, 10, 11, 12]]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 5)
assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]]
assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]]
assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]]
assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]]
assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]]
assert len(buffer) == 0
def test_batching_plugin_data_provider_batch_sizes():
batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 3,
"cutoff_len": 10,
}
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
def test_dynamic_batching():
# Input samples:
# sample lengths: [3, 4, 6, 2, 8, 9]
# input_ids:
# [0, 1, 2]
# [10, 11, 12, 13]
# [20, 21, 22, 23, 24, 25]
# [30, 31]
# [40, 41, 42, 43, 44, 45, 46, 47]
# [50, 51, 52, 53, 54, 55, 56, 57, 58]
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Dynamic batching output plan:
# dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size
# as the padded-token budget for one training micro batch.
# [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it.
assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0]
assert len(buffer) == 3
def test_dynamic_batching_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
# Input buffer:
# only one sample with length [6].
# cutoff_len * micro_batch_size gives a padded-token budget of 20.
# this buffer has not filled the budget and has no next sample to prove overflow,
# so dynamic batching cannot produce a batch yet.
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete():
# Input data provider:
# each iterator pass yields one sample with length [6].
# each yielded item is a list[ModelInput], matching BatchGenerator._next_samples.
# _fill_buffer keeps restarting the iterator until the next appended sample
# proves that the previous dynamic micro batch has reached its budget boundary.
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_batching"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
batch_generator._fill_buffer()
# Filled buffer after restart:
# existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch
# because adding it to the first dynamic micro batch would exceed the budget.
assert data_provider.num_iters == 4
assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic batching returns [micro_batch_0]
# micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6].
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert len(batch_generator._buffer) == 1
def test_normal_batching(): def test_normal_batching():
@@ -43,10 +196,3 @@ def test_normal_batching():
batch = next(iter(batch_generator)) batch = next(iter(batch_generator))
assert len(batch) == 2 assert len(batch) == 2
assert batch[0]["input_ids"].shape == (4, 10) assert batch[0]["input_ids"].shape == (4, 10)
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_batching
"""
test_normal_batching()

View File

@@ -227,17 +227,3 @@ def test_process_dpo_samples():
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs) assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test" assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default" assert model_inputs[0]["_dataset_name"] == "default"
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_rendering
"""
test_chatml_rendering()
test_chatml_parse()
test_chatml_rendering_remote(16)
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)
test_process_sft_samples()
test_process_dpo_samples()

View File

@@ -117,12 +117,3 @@ def test_pair_converter(num_samples: int):
], ],
} }
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data} assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
if __name__ == "__main__":
"""
python -m tests_v1.plugins.data_plugins.test_converter
"""
test_alpaca_converter(1)
test_sharegpt_converter()
test_pair_converter(1)

View File

@@ -52,12 +52,3 @@ def test_init_on_default():
) )
model_engine = ModelEngine(model_args=model_args) model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device == DistributedInterface().current_device assert model_engine.model.device == DistributedInterface().current_device
if __name__ == "__main__":
"""
python tests_v1/plugins/model_plugins/test_init_plugin.py
"""
test_init_on_meta()
test_init_on_rank0()
test_init_on_default()

View File

@@ -35,10 +35,3 @@ def test_sync_sampler():
"role": "assistant", "role": "assistant",
"content": [{"type": "text", "value": "This is a test."}], "content": [{"type": "text", "value": "This is a test."}],
} }
if __name__ == "__main__":
"""
python tests_v1/sampler/test_cli_sampler.py
"""
test_sync_sampler()