[v1] Implement dynamic padding-free stretrgy for batching (#10507)

Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
This commit is contained in:
cxy
2026-05-25 20:40:21 +08:00
committed by GitHub
parent 16ff5a23cb
commit 8e68764b65
2 changed files with 246 additions and 1 deletions

View File

@@ -84,6 +84,37 @@ def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchI
return sizes return sizes
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
current_tokens = 0
used = 0
is_complete = False
while index + used < len(samples):
sample = samples[index + used]
sample_len = min(len(sample["input_ids"]), cutoff_len)
if current_tokens + sample_len > budget:
is_complete = True
break
current_tokens += sample_len
used += 1
if used <= 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None: def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
"""Pack fixed samples into one padding-free sequence without a token budget.""" """Pack fixed samples into one padding-free sequence without a token budget."""
packed: dict[str, list[Any]] = {} packed: dict[str, list[Any]] = {}
@@ -206,3 +237,47 @@ def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInf
batch.append(default_collate(pad_and_truncate(samples, cutoff_len))) batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
return batch return batch
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_padding_free").register("compute_length")
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
def fill_dynamic_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
if packed_batch is None:
return None
batch.append(packed_batch)
return batch

View File

@@ -16,7 +16,11 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
from llamafactory.v1.core.data_engine import DataEngine from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.core.utils.batching import BatchGenerator from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes from llamafactory.v1.plugins.trainer_plugins.batching import (
BatchingPlugin,
_get_dynamic_micro_batch_sizes,
_get_dynamic_padding_free_micro_batch_sizes,
)
from llamafactory.v1.utils.constants import IGNORE_INDEX from llamafactory.v1.utils.constants import IGNORE_INDEX
from llamafactory.v1.utils.objects import StatefulBuffer from llamafactory.v1.utils.objects import StatefulBuffer
@@ -74,6 +78,7 @@ def test_batching_plugin_data_provider_batch_sizes():
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6 assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1 assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
def test_dynamic_batching(): def test_dynamic_batching():
@@ -196,3 +201,168 @@ def test_normal_batching():
batch = next(iter(batch_generator)) batch = next(iter(batch_generator))
assert len(batch) == 2 assert len(batch) == 2
assert batch[0]["input_ids"].shape == (4, 10) assert batch[0]["input_ids"].shape == (4, 10)
def test_dynamic_padding_free():
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
# input_ids breakdown:
# sample 0: [0,1,2] (length=3)
# sample 1: [10,11,12,13] (length=4)
# sample 2: [20,21,22,23,24,25] (length=6)
# sample 3: [30,31] (length=2)
# sample 4: [40-47] (length=8)
# sample 5: [50-58] (length=9)
samples = [
_make_model_input(3, 0),
_make_model_input(4, 10),
_make_model_input(6, 20),
_make_model_input(2, 30),
_make_model_input(8, 40),
_make_model_input(9, 50),
]
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
buffer = StatefulBuffer()
buffer.put(samples)
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
assert batch is not None
assert len(batch) == 1 # num_micro_batch=1
packed_batch = batch[0]
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
assert packed_batch["input_ids"].shape == (1, 15)
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
assert packed_batch["input_ids"].tolist() == [
[
0,
1,
2, # Sample 0
10,
11,
12,
13, # Sample 1
20,
21,
22,
23,
24,
25, # Sample 2
30,
31,
] # Sample 3
]
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
assert packed_batch["labels"].tolist() == [
[
0,
1,
2, # Sample 0
IGNORE_INDEX,
11,
12,
13, # Sample 1
IGNORE_INDEX,
21,
22,
23,
24,
25, # Sample 2
IGNORE_INDEX,
31,
] # Sample 3
]
# Verify attention_mask
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
# Verify position_ids
assert packed_batch["position_ids"].tolist() == [
[
0,
1,
2, # Sample 0
0,
1,
2,
3, # Sample 1
0,
1,
2,
3,
4,
5, # Sample 2
0,
1,
] # Sample 3
]
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
assert len(buffer) == 2
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
buffer = StatefulBuffer()
buffer.put([_make_model_input(6, 0)])
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
# Batch generation does not read from the data iterator. It only returns None and keeps
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
assert len(buffer) == 1
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
Data provider yields one sample of length 6 per iteration.
_fill_buffer keeps restarting iterator until next sample exceeds budget.
Budget = 2 * 10 = 20 tokens.
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
So buffer will have 4 samples after fill_buffer.
"""
samples = [_make_model_input(6, 0)]
data_provider = _RestartableDataProvider([[sample] for sample in samples])
batch_generator = BatchGenerator.__new__(BatchGenerator)
batch_generator.batching_strategy = "dynamic_padding_free"
batch_generator.micro_batch_size = 2
batch_generator.num_micro_batch = 1
batch_generator._buffer = StatefulBuffer()
batch_generator._data_provider = data_provider
batch_generator._data_iter = iter(data_provider)
batch_generator._batch_info = {
"micro_batch_size": 2,
"num_micro_batch": 1,
"cutoff_len": 10,
}
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
batch_generator._fill_buffer()
# Buffer after restarts:
# 3 samples can fit (18 tokens)
# 4th sample is kept in buffer for next batch
# => num_iters = 4
assert data_provider.num_iters == 4
assert _get_dynamic_padding_free_micro_batch_sizes(
batch_generator._buffer.samples, batch_generator._batch_info
) == [3]
batch = batch_generator._generate_batch()
# Output batch:
# dynamic_padding_free returns [micro_batch_0]
# 3 samples packed into shape [1, 18]
assert batch is not None
assert len(batch) == 1
assert batch[0]["input_ids"].shape == (1, 18)
assert len(batch_generator._buffer) == 1