mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-25 08:38:55 +08:00
[feat] support HyperParallel Context Parallel feature (#10559)
Co-authored-by: wcrzlh <weichaoran@huawei.com>
This commit is contained in:
@@ -500,6 +500,10 @@ class FinetuningArguments(
|
||||
)
|
||||
},
|
||||
)
|
||||
hyper_parallel_cp_size: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Context parallel size used when `use_hyper_parallel=True`."},
|
||||
)
|
||||
use_muon: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the Muon optimizer."},
|
||||
@@ -576,6 +580,7 @@ class FinetuningArguments(
|
||||
assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.hyper_parallel_cp_size > 0, "`hyper_parallel_cp_size` must be greater than 0."
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
@@ -18,6 +18,7 @@ import logging
|
||||
import os
|
||||
import types
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -35,6 +36,13 @@ from hyper_parallel.integration.llamafactory import (
|
||||
from hyper_parallel.integration.llamafactory import (
|
||||
clip_grad_norm_ as hp_clip_grad_norm_,
|
||||
)
|
||||
from hyper_parallel.integration.llamafactory.context_parallel import (
|
||||
cp_prepare_model,
|
||||
get_cp_rank,
|
||||
get_dp_rank,
|
||||
shard_inputs_for_cp,
|
||||
)
|
||||
from hyper_parallel.platform import get_platform
|
||||
from torch import nn
|
||||
|
||||
from ..sft.trainer import CustomSeq2SeqTrainer
|
||||
@@ -43,6 +51,87 @@ from ..sft.trainer import CustomSeq2SeqTrainer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _CPBatchRepeatedBatchSampler(torch.utils.data.BatchSampler):
|
||||
"""Repeat logical batches so Accelerate shards CP peers onto the same samples."""
|
||||
|
||||
def __init__(self, sampler, batch_size: int, drop_last: bool, repeat_factor: int, logical_group_size: int):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.repeat_factor = repeat_factor
|
||||
self.logical_group_size = logical_group_size
|
||||
|
||||
def __len__(self):
|
||||
logical_length = super().__len__()
|
||||
if not self.drop_last and logical_length > 0:
|
||||
logical_length = _ceil_div(logical_length, self.logical_group_size) * self.logical_group_size
|
||||
return logical_length * self.repeat_factor
|
||||
|
||||
def __iter__(self):
|
||||
initial_data = []
|
||||
logical_count = 0
|
||||
pad_cursor = 0
|
||||
max_initial_data = self.batch_size * self.logical_group_size
|
||||
|
||||
def collect_initial_data(batch):
|
||||
if len(initial_data) < max_initial_data:
|
||||
initial_data.extend(batch[: max_initial_data - len(initial_data)])
|
||||
|
||||
def get_padding_item():
|
||||
nonlocal pad_cursor
|
||||
item = initial_data[pad_cursor % len(initial_data)]
|
||||
pad_cursor += 1
|
||||
return item
|
||||
|
||||
def pad_batch(batch):
|
||||
batch = list(batch)
|
||||
if self.drop_last or len(batch) == self.batch_size:
|
||||
return batch
|
||||
|
||||
while len(batch) < self.batch_size:
|
||||
batch.append(get_padding_item())
|
||||
return batch
|
||||
|
||||
def make_padding_batch():
|
||||
return [get_padding_item() for _ in range(self.batch_size)]
|
||||
|
||||
def repeat_batch(batch):
|
||||
for _ in range(self.repeat_factor):
|
||||
yield list(batch)
|
||||
|
||||
for batch in super().__iter__():
|
||||
collect_initial_data(batch)
|
||||
batch = pad_batch(batch)
|
||||
logical_count += 1
|
||||
yield from repeat_batch(batch)
|
||||
|
||||
if self.drop_last or logical_count == 0:
|
||||
return
|
||||
|
||||
while logical_count % self.logical_group_size != 0:
|
||||
logical_count += 1
|
||||
yield from repeat_batch(make_padding_batch())
|
||||
|
||||
|
||||
class _CPDataLoaderLengthProxy:
|
||||
"""Keep baseline logical dataloader length while yielding CP-repeated batches."""
|
||||
|
||||
def __init__(self, dataloader, logical_length: int):
|
||||
self._dataloader = dataloader
|
||||
self._logical_length = logical_length
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._dataloader)
|
||||
|
||||
def __len__(self):
|
||||
return self._logical_length
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._dataloader, name)
|
||||
|
||||
|
||||
def _ceil_div(numerator: int, denominator: int) -> int:
|
||||
return (numerator + denominator - 1) // denominator
|
||||
|
||||
|
||||
class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
||||
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
|
||||
|
||||
@@ -73,15 +162,25 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
||||
if not getattr(self.accelerator, "is_fsdp2", False):
|
||||
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
|
||||
|
||||
# Prepare ref_model with HP's fsdp2_prepare_model
|
||||
self._cp_size = hp_args.cp_size
|
||||
self._cp_rank = get_cp_rank(hp_args) if self._cp_size > 1 else 0
|
||||
self._dp_rank = get_dp_rank(hp_args) if self._cp_size > 1 else get_platform().get_rank()
|
||||
|
||||
# Prepare ref_model with the same CP + HSDP path as the train model.
|
||||
self.ref_model = ref_model
|
||||
if self.ref_model is not None:
|
||||
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
|
||||
self.ref_model = self._prepare_model_for_hyper_parallel(self.ref_model)
|
||||
|
||||
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
|
||||
self._orig_fsdp2_prepare_model = None
|
||||
self._accelerator_patches_active = False
|
||||
|
||||
def _prepare_model_for_hyper_parallel(self, model: nn.Module) -> nn.Module:
|
||||
"""Apply CP runtime hooks before delegating to HyperParallel FSDP2 preparation."""
|
||||
if self._cp_size > 1:
|
||||
model = cp_prepare_model(model, self.accelerator, self._hp_args)
|
||||
return fsdp2_prepare_model(self.accelerator, model, self._hp_args)
|
||||
|
||||
def _activate_accelerator_patches(self) -> None:
|
||||
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
|
||||
if self._accelerator_patches_active:
|
||||
@@ -89,12 +188,10 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
||||
|
||||
import accelerate.accelerator as acc_module # pylint: disable=C0415
|
||||
|
||||
hp_args = self._hp_args
|
||||
|
||||
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
|
||||
|
||||
def _hp_fsdp2_prepare_model(accelerator, model):
|
||||
return fsdp2_prepare_model(accelerator, model, hp_args)
|
||||
return self._prepare_model_for_hyper_parallel(model)
|
||||
|
||||
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
|
||||
|
||||
@@ -135,6 +232,91 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
||||
return model
|
||||
return super()._wrap_model(model, training=training)
|
||||
|
||||
def _get_train_sampler(self, train_dataset=None):
|
||||
"""Match the no-CP baseline sampler semantics before CP repeats whole logical batches."""
|
||||
if train_dataset is None:
|
||||
train_dataset = self.train_dataset
|
||||
if getattr(self.finetuning_args, "disable_shuffling", False):
|
||||
return torch.utils.data.SequentialSampler(train_dataset)
|
||||
return super()._get_train_sampler(train_dataset)
|
||||
|
||||
def _build_cp_batch_sampler(self, dataset, shuffle: bool, batch_size: int, drop_last: bool):
|
||||
"""Repeat complete logical batches so CP groups consume the same baseline batch."""
|
||||
sampler = self._get_train_sampler(dataset) if shuffle else torch.utils.data.SequentialSampler(dataset)
|
||||
return _CPBatchRepeatedBatchSampler(
|
||||
sampler,
|
||||
batch_size=batch_size,
|
||||
drop_last=drop_last,
|
||||
repeat_factor=self._cp_size,
|
||||
logical_group_size=max(1, get_platform().get_world_size() // self._cp_size),
|
||||
)
|
||||
|
||||
def _get_cp_dataloader(self, dataset, batch_size: int, shuffle: bool):
|
||||
"""Create a train dataloader whose logical batches are shared within each CP group."""
|
||||
if isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
raise NotImplementedError(
|
||||
"HyperParallel CP training requires a map-style dataset because iterable datasets cannot "
|
||||
"repeat logical batches across CP ranks."
|
||||
)
|
||||
|
||||
try:
|
||||
import datasets # pylint: disable=C0415
|
||||
except ImportError: # pragma: no cover
|
||||
datasets = None
|
||||
|
||||
if datasets is not None and isinstance(dataset, datasets.Dataset):
|
||||
dataset = self._remove_unused_columns(dataset, description="Training")
|
||||
data_collator = self.data_collator
|
||||
else:
|
||||
data_collator = self._get_collator_with_removed_columns(self.data_collator, description="Training")
|
||||
|
||||
batch_sampler = self._build_cp_batch_sampler(
|
||||
dataset,
|
||||
shuffle=shuffle,
|
||||
batch_size=batch_size,
|
||||
drop_last=self.args.dataloader_drop_last,
|
||||
)
|
||||
logical_batches = len(batch_sampler) // self._cp_size
|
||||
dp_size = max(1, get_platform().get_world_size() // self._cp_size)
|
||||
logical_length = logical_batches // dp_size if self.args.dataloader_drop_last else _ceil_div(logical_batches, dp_size)
|
||||
|
||||
dataloader_params = {
|
||||
"batch_sampler": batch_sampler,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
"persistent_workers": self.args.dataloader_persistent_workers
|
||||
if self.args.dataloader_num_workers > 0
|
||||
else False,
|
||||
}
|
||||
if self.args.dataloader_num_workers > 0:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
from transformers.trainer import seed_worker # pylint: disable=C0415
|
||||
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker,
|
||||
num_workers=self.args.dataloader_num_workers,
|
||||
rank=self.args.process_index,
|
||||
)
|
||||
|
||||
dataloader = self.accelerator.prepare(torch.utils.data.DataLoader(dataset, **dataloader_params))
|
||||
return _CPDataLoaderLengthProxy(dataloader, logical_length)
|
||||
|
||||
def get_train_dataloader(self):
|
||||
"""Keep the no-CP logical batch stream, then repeat each whole batch across CP peers."""
|
||||
if self.train_dataset is None:
|
||||
raise ValueError("Trainer: training requires a train_dataset.")
|
||||
if self._cp_size <= 1:
|
||||
return super().get_train_dataloader()
|
||||
|
||||
shuffle = not getattr(self.finetuning_args, "disable_shuffling", False)
|
||||
return self._get_cp_dataloader(
|
||||
dataset=self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
|
||||
"""Skip redundant device moves for HSDP-wrapped models."""
|
||||
if isinstance(model, HSDPModule):
|
||||
@@ -157,10 +339,13 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
||||
inputs: dict[str, Any],
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Standard training step with HSDP gradient synchronization."""
|
||||
"""Standard training step with HSDP sync plus optional CP input sharding."""
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if self._cp_size > 1:
|
||||
inputs = shard_inputs_for_cp(inputs, self._cp_rank, self._cp_size)
|
||||
|
||||
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
|
||||
if isinstance(model, HSDPModule):
|
||||
model.set_is_last_backward(sync_gradients)
|
||||
|
||||
@@ -50,6 +50,10 @@ def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelA
|
||||
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
|
||||
|
||||
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||
|
||||
if getattr(hp_args, "cp_size", None) != finetuning_args.hyper_parallel_cp_size:
|
||||
setattr(hp_args, "cp_size", finetuning_args.hyper_parallel_cp_size)
|
||||
|
||||
if hp_args.activation_mode != "none":
|
||||
model_args.disable_gradient_checkpointing = True
|
||||
return hp_args
|
||||
|
||||
Reference in New Issue
Block a user