[v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469)

This commit is contained in:
jiaqiw09
2026-05-21 17:14:19 +08:00
committed by GitHub
parent 7e20db5735
commit bdcb92d035
23 changed files with 507 additions and 105 deletions

View File

@@ -0,0 +1,31 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: normal
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 2
batching_strategy: dynamic_batching
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -0,0 +1,30 @@
model: Qwen/Qwen3-0.6B
model_class: llm
template: qwen3_nothink
kernel_config:
name: auto
include_kernels: auto # choice: null/true/false/auto/kernel_id1,kernel_id2,kernel_id3, default is null
quant_config: null
dist_config:
name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_fsdp2
micro_batch_size: 4
batching_strategy: padding_free
flash_attn: fa2
cutoff_len: 2048
learning_rate: 1.0e-4
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -20,7 +20,6 @@ import sys
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Optional
@@ -584,7 +583,7 @@ class ModuleProfilerCallback(TrainerCallback):
if matched:
logger.info_rank0(
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
)
else:
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
@@ -616,7 +615,7 @@ class ModuleProfilerCallback(TrainerCallback):
bwd = self._backward_times.get(name, [])
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
logger.info_rank0("\n".join(lines))
self._forward_times.clear()

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils.types import AttentionFunction
from .arg_parser import InputArgument, get_args
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
from .data_args import DataArguments
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
__all__ = [
"AttentionFunction",
"BatchingStrategy",
"DataArguments",
"InputArgument",

View File

@@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from ..utils.types import AttentionFunction
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@@ -32,6 +33,12 @@ class ModelArguments:
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
)
flash_attn: AttentionFunction = field(
default=AttentionFunction.SDPA,
metadata={
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
},
)
model_class: ModelClass = field(
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
@@ -54,6 +61,12 @@ class ModelArguments:
)
def __post_init__(self) -> None:
supported_flash_attn = [item.value for item in AttentionFunction]
if self.flash_attn not in supported_flash_attn:
raise ValueError(
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
)
self.init_config = get_plugin_config(self.init_config)
self.peft_config = get_plugin_config(self.peft_config)
self.kernel_config = get_plugin_config(self.kernel_config)

View File

@@ -116,3 +116,9 @@ class TrainingArguments:
self.dist_config = get_plugin_config(self.dist_config)
self.optim_config = get_plugin_config(self.optim_config)
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING):
if self.max_steps is None or self.max_steps <= 0:
raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.")
if self.save_epochs is not None:
raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.")

View File

