[v1] Implement dynamic padding-free stretrgy for batching (#10507)

Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
This commit is contained in:
cxy
2026-05-25 20:40:21 +08:00
committed by GitHub
parent 16ff5a23cb
commit 8e68764b65
2 changed files with 246 additions and 1 deletions

View File

@@ -84,6 +84,37 @@ def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchI
return sizes
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
cutoff_len = batch_info["cutoff_len"]
sizes = []
index = 0
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
current_tokens = 0
used = 0
is_complete = False
while index + used < len(samples):
sample = samples[index + used]
sample_len = min(len(sample["input_ids"]), cutoff_len)
if current_tokens + sample_len > budget:
is_complete = True
break
current_tokens += sample_len
used += 1
if used <= 0 or not is_complete:
break
sizes.append(used)
index += used
return sizes
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
"""Pack fixed samples into one padding-free sequence without a token budget."""
packed: dict[str, list[Any]] = {}
@@ -206,3 +237,47 @@ def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInf
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
return batch
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
return 1
@BatchingPlugin("dynamic_padding_free").register("compute_length")
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
return ceil(len(data_provider) / batch_size)
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
def fill_dynamic_padding_free_buffer(
buffer: StatefulBuffer,
batch_info: BatchInfo,
next_samples: Callable[[bool], list[ModelInput] | None],
) -> None:
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
samples = next_samples(True)
if samples is None:
break
buffer.put(samples)
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
return None
batch = []
cutoff_len = batch_info["cutoff_len"]
for num_samples in micro_batch_sample_counts:
samples = buffer.get(num_samples)
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
if packed_batch is None:
return None
batch.append(packed_batch)
return batch