mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 02:48:54 +08:00
[v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469)
This commit is contained in:
@@ -20,7 +20,6 @@ import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@@ -584,7 +583,7 @@ class ModuleProfilerCallback(TrainerCallback):
|
||||
if matched:
|
||||
logger.info_rank0(
|
||||
f"ModuleProfiler: registered hooks on {len(matched)} modules: {matched[:5]}"
|
||||
+ (f" ... (+{len(matched)-5} more)" if len(matched) > 5 else "")
|
||||
+ (f" ... (+{len(matched) - 5} more)" if len(matched) > 5 else "")
|
||||
)
|
||||
else:
|
||||
logger.warning_rank0(f"ModuleProfiler: no modules matched patterns {self.patterns}")
|
||||
@@ -616,7 +615,7 @@ class ModuleProfilerCallback(TrainerCallback):
|
||||
bwd = self._backward_times.get(name, [])
|
||||
fwd_mean = sum(fwd) / len(fwd) if fwd else 0.0
|
||||
bwd_mean = sum(bwd) / len(bwd) if bwd else 0.0
|
||||
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean+bwd_mean:.3f}")
|
||||
lines.append(f" {name}: fwd={fwd_mean:.3f}, bwd={bwd_mean:.3f}, total={fwd_mean + bwd_mean:.3f}")
|
||||
|
||||
logger.info_rank0("\n".join(lines))
|
||||
self._forward_times.clear()
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils.types import AttentionFunction
|
||||
from .arg_parser import InputArgument, get_args
|
||||
from .arg_utils import BatchingStrategy, ModelClass, SampleBackend
|
||||
from .data_args import DataArguments
|
||||
@@ -21,6 +22,7 @@ from .training_args import TrainingArguments
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionFunction",
|
||||
"BatchingStrategy",
|
||||
"DataArguments",
|
||||
"InputArgument",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..utils.types import AttentionFunction
|
||||
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
|
||||
|
||||
@@ -32,6 +33,12 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
)
|
||||
flash_attn: AttentionFunction = field(
|
||||
default=AttentionFunction.SDPA,
|
||||
metadata={
|
||||
"help": "Attention implementation to use: eager, sdpa, or flash_attention_2. SDPA is the default implementation for models."
|
||||
},
|
||||
)
|
||||
model_class: ModelClass = field(
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
@@ -54,6 +61,12 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
supported_flash_attn = [item.value for item in AttentionFunction]
|
||||
if self.flash_attn not in supported_flash_attn:
|
||||
raise ValueError(
|
||||
f"Unsupported `flash_attn`: {self.flash_attn}. Supported values are: {supported_flash_attn}."
|
||||
)
|
||||
|
||||
self.init_config = get_plugin_config(self.init_config)
|
||||
self.peft_config = get_plugin_config(self.peft_config)
|
||||
self.kernel_config = get_plugin_config(self.kernel_config)
|
||||
|
||||
@@ -116,3 +116,9 @@ class TrainingArguments:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
self.optim_config = get_plugin_config(self.optim_config)
|
||||
self.lr_scheduler_config = get_plugin_config(self.lr_scheduler_config)
|
||||
|
||||
if str(self.batching_strategy) == str(BatchingStrategy.DYNAMIC_BATCHING):
|
||||
if self.max_steps is None or self.max_steps <= 0:
|
||||
raise ValueError("`dynamic_batching` requires `max_steps` because it is step-driven.")
|
||||
if self.save_epochs is not None:
|
||||
raise ValueError("`save_epochs` is not supported with `dynamic_batching`; use `save_steps` instead.")
|
||||
|
||||
@@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ..accelerator.helper import ReduceOp
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import TrainingArguments
|
||||
from ..config import BatchingStrategy, TrainingArguments
|
||||
from ..utils import logging
|
||||
from ..utils.callbacks import (
|
||||
CallbackHandler,
|
||||
@@ -147,13 +147,19 @@ class BaseTrainer:
|
||||
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
|
||||
|
||||
if model.config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_rank0(
|
||||
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
|
||||
raise ValueError(
|
||||
"Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
|
||||
)
|
||||
model.config._attn_implementation = "flash_attention_2"
|
||||
|
||||
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
if (
|
||||
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
|
||||
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
|
||||
):
|
||||
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
|
||||
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
dataset=self.train_dataset,
|
||||
renderer=self.renderer,
|
||||
@@ -237,6 +243,7 @@ class BaseTrainer:
|
||||
self.train_batch_generator.set_epoch(epoch)
|
||||
self.callback_handler.on_epoch_begin(self.args, self.state)
|
||||
|
||||
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
|
||||
for micro_batches in self.train_batch_generator:
|
||||
self.global_step += 1
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ class ModelEngine:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
|
||||
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
|
||||
|
||||
if self.args.quant_config is not None:
|
||||
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||
@@ -164,6 +165,7 @@ class ModelEngine:
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
attn_implementation=self.args.flash_attn,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
@@ -42,6 +42,8 @@ from .rendering import Renderer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
__all__ = ["BatchGenerator"]
|
||||
|
||||
|
||||
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_size = batch_info["micro_batch_size"]
|
||||
@@ -102,19 +104,18 @@ class BatchGenerator(Iterator):
|
||||
if not self.drop_last:
|
||||
raise ValueError("Drop last must be True.")
|
||||
|
||||
self._batch_info: BatchInfo = {
|
||||
"micro_batch_size": self.micro_batch_size,
|
||||
"num_micro_batch": self.num_micro_batch,
|
||||
"cutoff_len": self.cutoff_len,
|
||||
}
|
||||
|
||||
self._init_data_provider()
|
||||
|
||||
self._is_resuming: bool = False
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._buffer = StatefulBuffer()
|
||||
|
||||
self._batch_info: BatchInfo = {
|
||||
"micro_batch_size": self.micro_batch_size,
|
||||
"num_micro_batch": self.num_micro_batch,
|
||||
"cutoff_len": self.cutoff_len,
|
||||
"data_iter": self._data_iter,
|
||||
}
|
||||
|
||||
logger.info_rank0(
|
||||
f"Init unified data loader with global batch size {self.global_batch_size}, "
|
||||
f"micro batch size {self.micro_batch_size}, "
|
||||
@@ -137,12 +138,19 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
batch_size = self.micro_batch_size * self.num_micro_batch
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info)
|
||||
|
||||
generator_seed = torch.Generator()
|
||||
generator_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.batching_workers,
|
||||
collate_fn=self.renderer.process_samples,
|
||||
@@ -156,8 +164,7 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
|
||||
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
@@ -190,7 +197,7 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
|
||||
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples)
|
||||
|
||||
def _generate_batch(self) -> list[BatchInput] | None:
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
@@ -200,6 +207,20 @@ class BatchGenerator(Iterator):
|
||||
|
||||
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
|
||||
|
||||
def _next_samples(self, restart: bool) -> list[ModelInput] | None:
|
||||
try:
|
||||
return next(self._data_iter)
|
||||
except StopIteration:
|
||||
if not restart:
|
||||
return None
|
||||
|
||||
# Dynamic batching may restart the provider to fill one token-budgeted batch.
|
||||
self._data_iter = iter(self._data_provider)
|
||||
try:
|
||||
return next(self._data_iter)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"buffer": self._buffer.state_dict(),
|
||||
|
||||
@@ -12,23 +12,197 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Callable
|
||||
from math import ceil
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils.data import default_collate
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import pad_and_truncate
|
||||
from ...utils.objects import StatefulBuffer
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import BatchInfo, BatchInput, DataLoader
|
||||
from ...utils.types import BatchInfo, BatchInput, DataLoader, ModelInput
|
||||
|
||||
|
||||
class BatchingPlugin(BasePlugin):
|
||||
def compute_length(self, data_provider: DataLoader) -> int:
|
||||
def get_data_provider_batch_size(self, batch_info: BatchInfo) -> int:
|
||||
"""Return the raw data provider batch size for this batching strategy."""
|
||||
return self["get_data_provider_batch_size"](batch_info)
|
||||
|
||||
def compute_length(self, data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
"""Compute the length of the batch generator.
|
||||
|
||||
The approximate length is used to calculate the lr schedule.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
return self["compute_length"](data_provider, batch_info)
|
||||
|
||||
def fill_buffer(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> None:
|
||||
def fill_buffer(
|
||||
self,
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
"""Fill the buffer with data."""
|
||||
raise NotImplementedError()
|
||||
return self["fill_buffer"](buffer, batch_info, next_samples)
|
||||
|
||||
def generate_batch(self, buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
"""Generate a batch from the buffer."""
|
||||
raise NotImplementedError()
|
||||
return self["generate_batch"](buffer, batch_info)
|
||||
|
||||
|
||||
def _get_dynamic_micro_batch_sizes(samples: list[ModelInput], batch_info: BatchInfo) -> list[int]:
|
||||
"""Return sample counts for micro batches formed by one padded-token budget."""
|
||||
budget = batch_info["cutoff_len"] * batch_info["micro_batch_size"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
sizes = []
|
||||
index = 0
|
||||
while index < len(samples) and len(sizes) < batch_info["num_micro_batch"]:
|
||||
max_sample_len = 0
|
||||
used = 0
|
||||
is_complete = False
|
||||
while index + used < len(samples):
|
||||
sample_len = min(len(samples[index + used]["input_ids"]), cutoff_len)
|
||||
padded_tokens = max(max_sample_len, sample_len) * (used + 1)
|
||||
if used > 0 and padded_tokens > budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
max_sample_len = max(max_sample_len, sample_len)
|
||||
used += 1
|
||||
if max_sample_len * used >= budget:
|
||||
is_complete = True
|
||||
break
|
||||
|
||||
if used == 0 or not is_complete:
|
||||
break
|
||||
|
||||
sizes.append(used)
|
||||
index += used
|
||||
|
||||
return sizes
|
||||
|
||||
|
||||
def _pack_padding_free_samples(samples: list[ModelInput], cutoff_len: int) -> BatchInput | None:
|
||||
"""Pack fixed samples into one padding-free sequence without a token budget."""
|
||||
packed: dict[str, list[Any]] = {}
|
||||
position_ids: list[int] = []
|
||||
|
||||
for sample_index, sample in enumerate(samples):
|
||||
# Padding-free still truncates each sample by cutoff_len before packing
|
||||
# all samples into one contiguous sequence.
|
||||
sample_len = min(len(sample["input_ids"]), cutoff_len)
|
||||
if sample_len <= 0:
|
||||
continue
|
||||
|
||||
for key, value in sample.items():
|
||||
if key in ("attention_mask", "position_ids") or isinstance(value, str):
|
||||
continue
|
||||
|
||||
if key not in packed:
|
||||
packed[key] = []
|
||||
|
||||
sliced_value = list(value[:sample_len])
|
||||
if sample_index > 0 and sliced_value:
|
||||
if key == "labels":
|
||||
sliced_value[0] = IGNORE_INDEX
|
||||
elif key == "loss_weights":
|
||||
sliced_value[0] = 0.0
|
||||
|
||||
packed[key].extend(sliced_value)
|
||||
|
||||
position_ids.extend(range(sample_len))
|
||||
|
||||
if not position_ids:
|
||||
return None
|
||||
|
||||
packed["position_ids"] = position_ids
|
||||
packed["attention_mask"] = [1] * len(position_ids)
|
||||
return {key: torch.tensor(value).unsqueeze(0) for key, value in packed.items()}
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("get_data_provider_batch_size")
|
||||
def get_padding_free_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("compute_length")
|
||||
def compute_padding_free_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
return len(data_provider)
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("fill_buffer")
|
||||
def fill_padding_free_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(buffer) < batch_info["micro_batch_size"] * batch_info["num_micro_batch"]:
|
||||
samples = next_samples(False)
|
||||
if samples is None:
|
||||
break
|
||||
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("padding_free").register("generate_batch")
|
||||
def generate_padding_free_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_size = batch_info["micro_batch_size"]
|
||||
num_micro_batch = batch_info["num_micro_batch"]
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
batch_size = micro_batch_size * num_micro_batch
|
||||
if len(buffer) < batch_size:
|
||||
return None
|
||||
|
||||
samples = buffer.get(batch_size)
|
||||
batch = []
|
||||
for i in range(num_micro_batch):
|
||||
micro_batch = samples[i * micro_batch_size : (i + 1) * micro_batch_size]
|
||||
packed_micro_batch = _pack_padding_free_samples(micro_batch, cutoff_len)
|
||||
if packed_micro_batch is None:
|
||||
return None
|
||||
|
||||
batch.append(packed_micro_batch)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("get_data_provider_batch_size")
|
||||
def get_dynamic_batching_data_provider_batch_size(batch_info: BatchInfo) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("compute_length")
|
||||
def compute_dynamic_batching_length(data_provider: DataLoader, batch_info: BatchInfo) -> int:
|
||||
batch_size = batch_info["micro_batch_size"] * batch_info["num_micro_batch"]
|
||||
return ceil(len(data_provider) / batch_size)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("fill_buffer")
|
||||
def fill_dynamic_batching_buffer(
|
||||
buffer: StatefulBuffer,
|
||||
batch_info: BatchInfo,
|
||||
next_samples: Callable[[bool], list[ModelInput] | None],
|
||||
) -> None:
|
||||
while len(_get_dynamic_micro_batch_sizes(buffer.samples, batch_info)) < batch_info["num_micro_batch"]:
|
||||
samples = next_samples(True)
|
||||
if samples is None:
|
||||
break
|
||||
|
||||
buffer.put(samples)
|
||||
|
||||
|
||||
@BatchingPlugin("dynamic_batching").register("generate_batch")
|
||||
def generate_dynamic_batching_batch(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_sample_counts = _get_dynamic_micro_batch_sizes(buffer.samples, batch_info)
|
||||
if len(micro_batch_sample_counts) < batch_info["num_micro_batch"]:
|
||||
return None
|
||||
|
||||
batch = []
|
||||
cutoff_len = batch_info["cutoff_len"]
|
||||
for num_samples in micro_batch_sample_counts:
|
||||
samples = buffer.get(num_samples)
|
||||
batch.append(default_collate(pad_and_truncate(samples, cutoff_len)))
|
||||
|
||||
return batch
|
||||
|
||||
@@ -61,6 +61,9 @@ def load_checkpoint_fsdp2(model: HFModel, optimizer: torch.optim.Optimizer, ckpt
|
||||
|
||||
@DistributedPlugin("deepspeed").register()
|
||||
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||
if dist_config.get("cp_size", 1) > 1:
|
||||
raise ValueError("CP currently requires `dist_config.name: fsdp2`.")
|
||||
|
||||
from .deepspeed import DeepSpeedEngine
|
||||
|
||||
return DeepSpeedEngine(
|
||||
|
||||
@@ -33,6 +33,10 @@ class StatefulBuffer:
|
||||
def size(self) -> int:
|
||||
return self._buffer_size
|
||||
|
||||
@property
|
||||
def samples(self) -> list[ModelInput]:
|
||||
return self._buffer
|
||||
|
||||
def put(self, samples: list[ModelInput]) -> None:
|
||||
"""Add samples to the buffer."""
|
||||
num_tokens = sum(len(sample["input_ids"]) for sample in samples)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from enum import StrEnum, unique
|
||||
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
@@ -54,6 +54,13 @@ else:
|
||||
ProcessGroup = None
|
||||
|
||||
|
||||
@unique
|
||||
class AttentionFunction(StrEnum):
|
||||
EAGER = "eager"
|
||||
SDPA = "sdpa"
|
||||
FLASH_ATTENTION_2 = "flash_attention_2"
|
||||
|
||||
|
||||
class DatasetInfo(TypedDict, total=False):
|
||||
path: str
|
||||
"""Local file path."""
|
||||
@@ -171,8 +178,6 @@ class BatchInfo(TypedDict):
|
||||
"""Number of micro batches."""
|
||||
cutoff_len: int
|
||||
"""Cutoff length."""
|
||||
data_iter: Iterator[list[ModelInput]]
|
||||
"""Data iterator."""
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
|
||||
Reference in New Issue
Block a user