@@ -34,7 +34,7 @@ import torch.nn.functional as F
from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments
from ..config import BatchingStrategy, TrainingArguments
from ..utils import logging
from ..utils.callbacks import (
CallbackHandler,
@@ -147,13 +147,19 @@ class BaseTrainer:
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
if model.config._attn_implementation != "flash_attention_2":
logger.warning_rank0(
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
raise ValueError(
"Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
)
model.config._attn_implementation = "flash_attention_2"
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
def _create_batch_generator(self) -> None:
if (
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
):
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset,
renderer=self.renderer,
@@ -237,6 +243,7 @@ class BaseTrainer:
self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
for micro_batches in self.train_batch_generator:
self.global_step += 1

View File

@@ -120,6 +120,7 @@ class ModelEngine:
init_device = DistributedInterface().current_device
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
if self.args.quant_config is not None:
from ..plugins.model_plugins.quantization import QuantizationPlugin
@@ -164,6 +165,7 @@ class ModelEngine:
self.args.model,
config=self.model_config,
dtype="auto",
attn_implementation=self.args.flash_attn,
trust_remote_code=self.args.trust_remote_code,
**init_kwargs,
)

View File

@@ -42,6 +42,8 @@ from .rendering import Renderer
logger = logging.get_logger(__name__)
__all__ = ["BatchGenerator"]
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
@@ -102,19 +104,18 @@ class BatchGenerator(Iterator):
if not self.drop_last:
raise ValueError("Drop last must be True.")
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
}
self._init_data_provider()
self._is_resuming: bool = False
self._data_iter = iter(self._data_provider)
self._buffer = StatefulBuffer()
self._batch_info: BatchInfo = {
"micro_batch_size": self.micro_batch_size,
"num_micro_batch": self.num_micro_batch,
"cutoff_len": self.cutoff_len,
"data_iter": self._data_iter,
}
logger.info_rank0(
f"Init unified data loader with global batch size {self.global_batch_size}, "
f"micro batch size {self.micro_batch_size}, "
@@ -137,12 +138,19 @@ class BatchGenerator(Iterator):
else:
raise NotImplementedError("Iterable dataset is not supported yet.")
if self.batching_strategy == BatchingStrategy.NORMAL:
batch_size = self.micro_batch_size * self.num_micro_batch
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info)
generator_seed = torch.Generator()
generator_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader(
self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch,
batch_size=batch_size,
sampler=sampler,
num_workers=self.batching_workers,
collate_fn=self.renderer.process_samples,
@@ -156,8 +164,7 @@ class BatchGenerator(Iterator):
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info)
def __len__(self) -> int:
return self._length
@@ -190,7 +197,7 @@ class BatchGenerator(Iterator):
else:
from ...plugins.trainer_plugins.batching import BatchingPlugin
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples)
def _generate_batch(self) -> list[BatchInput] | None:
if self.batching_strategy == BatchingStrategy.NORMAL:
@@ -200,6 +207,20 @@ class BatchGenerator(Iterator):
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
def _next_samples(self, restart: bool) -> list[ModelInput] | None:
try:
return next(self._data_iter)
except StopIteration:
if not restart:
return None
# Dynamic batching may restart the provider to fill one token-budgeted batch.
self._data_iter = iter(self._data_provider)
try:
return next(self._data_iter)
except StopIteration:
return None
def state_dict(self) -> dict[str, Any]:
return {
"buffer": self._buffer.state_dict(),

View File

@@ -12,23 +12,197 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from math import ceil
from typing import Any
import torch
from torch.utils.data import default_collate
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import pad_and_truncate
from ...utils.objects import StatefulBuffer
from ...utils.plugin import BasePlugin
from ...utils.types import BatchInfo, BatchInput, DataLoader
from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput
class BatchingPlugin(BasePlugin):
def compute_length(self, data_provider: DataLoader) -> int:
def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int:
"""Return the raw data provider batch size for this batching strategy."""
return self["get_data_provider_batch_size"](batch_info)
def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int:
"""Compute the length of the batch generator.
The approximate length is used to calculate the lr schedule.
"""
raise NotImplementedError()
return self["compute_length"](data_provider, batch_info)
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
def fill_buffer(
self,
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
"""Fill the buffer with data."""
raise NotImplementedError()
return self["fill_buffer"](buffer, batch_info, next_samples)
def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
"""Generate a batch from the buffer."""
raise NotImplementedError()
return self["generate_batch"](buffer, batch_info)
def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
"""Return sample counts for micro batches formed by one padded-token budget."""
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
max_sample_len = 0
used = 0
is_complete = False
while index + used < len(samples):
sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len)
padded_tokens = max(max_sample_len, sample_len) * (used + 1)
if used > 0 and padded_tokens > budget:
is_complete = True
break
max_sample_len = max(max_sample_len, sample_len)
used += 1
if max_sample_len * used >= budget:
is_complete = True
break
if used == 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
"""Pack fixed samples into one padding-free sequence without a token budget."""
packed: dict[str, list[Any]] = {}
position_ids: list[int] = []
for sample_index, sample in enumerate(samples):
# Padding-free still truncates each sample by cutoff_len before packing
# all samples into one contiguous sequence.
sample_len = min(len(sample["input_ids"]), cutoff_len)
if sample_len <= 0:
continue
for key, value in sample.items():
if key in ("attention_mask", "position_ids") or isinstance(value, str):
continue
if key not in packed:
packed[key] = []
sliced_value = list(value[:sample_len])
if sample_index > 0 and sliced_value:
if key == "labels":
sliced_value[0] = IGNORE_INDEX
elif key == "loss_weights":
sliced_value[0] = 0.0
packed[key].extend(sliced_value)
position_ids.extend(range(sample_len))
if not position_ids:
return None
packed["position_ids"] = position_ids
packed["attention_mask"] = [1] * len(position_ids)
return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()}
@BatchingPlugin("padding_free").register("get_data_provider_batch_size")
def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
@BatchingPlugin("padding_free").register("compute_length")
def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
return len(data_provider)
@BatchingPlugin("padding_free").register("fill_buffer")
def fill_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]:
samples = next_samples(False)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("padding_free").register("generate_batch")
def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_size = batch_info["micro_batch_size"]
num_micro_batch = batch_info["num_micro_batch"]
cutoff_len = batch_info["cutoff_len"]
batch_size = micro_batch_size * num_micro_batch
if len(buffer) < batch_size:
return None
samples = buffer.get(batch_size)
batch = []
for i in range(num_micro_batch):
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len)
if packed_micro_batch is None:
return None
batch.append(packed_micro_batch)
return batch
@BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size")
def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_batching").register("compute_length")
def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_batching").register("fill_buffer")
def fill_dynamic_batching_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_batching").register("generate_batch")
def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
return batch

