mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469)
This commit is contained in:
@@ -30,10 +30,3 @@ def test_map_dataset(num_samples: int):
|
||||
for index in indexes:
|
||||
print(data_engine[index])
|
||||
assert data_engine[index] == {"_dataset_name": "default", **original_data[index]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.test_data_engine
|
||||
"""
|
||||
test_map_dataset(1)
|
||||
|
||||
@@ -41,11 +41,3 @@ def test_tiny_qwen_with_kernel_plugin():
|
||||
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||
|
||||
assert "Qwen3ForCausalLM" in model_engine.model.__class__.__name__
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.test_model_loader
|
||||
"""
|
||||
test_tiny_qwen()
|
||||
test_tiny_qwen_with_kernel_plugin()
|
||||
|
||||
@@ -16,6 +16,159 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||
from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes
|
||||
from llamafactory.v1.utils.constants import IGNORE_INDEX
|
||||
from llamafactory.v1.utils.objects import StatefulBuffer
|
||||
|
||||
|
||||
def _make_model_input(length: int, start: int = 0):
|
||||
input_ids = list(range(start, start + length))
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": [1] * length,
|
||||
"labels": input_ids.copy(),
|
||||
"loss_weights": [1.0] * length,
|
||||
}
|
||||
|
||||
|
||||
class _RestartableDataProvider:
|
||||
def __init__(self, batches):
|
||||
self.batches = batches
|
||||
self.num_iters = 0
|
||||
|
||||
def __iter__(self):
|
||||
self.num_iters += 1
|
||||
return iter(self.batches)
|
||||
|
||||
|
||||
def test_padding_free():
|
||||
buffer = StatefulBuffer()
|
||||
# Input samples:
|
||||
# sample 0 input_ids: [0, 1]
|
||||
# sample 1 input_ids: [10, 11, 12, 13]
|
||||
buffer.put([_make_model_input(2, 0), _make_model_input(4, 10)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 3}
|
||||
|
||||
batch = BatchingPlugin("padding_free").generate_batch(buffer, batch_info)
|
||||
|
||||
# Output batch:
|
||||
# sample 1 is truncated to [10, 11, 12]
|
||||
# both samples are packed into one sequence: [[0, 1, 10, 11, 12]]
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (1, 5)
|
||||
assert batch[0]["input_ids"].tolist() == [[0, 1, 10, 11, 12]]
|
||||
assert batch[0]["attention_mask"].tolist() == [[1, 1, 1, 1, 1]]
|
||||
assert batch[0]["position_ids"].tolist() == [[0, 1, 0, 1, 2]]
|
||||
assert batch[0]["labels"].tolist() == [[0, 1, IGNORE_INDEX, 11, 12]]
|
||||
assert batch[0]["loss_weights"].tolist() == [[1.0, 1.0, 0.0, 1.0, 1.0]]
|
||||
assert len(buffer) == 0
|
||||
|
||||
|
||||
def test_batching_plugin_data_provider_batch_sizes():
|
||||
batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 3,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
|
||||
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
|
||||
|
||||
|
||||
def test_dynamic_batching():
|
||||
# Input samples:
|
||||
# sample lengths: [3, 4, 6, 2, 8, 9]
|
||||
# input_ids:
|
||||
# [0, 1, 2]
|
||||
# [10, 11, 12, 13]
|
||||
# [20, 21, 22, 23, 24, 25]
|
||||
# [30, 31]
|
||||
# [40, 41, 42, 43, 44, 45, 46, 47]
|
||||
# [50, 51, 52, 53, 54, 55, 56, 57, 58]
|
||||
samples = [
|
||||
_make_model_input(3, 0),
|
||||
_make_model_input(4, 10),
|
||||
_make_model_input(6, 20),
|
||||
_make_model_input(2, 30),
|
||||
_make_model_input(8, 40),
|
||||
_make_model_input(9, 50),
|
||||
]
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
# Dynamic batching output plan:
|
||||
# dynamic batching reads one sample at a time and uses cutoff_len * micro_batch_size
|
||||
# as the padded-token budget for one training micro batch.
|
||||
# [3, 4, 6] fits within budget 20 as shape [3, 6]; adding [2] would exceed it.
|
||||
assert _get_dynamic_micro_batch_sizes(samples, batch_info) == [3]
|
||||
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put(samples)
|
||||
batch = BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info)
|
||||
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (3, 6)
|
||||
assert batch[0]["input_ids"].tolist()[0] == [0, 1, 2, 0, 0, 0]
|
||||
assert len(buffer) == 3
|
||||
|
||||
|
||||
def test_dynamic_batching_returns_none_when_token_budget_is_incomplete():
|
||||
buffer = StatefulBuffer()
|
||||
# Input buffer:
|
||||
# only one sample with length [6].
|
||||
# cutoff_len * micro_batch_size gives a padded-token budget of 20.
|
||||
# this buffer has not filled the budget and has no next sample to prove overflow,
|
||||
# so dynamic batching cannot produce a batch yet.
|
||||
buffer.put([_make_model_input(6, 0)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
|
||||
assert BatchingPlugin("dynamic_batching").generate_batch(buffer, batch_info) is None
|
||||
# Batch generation does not read from the data iterator. It only returns None and keeps
|
||||
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
|
||||
assert len(buffer) == 1
|
||||
|
||||
|
||||
def test_dynamic_batching_fill_buffer_restarts_until_micro_batch_is_complete():
|
||||
# Input data provider:
|
||||
# each iterator pass yields one sample with length [6].
|
||||
# each yielded item is a list[ModelInput], matching BatchGenerator._next_samples.
|
||||
# _fill_buffer keeps restarting the iterator until the next appended sample
|
||||
# proves that the previous dynamic micro batch has reached its budget boundary.
|
||||
samples = [_make_model_input(6, 0)]
|
||||
data_provider = _RestartableDataProvider([[sample] for sample in samples])
|
||||
|
||||
batch_generator = BatchGenerator.__new__(BatchGenerator)
|
||||
batch_generator.batching_strategy = "dynamic_batching"
|
||||
batch_generator.micro_batch_size = 2
|
||||
batch_generator.num_micro_batch = 1
|
||||
batch_generator._buffer = StatefulBuffer()
|
||||
batch_generator._data_provider = data_provider
|
||||
batch_generator._data_iter = iter(data_provider)
|
||||
batch_generator._batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 1,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
batch_generator._fill_buffer()
|
||||
|
||||
# Filled buffer after restart:
|
||||
# existing buffer [6, 6, 6] is kept; the fourth [6] remains for the next batch
|
||||
# because adding it to the first dynamic micro batch would exceed the budget.
|
||||
assert data_provider.num_iters == 4
|
||||
assert _get_dynamic_micro_batch_sizes(batch_generator._buffer.samples, batch_generator._batch_info) == [3]
|
||||
|
||||
batch = batch_generator._generate_batch()
|
||||
|
||||
# Output batch:
|
||||
# dynamic batching returns [micro_batch_0]
|
||||
# micro_batch_0 consumes [6, 6, 6] => 3 samples, padded to shape [3, 6].
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (3, 6)
|
||||
assert len(batch_generator._buffer) == 1
|
||||
|
||||
|
||||
def test_normal_batching():
|
||||
@@ -43,10 +196,3 @@ def test_normal_batching():
|
||||
batch = next(iter(batch_generator))
|
||||
assert len(batch) == 2
|
||||
assert batch[0]["input_ids"].shape == (4, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.utils.test_batching
|
||||
"""
|
||||
test_normal_batching()
|
||||
|
||||
@@ -227,17 +227,3 @@ def test_process_dpo_samples():
|
||||
assert model_inputs[0]["token_type_ids"] == [1] * len(hf_inputs) + [2] * len(hf_inputs)
|
||||
assert model_inputs[0]["extra_info"] == "test"
|
||||
assert model_inputs[0]["_dataset_name"] == "default"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m tests_v1.core.utils.test_rendering
|
||||
"""
|
||||
test_chatml_rendering()
|
||||
test_chatml_parse()
|
||||
test_chatml_rendering_remote(16)
|
||||
test_qwen3_nothink_rendering()
|
||||
test_qwen3_nothink_parse()
|
||||
test_qwen3_nothink_rendering_remote(16)
|
||||
test_process_sft_samples()
|
||||
test_process_dpo_samples()
|
||||
|
||||
Reference in New Issue
Block a user