mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-29 03:18:56 +08:00
[v1] Add FlashAttention selection and implement normal / padding-free / dynamic batching (#10469)
This commit is contained in:
@@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ..accelerator.helper import ReduceOp
|
||||
from ..accelerator.interface import Dim, DistributedInterface
|
||||
from ..config import TrainingArguments
|
||||
from ..config import BatchingStrategy, TrainingArguments
|
||||
from ..utils import logging
|
||||
from ..utils.callbacks import (
|
||||
CallbackHandler,
|
||||
@@ -147,13 +147,19 @@ class BaseTrainer:
|
||||
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
|
||||
|
||||
if model.config._attn_implementation != "flash_attention_2":
|
||||
logger.warning_rank0(
|
||||
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
|
||||
raise ValueError(
|
||||
"Sequence parallelism requires flash attention. Please set `flash_attn: flash_attention_2`."
|
||||
)
|
||||
model.config._attn_implementation = "flash_attention_2"
|
||||
|
||||
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
if (
|
||||
self.args.batching_strategy == BatchingStrategy.PADDING_FREE
|
||||
and getattr(self.model.config, "_attn_implementation", None) != "flash_attention_2"
|
||||
):
|
||||
raise ValueError("`padding_free` requires `flash_attn: flash_attention_2`.")
|
||||
|
||||
self.train_batch_generator = BatchGenerator(
|
||||
dataset=self.train_dataset,
|
||||
renderer=self.renderer,
|
||||
@@ -237,6 +243,7 @@ class BaseTrainer:
|
||||
self.train_batch_generator.set_epoch(epoch)
|
||||
self.callback_handler.on_epoch_begin(self.args, self.state)
|
||||
|
||||
# BatchGenerator is an iterator; each loop step calls its __next__ to produce one optimizer step.
|
||||
for micro_batches in self.train_batch_generator:
|
||||
self.global_step += 1
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ class ModelEngine:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
init_kwargs = {} if self._deepspeed_zero3_enabled else {"device_map": init_device}
|
||||
logger.info_rank0(f"Using attention implementation: {self.args.flash_attn}.")
|
||||
|
||||
if self.args.quant_config is not None:
|
||||
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||
@@ -164,6 +165,7 @@ class ModelEngine:
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
attn_implementation=self.args.flash_attn,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
@@ -42,6 +42,8 @@ from .rendering import Renderer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
__all__ = ["BatchGenerator"]
|
||||
|
||||
|
||||
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None:
|
||||
micro_batch_size = batch_info["micro_batch_size"]
|
||||
@@ -102,19 +104,18 @@ class BatchGenerator(Iterator):
|
||||
if not self.drop_last:
|
||||
raise ValueError("Drop last must be True.")
|
||||
|
||||
self._batch_info: BatchInfo = {
|
||||
"micro_batch_size": self.micro_batch_size,
|
||||
"num_micro_batch": self.num_micro_batch,
|
||||
"cutoff_len": self.cutoff_len,
|
||||
}
|
||||
|
||||
self._init_data_provider()
|
||||
|
||||
self._is_resuming: bool = False
|
||||
self._data_iter = iter(self._data_provider)
|
||||
self._buffer = StatefulBuffer()
|
||||
|
||||
self._batch_info: BatchInfo = {
|
||||
"micro_batch_size": self.micro_batch_size,
|
||||
"num_micro_batch": self.num_micro_batch,
|
||||
"cutoff_len": self.cutoff_len,
|
||||
"data_iter": self._data_iter,
|
||||
}
|
||||
|
||||
logger.info_rank0(
|
||||
f"Init unified data loader with global batch size {self.global_batch_size}, "
|
||||
f"micro batch size {self.micro_batch_size}, "
|
||||
@@ -137,12 +138,19 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
batch_size = self.micro_batch_size * self.num_micro_batch
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
batch_size = BatchingPlugin(self.batching_strategy).get_data_provider_batch_size(self._batch_info)
|
||||
|
||||
generator_seed = torch.Generator()
|
||||
generator_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=self.batching_workers,
|
||||
collate_fn=self.renderer.process_samples,
|
||||
@@ -156,8 +164,7 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider)
|
||||
raise NotImplementedError("Batching strategy other than NORMAL is not supported yet.")
|
||||
self._length = BatchingPlugin(self.batching_strategy).compute_length(self._data_provider, self._batch_info)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._length
|
||||
@@ -190,7 +197,7 @@ class BatchGenerator(Iterator):
|
||||
else:
|
||||
from ...plugins.trainer_plugins.batching import BatchingPlugin
|
||||
|
||||
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info)
|
||||
BatchingPlugin(self.batching_strategy).fill_buffer(self._buffer, self._batch_info, self._next_samples)
|
||||
|
||||
def _generate_batch(self) -> list[BatchInput] | None:
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
@@ -200,6 +207,20 @@ class BatchGenerator(Iterator):
|
||||
|
||||
return BatchingPlugin(self.batching_strategy).generate_batch(self._buffer, self._batch_info)
|
||||
|
||||
def _next_samples(self, restart: bool) -> list[ModelInput] | None:
|
||||
try:
|
||||
return next(self._data_iter)
|
||||
except StopIteration:
|
||||
if not restart:
|
||||
return None
|
||||
|
||||
# Dynamic batching may restart the provider to fill one token-budgeted batch.
|
||||
self._data_iter = iter(self._data_provider)
|
||||
try:
|
||||
return next(self._data_iter)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"buffer": self._buffer.state_dict(),
|
||||
|
||||
Reference in New Issue
Block a user