View File

@@ -61,6 +61,9 @@ def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt
@DistributedPlugin("deepspeed").register()
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
if dist_config.get("cp_size", 1) > 1:
raise ValueError("CP currently requires `dist_config.name: fsdp2`.")
from .deepspeed import DeepSpeedEngine
return DeepSpeedEngine(

View File

@@ -33,6 +33,10 @@ class StatefulBuffer:
def size(self) -> int:
return self._buffer_size
@property
def samples(self) -> list[ModelInput]:
return self._buffer
def put(self, samples: list[ModelInput]) -> None:
"""Add samples to the buffer."""
num_tokens = sum(len(sample["input_ids"]) for sample in samples)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator
from enum import StrEnum, unique
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
@@ -54,6 +54,13 @@ else:
ProcessGroup = None
@unique
class AttentionFunction(StrEnum):
EAGER = "eager"
SDPA = "sdpa"
FLASH_ATTENTION_2 = "flash_attention_2"
class DatasetInfo(TypedDict, total=False):
path: str
"""Local file path."""
@@ -171,8 +178,6 @@ class BatchInfo(TypedDict):
"""Number of micro batches."""
cutoff_len: int
"""Cutoff length."""
data_iter: Iterator[list[ModelInput]]
"""Data iterator."""
class ModelOutput(NamedTuple):

View File

@@ -58,10 +58,3 @@ def test_multi_device():
master_port = find_available_port()
world_size = 2
mp.spawn(_all_reduce_tests, args=(world_size, master_port), nprocs=world_size)
if __name__ == "__main__":
"""
python tests_v1/accelerator/test_interface.py
"""
test_all_device()

View File

@@ -70,13 +70,3 @@ def test_get_args_from_yaml(tmp_path: Path):
assert training_args.bf16 is False
assert training_args.dist_config is None
assert sample_args.sample_backend == "hf"
if __name__ == "__main__":
"""
python -m tests_v1.config.test_args_parser
"""
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
test_get_args_from_yaml(tmp_path=Path(tmp_dir))

View File

@@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int):
for index in indexes:
print(data_engine[index])
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
if __name__ == "__main__":
"""
python -m tests_v1.core.test_data_engine
"""
test_map_dataset(1)

View File

@@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin():
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
if __name__ == "__main__":
"""
python -m tests_v1.core.test_model_loader
"""
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()

View File

