mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 12:48:55 +08:00
[v1] Implement dynamic padding-free stretrgy for batching (#10507)
Co-authored-by: cxy-thinkbook <xuanyuchen@seu.edu.cn>
This commit is contained in:
@@ -84,6 +84,37 @@ def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchI
|
||||
return sizes
|
||||
|
||||
|
||||
def _get_dynamic_padding_free_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
|
||||
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
sizes = []
|
||||
index = 0
|
||||
|
||||
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
|
||||
current_tokens = 0
|
||||
used = 0
|
||||
is_complete = False
|
||||
|
||||
while index + used < len(samples):
|
||||
sample = samples[index + used]
|
||||
sample_len = min(len(sample["input_ids"]), cutoff_len)
|
||||
|
||||
if current_tokens + sample_len > budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
current_tokens += sample_len
|
||||
used += 1
|
||||
|
||||
if used <= 0 or not is_complete:
|
||||
break
|
||||
|
||||
sizes.append(used)
|
||||
index += used
|
||||
|
||||
return sizes
|
||||
|
||||
|
||||
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
|
||||
"""Pack fixed samples into one padding-free sequence without a token budget."""
|
||||
packed: dict[str, list[Any]] = {}
|
||||
@@ -206,3 +237,47 @@ def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInf
|
||||
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("get_data_provider_batch_size")
|
||||
def get_dynamic_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("compute_length")
|
||||
def compute_dynamic_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
return ceil(len(data_provider) / batch_size)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("fill_buffer")
|
||||
def fill_dynamic_padding_free_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(_get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
|
||||
samples = next_samples(True)
|
||||
if samples is None:
|
||||
break
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_padding_free").register("generate_batch")
|
||||
def generate_dynamic_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_sample_counts = _get_dynamic_padding_free_micro_batch_sizes(buffer.samples, batch_info)
|
||||
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
|
||||
return None
|
||||
|
||||
batch = []
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
|
||||
for num_samples in micro_batch_sample_counts:
|
||||
samples = buffer.get(num_samples)
|
||||
packed_batch = _pack_padding_free_samples(samples, cutoff_len)
|
||||
if packed_batch is None:
|
||||
return None
|
||||
|
||||
batch.append(packed_batch)
|
||||
|
||||
return batch
|
||||
|
||||
Reference in New Issue
Block a user