mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[v1] Implement dynamic padding-free stretrgy for batching (#10507)
Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
This commit is contained in:
@@ -16,7 +16,11 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||
from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes
|
||||
from llamafactory.v1.plugins.trainer_plugins.batching import (
|
||||
BatchingPlugin,
|
||||
_get_dynamic_micro_batch_sizes,
|
||||
_get_dynamic_padding_free_micro_batch_sizes,
|
||||
)
|
||||
from llamafactory.v1.utils.constants import IGNORE_INDEX
|
||||
from llamafactory.v1.utils.objects import StatefulBuffer
|
||||
|
||||
@@ -74,6 +78,7 @@ def test_batching_plugin_data_provider_batch_sizes():
|
||||
|
||||
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
|
||||
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
|
||||
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
|
||||
|
||||
|
||||
def test_dynamic_batching():
|
||||
@@ -196,3 +201,168 @@ def test_normal_batching():
|
||||
batch = next(iter(batch_generator))
|
||||
assert len(batch) == 2
|
||||
assert batch[0]["input_ids"].shape == (4, 10)
|
||||
|
||||
|
||||
def test_dynamic_padding_free():
|
||||
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
|
||||
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
|
||||
# input_ids breakdown:
|
||||
# sample 0: [0,1,2] (length=3)
|
||||
# sample 1: [10,11,12,13] (length=4)
|
||||
# sample 2: [20,21,22,23,24,25] (length=6)
|
||||
# sample 3: [30,31] (length=2)
|
||||
# sample 4: [40-47] (length=8)
|
||||
# sample 5: [50-58] (length=9)
|
||||
samples = [
|
||||
_make_model_input(3, 0),
|
||||
_make_model_input(4, 10),
|
||||
_make_model_input(6, 20),
|
||||
_make_model_input(2, 30),
|
||||
_make_model_input(8, 40),
|
||||
_make_model_input(9, 50),
|
||||
]
|
||||
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
|
||||
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
|
||||
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put(samples)
|
||||
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
|
||||
|
||||
assert batch is not None
|
||||
assert len(batch) == 1 # num_micro_batch=1
|
||||
packed_batch = batch[0]
|
||||
|
||||
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
|
||||
assert packed_batch["input_ids"].shape == (1, 15)
|
||||
|
||||
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
|
||||
assert packed_batch["input_ids"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13, # Sample 1
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25, # Sample 2
|
||||
30,
|
||||
31,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
|
||||
assert packed_batch["labels"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
IGNORE_INDEX,
|
||||
11,
|
||||
12,
|
||||
13, # Sample 1
|
||||
IGNORE_INDEX,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25, # Sample 2
|
||||
IGNORE_INDEX,
|
||||
31,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify attention_mask
|
||||
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
|
||||
|
||||
# Verify position_ids
|
||||
assert packed_batch["position_ids"].tolist() == [
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2, # Sample 0
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3, # Sample 1
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5, # Sample 2
|
||||
0,
|
||||
1,
|
||||
] # Sample 3
|
||||
]
|
||||
|
||||
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
|
||||
assert len(buffer) == 2
|
||||
|
||||
|
||||
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
|
||||
buffer = StatefulBuffer()
|
||||
buffer.put([_make_model_input(6, 0)])
|
||||
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||
|
||||
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
|
||||
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
|
||||
# Batch generation does not read from the data iterator. It only returns None and keeps
|
||||
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
|
||||
assert len(buffer) == 1
|
||||
|
||||
|
||||
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
|
||||
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
|
||||
|
||||
Data provider yields one sample of length 6 per iteration.
|
||||
_fill_buffer keeps restarting iterator until next sample exceeds budget.
|
||||
Budget = 2 * 10 = 20 tokens.
|
||||
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
|
||||
So buffer will have 4 samples after fill_buffer.
|
||||
"""
|
||||
samples = [_make_model_input(6, 0)]
|
||||
data_provider = _RestartableDataProvider([[sample] for sample in samples])
|
||||
|
||||
batch_generator = BatchGenerator.__new__(BatchGenerator)
|
||||
batch_generator.batching_strategy = "dynamic_padding_free"
|
||||
batch_generator.micro_batch_size = 2
|
||||
batch_generator.num_micro_batch = 1
|
||||
batch_generator._buffer = StatefulBuffer()
|
||||
batch_generator._data_provider = data_provider
|
||||
batch_generator._data_iter = iter(data_provider)
|
||||
batch_generator._batch_info = {
|
||||
"micro_batch_size": 2,
|
||||
"num_micro_batch": 1,
|
||||
"cutoff_len": 10,
|
||||
}
|
||||
|
||||
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
|
||||
batch_generator._fill_buffer()
|
||||
|
||||
# Buffer after restarts:
|
||||
# 3 samples can fit (18 tokens)
|
||||
# 4th sample is kept in buffer for next batch
|
||||
# => num_iters = 4
|
||||
assert data_provider.num_iters == 4
|
||||
assert _get_dynamic_padding_free_micro_batch_sizes(
|
||||
batch_generator._buffer.samples, batch_generator._batch_info
|
||||
) == [3]
|
||||
|
||||
batch = batch_generator._generate_batch()
|
||||
|
||||
# Output batch:
|
||||
# dynamic_padding_free returns [micro_batch_0]
|
||||
# 3 samples packed into shape [1, 18]
|
||||
assert batch is not None
|
||||
assert len(batch) == 1
|
||||
assert batch[0]["input_ids"].shape == (1, 18)
|
||||
assert len(batch_generator._buffer) == 1
|
||||
|
||||
Reference in New Issue
Block a user