# Copyright 2025 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Callable from math import ceil from typing import Any import torch from torch.utils.data import default_collate from ...utils.constants import IGNORE_INDEX from ...utils.helper import pad_and_truncate from ...utils.objects import StatefulBuffer from ...utils.plugin import BasePlugin from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput class BatchingPlugin(BasePlugin): def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int: """Return the raw data provider batch size for this batching strategy.""" return self["get_data_provider_batch_size"](batch_info) def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int: """Compute the length of the batch generator. The approximate length is used to calculate the lr schedule. """ return self["compute_length"](data_provider, batch_info) def fill_buffer( self, buffer: StatefulBuffer, batch_info: BatchInfo, next_samples: Callable[[bool], list[ModelInput] | None], ) -> None: """Fill the buffer with data.""" return self["fill_buffer"](buffer, batch_info, next_samples) def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: """Generate a batch from the buffer.""" return self["generate_batch"](buffer, batch_info) def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]: """Return sample counts for micro batches formed by one padded-token budget.""" budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"] cutoff_len = batch_info["cutoff_len"] sizes = [] index = 0 while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]: max_sample_len = 0 used = 0 is_complete = False while index + used < len(samples): sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len) padded_tokens = max(max_sample_len, sample_len) * (used + 1) if used > 0 and padded_tokens > budget: is_complete = True break max_sample_len = max(max_sample_len, sample_len) used += 1 if max_sample_len * used >= budget: is_complete = True break if used == 0 or not is_complete: break sizes.append(used) index += used return sizes def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None: """Pack fixed samples into one padding-free sequence without a token budget.""" packed: dict[str, list[Any]] = {} position_ids: list[int] = [] for sample_index, sample in enumerate(samples): # Padding-free still truncates each sample by cutoff_len before packing # all samples into one contiguous sequence. sample_len = min(len(sample["input_ids"]), cutoff_len) if sample_len <= 0: continue for key, value in sample.items(): if key in ("attention_mask", "position_ids") or isinstance(value, str): continue if key not in packed: packed[key] = [] sliced_value = list(value[:sample_len]) if sample_index > 0 and sliced_value: if key == "labels": sliced_value[0] = IGNORE_INDEX elif key == "loss_weights": sliced_value[0] = 0.0 packed[key].extend(sliced_value) position_ids.extend(range(sample_len)) if not position_ids: return None packed["position_ids"] = position_ids packed["attention_mask"] = [1] * len(position_ids) return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()} @BatchingPlugin("padding_free").register("get_data_provider_batch_size") def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int: return batch_info["micro_batch_size"] * batch_info["num_micro_batch"] @BatchingPlugin("padding_free").register("compute_length") def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int: return len(data_provider) @BatchingPlugin("padding_free").register("fill_buffer") def fill_padding_free_buffer( buffer: StatefulBuffer, batch_info: BatchInfo, next_samples: Callable[[bool], list[ModelInput] | None], ) -> None: while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]: samples = next_samples(False) if samples is None: break buffer.put(samples) @BatchingPlugin("padding_free").register("generate_batch") def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: micro_batch_size = batch_info["micro_batch_size"] num_micro_batch = batch_info["num_micro_batch"] cutoff_len = batch_info["cutoff_len"] batch_size = micro_batch_size * num_micro_batch if len(buffer) < batch_size: return None samples = buffer.get(batch_size) batch = [] for i in range(num_micro_batch): micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size] packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len) if packed_micro_batch is None: return None batch.append(packed_micro_batch) return batch @BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size") def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int: return 1 @BatchingPlugin("dynamic_batching").register("compute_length") def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int: batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"] return ceil(len(data_provider) / batch_size) @BatchingPlugin("dynamic_batching").register("fill_buffer") def fill_dynamic_batching_buffer( buffer: StatefulBuffer, batch_info: BatchInfo, next_samples: Callable[[bool], list[ModelInput] | None], ) -> None: while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]: samples = next_samples(True) if samples is None: break buffer.put(samples) @BatchingPlugin("dynamic_batching").register("generate_batch") def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info) if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]: return None batch = [] cutoff_len = batch_info["cutoff_len"] for num_samples in micro_batch_sample_counts: samples = buffer.get(num_samples) batch.append(default_collate(pad_and_truncate(samples, cutoff_len))) return batch