mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[v1] Implement dynamic padding-free stretrgy for batching (#10507)
Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
This commit is contained in:
@@ -84,6 +84,37 @@ def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchI
|
|||||||
return sizes
|
return sizes
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
|
||||||
|
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
|
||||||
|
cutoff_len = batch_info["cutoff_len"]
|
||||||
|
sizes = []
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
|
||||||
|
current_tokens = 0
|
||||||
|
used = 0
|
||||||
|
is_complete = False
|
||||||
|
|
||||||
|
while index + used < len(samples):
|
||||||
|
sample = samples[index + used]
|
||||||
|
sample_len = min(len(sample["input_ids"]), cutoff_len)
|
||||||
|
|
||||||
|
if current_tokens + sample_len > budget:
|
||||||
|
is_complete = True
|
||||||
|
break
|
||||||
|
|
||||||
|
current_tokens += sample_len
|
||||||
|
used += 1
|
||||||
|
|
||||||
|
if used <= 0 or not is_complete:
|
||||||
|
break
|
||||||
|
|
||||||
|
sizes.append(used)
|
||||||
|
index += used
|
||||||
|
|
||||||
|
return sizes
|
||||||
|
|
||||||
|
|
||||||
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
|
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
|
||||||
"""Pack fixed samples into one padding-free sequence without a token budget."""
|
"""Pack fixed samples into one padding-free sequence without a token budget."""
|
||||||
packed: dict[str, list[Any]] = {}
|
packed: dict[str, list[Any]] = {}
|
||||||
@@ -206,3 +237,47 @@ def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInf
|
|||||||
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
|
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
|
||||||
|
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
@BatchingPlugin("dynamic_padding_free").register("compute_length")
|
||||||
|
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||||
|
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||||
|
return ceil(len(data_provider) / batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
|
||||||
|
def fill_dynamic_padding_free_buffer(
|
||||||
|
buffer: StatefulBuffer,
|
||||||
|
batch_info: BatchInfo,
|
||||||
|
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||||
|
) -> None:
|
||||||
|
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
|
||||||
|
samples = next_samples(True)
|
||||||
|
if samples is None:
|
||||||
|
break
|
||||||
|
buffer.put(samples)
|
||||||
|
|
||||||
|
|
||||||
|
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
|
||||||
|
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||||
|
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
|
||||||
|
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
batch = []
|
||||||
|
cutoff_len = batch_info["cutoff_len"]
|
||||||
|
|
||||||
|
for num_samples in micro_batch_sample_counts:
|
||||||
|
samples = buffer.get(num_samples)
|
||||||
|
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
|
||||||
|
if packed_batch is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
batch.append(packed_batch)
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|||||||
@@ -16,7 +16,11 @@ from llamafactory.v1.config import DataArguments, ModelArguments, TrainingArgume
|
|||||||
from llamafactory.v1.core.data_engine import DataEngine
|
from llamafactory.v1.core.data_engine import DataEngine
|
||||||
from llamafactory.v1.core.model_engine import ModelEngine
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
from llamafactory.v1.core.utils.batching import BatchGenerator
|
from llamafactory.v1.core.utils.batching import BatchGenerator
|
||||||
from llamafactory.v1.plugins.trainer_plugins.batching import BatchingPlugin, _get_dynamic_micro_batch_sizes
|
from llamafactory.v1.plugins.trainer_plugins.batching import (
|
||||||
|
BatchingPlugin,
|
||||||
|
_get_dynamic_micro_batch_sizes,
|
||||||
|
_get_dynamic_padding_free_micro_batch_sizes,
|
||||||
|
)
|
||||||
from llamafactory.v1.utils.constants import IGNORE_INDEX
|
from llamafactory.v1.utils.constants import IGNORE_INDEX
|
||||||
from llamafactory.v1.utils.objects import StatefulBuffer
|
from llamafactory.v1.utils.objects import StatefulBuffer
|
||||||
|
|
||||||
@@ -74,6 +78,7 @@ def test_batching_plugin_data_provider_batch_sizes():
|
|||||||
|
|
||||||
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
|
assert BatchingPlugin("padding_free").get_data_provider_batch_size(batch_info) == 6
|
||||||
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
|
assert BatchingPlugin("dynamic_batching").get_data_provider_batch_size(batch_info) == 1
|
||||||
|
assert BatchingPlugin("dynamic_padding_free").get_data_provider_batch_size(batch_info) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_dynamic_batching():
|
def test_dynamic_batching():
|
||||||
@@ -196,3 +201,168 @@ def test_normal_batching():
|
|||||||
batch = next(iter(batch_generator))
|
batch = next(iter(batch_generator))
|
||||||
assert len(batch) == 2
|
assert len(batch) == 2
|
||||||
assert batch[0]["input_ids"].shape == (4, 10)
|
assert batch[0]["input_ids"].shape == (4, 10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_padding_free():
|
||||||
|
"""Test core logic of dynamic padding free strategy: pack samples by total token budget without padding."""
|
||||||
|
# Construct test samples (lengths: 3, 4, 6, 2, 8, 9)
|
||||||
|
# input_ids breakdown:
|
||||||
|
# sample 0: [0,1,2] (length=3)
|
||||||
|
# sample 1: [10,11,12,13] (length=4)
|
||||||
|
# sample 2: [20,21,22,23,24,25] (length=6)
|
||||||
|
# sample 3: [30,31] (length=2)
|
||||||
|
# sample 4: [40-47] (length=8)
|
||||||
|
# sample 5: [50-58] (length=9)
|
||||||
|
samples = [
|
||||||
|
_make_model_input(3, 0),
|
||||||
|
_make_model_input(4, 10),
|
||||||
|
_make_model_input(6, 20),
|
||||||
|
_make_model_input(2, 30),
|
||||||
|
_make_model_input(8, 40),
|
||||||
|
_make_model_input(9, 50),
|
||||||
|
]
|
||||||
|
# Batch config: micro_batch_size=2 → token budget = cutoff_len * micro_batch_size = 10*2=20
|
||||||
|
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||||
|
|
||||||
|
# Budget=20: 3+4+6+2=15 ≤20 (adding 8 would exceed) → first 4 samples are selected
|
||||||
|
assert _get_dynamic_padding_free_micro_batch_sizes(samples, batch_info) == [4]
|
||||||
|
|
||||||
|
buffer = StatefulBuffer()
|
||||||
|
buffer.put(samples)
|
||||||
|
batch = BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info)
|
||||||
|
|
||||||
|
assert batch is not None
|
||||||
|
assert len(batch) == 1 # num_micro_batch=1
|
||||||
|
packed_batch = batch[0]
|
||||||
|
|
||||||
|
# Total packed length: 3+4+6+2=15 → input_ids shape = (1,15) (no padding)
|
||||||
|
assert packed_batch["input_ids"].shape == (1, 15)
|
||||||
|
|
||||||
|
# Verify input_ids concatenation (first label of non-initial samples set to IGNORE_INDEX)
|
||||||
|
assert packed_batch["input_ids"].tolist() == [
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2, # Sample 0
|
||||||
|
10,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13, # Sample 1
|
||||||
|
20,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
24,
|
||||||
|
25, # Sample 2
|
||||||
|
30,
|
||||||
|
31,
|
||||||
|
] # Sample 3
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify labels (first token of non-initial samples is IGNORE_INDEX)
|
||||||
|
assert packed_batch["labels"].tolist() == [
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2, # Sample 0
|
||||||
|
IGNORE_INDEX,
|
||||||
|
11,
|
||||||
|
12,
|
||||||
|
13, # Sample 1
|
||||||
|
IGNORE_INDEX,
|
||||||
|
21,
|
||||||
|
22,
|
||||||
|
23,
|
||||||
|
24,
|
||||||
|
25, # Sample 2
|
||||||
|
IGNORE_INDEX,
|
||||||
|
31,
|
||||||
|
] # Sample 3
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify attention_mask
|
||||||
|
assert packed_batch["attention_mask"].tolist() == [[1] * 15]
|
||||||
|
|
||||||
|
# Verify position_ids
|
||||||
|
assert packed_batch["position_ids"].tolist() == [
|
||||||
|
[
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2, # Sample 0
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3, # Sample 1
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5, # Sample 2
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
] # Sample 3
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify remaining samples in buffer: 6-4=2 samples (length 8,9)
|
||||||
|
assert len(buffer) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_padding_free_returns_none_when_token_budget_is_incomplete():
|
||||||
|
buffer = StatefulBuffer()
|
||||||
|
buffer.put([_make_model_input(6, 0)])
|
||||||
|
batch_info = {"micro_batch_size": 2, "num_micro_batch": 1, "cutoff_len": 10}
|
||||||
|
|
||||||
|
assert _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) == []
|
||||||
|
assert BatchingPlugin("dynamic_padding_free").generate_batch(buffer, batch_info) is None
|
||||||
|
# Batch generation does not read from the data iterator. It only returns None and keeps
|
||||||
|
# existing samples in the buffer; BatchGenerator._fill_buffer handles refilling.
|
||||||
|
assert len(buffer) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_padding_free_fill_buffer_restarts_until_micro_batch_is_complete():
|
||||||
|
"""Test fill_buffer logic for dynamic_padding_free: restart data iterator until token budget is full.
|
||||||
|
|
||||||
|
Data provider yields one sample of length 6 per iteration.
|
||||||
|
_fill_buffer keeps restarting iterator until next sample exceeds budget.
|
||||||
|
Budget = 2 * 10 = 20 tokens.
|
||||||
|
3 samples (6*3=18) fit; 4th sample (24) exceeds budget.
|
||||||
|
So buffer will have 4 samples after fill_buffer.
|
||||||
|
"""
|
||||||
|
samples = [_make_model_input(6, 0)]
|
||||||
|
data_provider = _RestartableDataProvider([[sample] for sample in samples])
|
||||||
|
|
||||||
|
batch_generator = BatchGenerator.__new__(BatchGenerator)
|
||||||
|
batch_generator.batching_strategy = "dynamic_padding_free"
|
||||||
|
batch_generator.micro_batch_size = 2
|
||||||
|
batch_generator.num_micro_batch = 1
|
||||||
|
batch_generator._buffer = StatefulBuffer()
|
||||||
|
batch_generator._data_provider = data_provider
|
||||||
|
batch_generator._data_iter = iter(data_provider)
|
||||||
|
batch_generator._batch_info = {
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"num_micro_batch": 1,
|
||||||
|
"cutoff_len": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Execute fill buffer (will restart iterator multiple times to collect enough samples)
|
||||||
|
batch_generator._fill_buffer()
|
||||||
|
|
||||||
|
# Buffer after restarts:
|
||||||
|
# 3 samples can fit (18 tokens)
|
||||||
|
# 4th sample is kept in buffer for next batch
|
||||||
|
# => num_iters = 4
|
||||||
|
assert data_provider.num_iters == 4
|
||||||
|
assert _get_dynamic_padding_free_micro_batch_sizes(
|
||||||
|
batch_generator._buffer.samples, batch_generator._batch_info
|
||||||
|
) == [3]
|
||||||
|
|
||||||
|
batch = batch_generator._generate_batch()
|
||||||
|
|
||||||
|
# Output batch:
|
||||||
|
# dynamic_padding_free returns [micro_batch_0]
|
||||||
|
# 3 samples packed into shape [1, 18]
|
||||||
|
assert batch is not None
|
||||||
|
assert len(batch) == 1
|
||||||
|
assert batch[0]["input_ids"].shape == (1, 18)
|
||||||
|
assert len(batch_generator._buffer) == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user