From 8e68764b65cab50c80b44079bc8430fb7e9ca70b Mon Sep 17 00:00:00 2001 From: cxy <2630747598@qq.com> Date: Mon, 25 May 2026 20:40:21 +0800 Subject: [PATCH] [v1] Implement dynamic padding-free stretrgy for batching (#10507) Co-authored-by: cxy-thinkbook --- .../v1/plugins/trainer_plugins/batching.py | 75 ++++++++ tests_v1/core/utils/test_batching.py | 172 +++++++++++++++++- 2 files changed, 246 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/v1/plugins/trainer_plugins/batching.py b/src/llamafactory/v1/plugins/trainer_plugins/batching.py index 7c945c020..23746a6d4 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/batching.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/batching.py @@ -84,6 +84,37 @@ def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchI return sizes +def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]: + budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"] + cutoff_len = batch_info["cutoff_len"] + sizes = [] + index = 0 + + while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]: + current_tokens = 0 + used = 0 + is_complete = False + + while index + used < len(samples): + sample = samples[index + used] + sample_len = min(len(sample["input_ids"]), cutoff_len) + + if current_tokens + sample_len > budget: + is_complete = True + break + + current_tokens += sample_len + used += 1 + + if used <= 0 or not is_complete: + break + + sizes.append(used) + index += used + + return sizes + + def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None: """Pack fixed samples into one padding-free sequence without a token budget.""" packed: dict[str, list[Any]] = {} @@ -206,3 +237,47 @@ def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInf batch.append(default_collate(pad_and_truncate(samples, cutoff_len))) return batch + + +@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size") +def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int: + return 1 + + +@BatchingPlugin("dynamic_padding_free").register("compute_length") +def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int: + batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"] + return ceil(len(data_provider) / batch_size) + + +@BatchingPlugin("dynamic_padding_free").register("fill_buffer") +def fill_dynamic_padding_free_buffer( + buffer: StatefulBuffer, + batch_info: BatchInfo, + next_samples: Callable[[bool], list[ModelInput] | None], +) -> None: + while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]: + samples = next_samples(True) + if samples is None: + break + buffer.put(samples) + + +@BatchingPlugin("dynamic_padding_free").register("generate_batch") +def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: + micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info) + if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]: + return None + + batch = [] + cutoff_len = batch_info["cutoff_len"] + + for num_samples in micro_batch_sample_counts: + samples = buffer.get(num_samples) + packed_batch = _pack_padding_free_samples(samples, cutoff_len) + if packed_batch is None: + return None + + batch.append(packed_batch) + + return batch diff --git a/tests_v1/core/utils/test_batching.py b/tests_v1/core/utils/test_batching.py index f379d0763..ef01e50ee 100644 --- a/tests_v1/core/utils/test_batching.py +++ b/tests_v1/core/utils/test_batching.py @@ -16,7 +16,11 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume from llamafactory.v1.core.data_engine import DataEngine from llamafactory.v1.core.model_engine import ModelEngine from llamafactory.v1.core.utils.batching import BatchGenerator -from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes +from llamafactory.v1.plugins.trainer_plugins.batching import ( + BatchingPlugin, + _get_dynamic_micro_batch_sizes, + _get_dynamic_padding_free_micro_batch_sizes, +) from llamafactory.v1.utils.constants import IGNORE_INDEX from llamafactory.v1.utils.objects import StatefulBuffer @@ -74,6 +78,7 @@ def test_batching_plugin_data_provider_batch_sizes(): assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6 assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1 + assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1 def test_dynamic_batching(): @@ -196,3 +201,168 @@ def test_normal_batching(): batch = next(iter(batch_generator)) assert len(batch) == 2 assert batch[0]["input_ids"].shape == (4, 10) + + +def test_dynamic_padding_free(): + """Test core logic of dynamic padding free strategy: pack samples by total token budget without padding.""" + # Construct test samples (lengths: 3, 4, 6, 2, 8, 9) + # input_ids breakdown: + # sample 0: [0,1,2] (length=3) + # sample 1: [10,11,12,13] (length=4) + # sample 2: [20,21,22,23,24,25] (length=6) + # sample 3: [30,31] (length=2) + # sample 4: [40-47] (length=8) + # sample 5: [50-58] (length=9) + samples = [ + _make_model_input(3, 0), + _make_model_input(4, 10), + _make_model_input(6, 20), + _make_model_input(2, 30), + _make_model_input(8, 40), + _make_model_input(9, 50), + ] + # Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20 + batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10} + + # Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected + assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4] + + buffer = StatefulBuffer() + buffer.put(samples) + batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) + + assert batch is not None + assert len(batch) == 1 # num_micro_batch=1 + packed_batch = batch[0] + + # Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding) + assert packed_batch["input_ids"].shape == (1, 15) + + # Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX) + assert packed_batch["input_ids"].tolist() == [ + [ + 0, + 1, + 2, # Sample 0 + 10, + 11, + 12, + 13, # Sample 1 + 20, + 21, + 22, + 23, + 24, + 25, # Sample 2 + 30, + 31, + ] # Sample 3 + ] + + # Verify labels (first token of non-initial samples is IGNORE_INDEX) + assert packed_batch["labels"].tolist() == [ + [ + 0, + 1, + 2, # Sample 0 + IGNORE_INDEX, + 11, + 12, + 13, # Sample 1 + IGNORE_INDEX, + 21, + 22, + 23, + 24, + 25, # Sample 2 + IGNORE_INDEX, + 31, + ] # Sample 3 + ] + + # Verify attention_mask + assert packed_batch["attention_mask"].tolist() == [[1] * 15] + + # Verify position_ids + assert packed_batch["position_ids"].tolist() == [ + [ + 0, + 1, + 2, # Sample 0 + 0, + 1, + 2, + 3, # Sample 1 + 0, + 1, + 2, + 3, + 4, + 5, # Sample 2 + 0, + 1, + ] # Sample 3 + ] + + # Verify remaining samples in buffer: 6-4=2 samples (length 8,9) + assert len(buffer) == 2 + + +def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete(): + buffer = StatefulBuffer() + buffer.put([_make_model_input(6, 0)]) + batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10} + + assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == [] + assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None + # Batch generation does not read from the data iterator. It only returns None and keeps + # existing samples in the buffer; BatchGenerator._fill_buffer handles refilling. + assert len(buffer) == 1 + + +def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete(): + """Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full. + + Data provider yields one sample of length 6 per iteration. + _fill_buffer keeps restarting iterator until next sample exceeds budget. + Budget = 2 * 10 = 20 tokens. + 3 samples (6*3=18) fit; 4th sample (24) exceeds budget. + So buffer will have 4 samples after fill_buffer. + """ + samples = [_make_model_input(6, 0)] + data_provider = _RestartableDataProvider([[sample] for sample in samples]) + + batch_generator = BatchGenerator.__new__(BatchGenerator) + batch_generator.batching_strategy = "dynamic_padding_free" + batch_generator.micro_batch_size = 2 + batch_generator.num_micro_batch = 1 + batch_generator._buffer = StatefulBuffer() + batch_generator._data_provider = data_provider + batch_generator._data_iter = iter(data_provider) + batch_generator._batch_info = { + "micro_batch_size": 2, + "num_micro_batch": 1, + "cutoff_len": 10, + } + + # Execute fill buffer (will restart iterator multiple times to collect enough samples) + batch_generator._fill_buffer() + + # Buffer after restarts: + # 3 samples can fit (18 tokens) + # 4th sample is kept in buffer for next batch + # => num_iters = 4 + assert data_provider.num_iters == 4 + assert _get_dynamic_padding_free_micro_batch_sizes( + batch_generator._buffer.samples, batch_generator._batch_info + ) == [3] + + batch = batch_generator._generate_batch() + + # Output batch: + # dynamic_padding_free returns [micro_batch_0] + # 3 samples packed into shape [1, 18] + assert batch is not None + assert len(batch) == 1 + assert batch[0]["input_ids"].shape == (1, 18) + assert len(batch_generator._buffer) == 1