@@ -16,6 +16,159 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes
from llamafactory.v1.utils.constants import IGNORE_INDEX
from llamafactory.v1.utils.objects import StatefulBuffer
def _make_model_input(length: int, start: int = 0):
input_ids = list(range(start, start + length))
return {
"input_ids": input_ids,
"attention_mask": [1] * length,
"labels": input_ids.copy(),
"loss_weights": [1.0] * length,
}
class _RestartableDataProvider:
def __init__(self, batches):
self.batches = batches
self.num_iters = 0
def __iter__(self):
self.num_iters += 1
return iter(self.batches)
def test_padding_free():
buffer = StatefulBuffer()
# Input samples:
# sample 0 input_ids: [0, 1]
# sample 1 input_ids: [10, 11, 12, 13]
buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3}
batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info)
# Output batch:
# sample 1 is truncated to [10, 11, 12]
# both samples are packed into one sequence: [[0, 1, 10, 11, 12]]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 5)
assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]]
assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]]
assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]]
assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]]
assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]]
assert len(buffer) == 0
def test_batching_plugin_data_provider_batch_sizes():
batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 3,
"cutoff_len": 10,
}
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
def test_dynamic_batching():
# Input samples:
# sample lengths: [3, 4, 6, 2, 8, 9]
# input_ids:
# [0, 1, 2]
# [10, 11, 12, 13]
# [20, 21, 22, 23, 24, 25]
# [30, 31]
# [40, 41, 42, 43, 44, 45, 46, 47]
# [50, 51, 52, 53, 54, 55, 56, 57, 58]
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Dynamic batching output plan:
# dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size
# as the padded-token budget for one training micro batch.
# [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it.
assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0]
assert len(buffer) == 3
def test_dynamic_batching_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
# Input buffer:
# only one sample with length [6].
# cutoff_len * micro_batch_size gives a padded-token budget of 20.
# this buffer has not filled the budget and has no next sample to prove overflow,
# so dynamic batching cannot produce a batch yet.
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete():
# Input data provider:
# each iterator pass yields one sample with length [6].
# each yielded item is a list[ModelInput], matching BatchGenerator._next_samples.
# _fill_buffer keeps restarting the iterator until the next appended sample
# proves that the previous dynamic micro batch has reached its budget boundary.
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_batching"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
batch_generator._fill_buffer()
# Filled buffer after restart:
# existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch
# because adding it to the first dynamic micro batch would exceed the budget.
assert data_provider.num_iters == 4
assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic batching returns [micro_batch_0]
# micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6].
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert len(batch_generator._buffer) == 1
def test_normal_batching():
@@ -43,10 +196,3 @@ def test_normal_batching():
batch = next(iter(batch_generator))
assert len(batch) == 2
assert batch[0]["input_ids"].shape == (4, 10)
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_batching
"""
test_normal_batching()

View File

@@ -227,17 +227,3 @@ def test_process_dpo_samples():
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_rendering
"""
test_chatml_rendering()
test_chatml_parse()
test_chatml_rendering_remote(16)
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)
test_process_sft_samples()
test_process_dpo_samples()

View File

@@ -117,12 +117,3 @@ def test_pair_converter(num_samples: int):
],
}
assert data_engine[index] == {"_dataset_name": "tiny_dataset", **expected_data}
if __name__ == "__main__":
"""
python -m tests_v1.plugins.data_plugins.test_converter
"""
test_alpaca_converter(1)
test_sharegpt_converter()
test_pair_converter(1)

View File

@@ -52,12 +52,3 @@ def test_init_on_default():
)
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device == DistributedInterface().current_device
if __name__ == "__main__":
"""
python tests_v1/plugins/model_plugins/test_init_plugin.py
"""
test_init_on_meta()
test_init_on_rank0()
test_init_on_default()

View File

@@ -35,10 +35,3 @@ def test_sync_sampler():
"role": "assistant",
"content": [{"type": "text", "value": "This is a test."}],
}
if __name__ == "__main__":
"""
python tests_v1/sampler/test_cli_sampler.py
"""
test_sync_sampler()