[v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469)

This commit is contained in:
jiaqiw09
2026-05-21 17:14:19 +08:00
committed by GitHub
parent 7e20db5735
commit bdcb92d035
23 changed files with 507 additions and 105 deletions

View File

@@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int):
for index in indexes:
print(data_engine[index])
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
if __name__ == "__main__":
"""
python -m tests_v1.core.test_data_engine
"""
test_map_dataset(1)

View File

@@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin():
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
if __name__ == "__main__":
"""
python -m tests_v1.core.test_model_loader
"""
test_tiny_qwen()
test_tiny_qwen_with_kernel_plugin()

View File

@@ -16,6 +16,159 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes
from llamafactory.v1.utils.constants import IGNORE_INDEX
from llamafactory.v1.utils.objects import StatefulBuffer
def _make_model_input(length: int, start: int = 0):
input_ids = list(range(start, start + length))
return {
"input_ids": input_ids,
"attention_mask": [1] * length,
"labels": input_ids.copy(),
"loss_weights": [1.0] * length,
}
class _RestartableDataProvider:
def __init__(self, batches):
self.batches = batches
self.num_iters = 0
def __iter__(self):
self.num_iters += 1
return iter(self.batches)
def test_padding_free():
buffer = StatefulBuffer()
# Input samples:
# sample 0 input_ids: [0, 1]
# sample 1 input_ids: [10, 11, 12, 13]
buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3}
batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info)
# Output batch:
# sample 1 is truncated to [10, 11, 12]
# both samples are packed into one sequence: [[0, 1, 10, 11, 12]]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 5)
assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]]
assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]]
assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]]
assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]]
assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]]
assert len(buffer) == 0
def test_batching_plugin_data_provider_batch_sizes():
batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 3,
"cutoff_len": 10,
}
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
def test_dynamic_batching():
# Input samples:
# sample lengths: [3, 4, 6, 2, 8, 9]
# input_ids:
# [0, 1, 2]
# [10, 11, 12, 13]
# [20, 21, 22, 23, 24, 25]
# [30, 31]
# [40, 41, 42, 43, 44, 45, 46, 47]
# [50, 51, 52, 53, 54, 55, 56, 57, 58]
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Dynamic batching output plan:
# dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size
# as the padded-token budget for one training micro batch.
# [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it.
assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0]
assert len(buffer) == 3
def test_dynamic_batching_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
# Input buffer:
# only one sample with length [6].
# cutoff_len * micro_batch_size gives a padded-token budget of 20.
# this buffer has not filled the budget and has no next sample to prove overflow,
# so dynamic batching cannot produce a batch yet.
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete():
# Input data provider:
# each iterator pass yields one sample with length [6].
# each yielded item is a list[ModelInput], matching BatchGenerator._next_samples.
# _fill_buffer keeps restarting the iterator until the next appended sample
# proves that the previous dynamic micro batch has reached its budget boundary.
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_batching"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
batch_generator._fill_buffer()
# Filled buffer after restart:
# existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch
# because adding it to the first dynamic micro batch would exceed the budget.
assert data_provider.num_iters == 4
assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic batching returns [micro_batch_0]
# micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6].
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (3, 6)
assert len(batch_generator._buffer) == 1
def test_normal_batching():
@@ -43,10 +196,3 @@ def test_normal_batching():
batch = next(iter(batch_generator))
assert len(batch) == 2
assert batch[0]["input_ids"].shape == (4, 10)
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_batching
"""
test_normal_batching()

View File

@@ -227,17 +227,3 @@ def test_process_dpo_samples():
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
assert model_inputs[0]["extra_info"] == "test"
assert model_inputs[0]["_dataset_name"] == "default"
if __name__ == "__main__":
"""
python -m tests_v1.core.utils.test_rendering
"""
test_chatml_rendering()
test_chatml_parse()
test_chatml_rendering_remote(16)
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)
test_process_sft_samples()
test_process_dpo_samples()