[feat] support HyperParallel Context Parallel feature (#10559)

Co-authored-by: wcrzlh <weichaoran@huawei.com>
This commit is contained in:
Chaoran Wei
2026-06-22 07:40:44 +08:00
committed by GitHub
parent 8792f06161
commit 802bcfe969
3 changed files with 200 additions and 6 deletions

View File

@@ -500,6 +500,10 @@ class FinetuningArguments(
)
},
)
hyper_parallel_cp_size: int = field(
default=1,
metadata={"help": "Context parallel size used when `use_hyper_parallel=True`."},
)
use_muon: bool = field(
default=False,
metadata={"help": "Whether or not to use the Muon optimizer."},
@@ -576,6 +580,7 @@ class FinetuningArguments(
assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.hyper_parallel_cp_size > 0, "`hyper_parallel_cp_size` must be greater than 0."
if self.stage == "ppo" and self.reward_model is None:
raise ValueError("`reward_model` is necessary for PPO training.")

View File

@@ -18,6 +18,7 @@ import logging
import os
import types
from contextlib import nullcontext
from functools import partial
from typing import Any, Optional
import torch
@@ -35,6 +36,13 @@ from hyper_parallel.integration.llamafactory import (
from hyper_parallel.integration.llamafactory import (
clip_grad_norm_ as hp_clip_grad_norm_,
)
from hyper_parallel.integration.llamafactory.context_parallel import (
cp_prepare_model,
get_cp_rank,
get_dp_rank,
shard_inputs_for_cp,
)
from hyper_parallel.platform import get_platform
from torch import nn
from ..sft.trainer import CustomSeq2SeqTrainer
@@ -43,6 +51,87 @@ from ..sft.trainer import CustomSeq2SeqTrainer
logger = logging.getLogger(__name__)
class _CPBatchRepeatedBatchSampler(torch.utils.data.BatchSampler):
"""Repeat logical batches so Accelerate shards CP peers onto the same samples."""
def __init__(self, sampler, batch_size: int, drop_last: bool, repeat_factor: int, logical_group_size: int):
super().__init__(sampler, batch_size, drop_last)
self.repeat_factor = repeat_factor
self.logical_group_size = logical_group_size
def __len__(self):
logical_length = super().__len__()
if not self.drop_last and logical_length > 0:
logical_length = _ceil_div(logical_length, self.logical_group_size) * self.logical_group_size
return logical_length * self.repeat_factor
def __iter__(self):
initial_data = []
logical_count = 0
pad_cursor = 0
max_initial_data = self.batch_size * self.logical_group_size
def collect_initial_data(batch):
if len(initial_data) < max_initial_data:
initial_data.extend(batch[: max_initial_data - len(initial_data)])
def get_padding_item():
nonlocal pad_cursor
item = initial_data[pad_cursor % len(initial_data)]
pad_cursor += 1
return item
def pad_batch(batch):
batch = list(batch)
if self.drop_last or len(batch) == self.batch_size:
return batch
while len(batch) < self.batch_size:
batch.append(get_padding_item())
return batch
def make_padding_batch():
return [get_padding_item() for _ in range(self.batch_size)]
def repeat_batch(batch):
for _ in range(self.repeat_factor):
yield list(batch)
for batch in super().__iter__():
collect_initial_data(batch)
batch = pad_batch(batch)
logical_count += 1
yield from repeat_batch(batch)
if self.drop_last or logical_count == 0:
return
while logical_count % self.logical_group_size != 0:
logical_count += 1
yield from repeat_batch(make_padding_batch())
class _CPDataLoaderLengthProxy:
"""Keep baseline logical dataloader length while yielding CP-repeated batches."""
def __init__(self, dataloader, logical_length: int):
self._dataloader = dataloader
self._logical_length = logical_length
def __iter__(self):
return iter(self._dataloader)
def __len__(self):
return self._logical_length
def __getattr__(self, name):
return getattr(self._dataloader, name)
def _ceil_div(numerator: int, denominator: int) -> int:
return (numerator + denominator - 1) // denominator
class HyperParallelTrainer(CustomSeq2SeqTrainer):
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
@@ -73,15 +162,25 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
if not getattr(self.accelerator, "is_fsdp2", False):
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
# Prepare ref_model with HP's fsdp2_prepare_model
self._cp_size = hp_args.cp_size
self._cp_rank = get_cp_rank(hp_args) if self._cp_size > 1 else 0
self._dp_rank = get_dp_rank(hp_args) if self._cp_size > 1 else get_platform().get_rank()
# Prepare ref_model with the same CP + HSDP path as the train model.
self.ref_model = ref_model
if self.ref_model is not None:
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
self.ref_model = self._prepare_model_for_hyper_parallel(self.ref_model)
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
self._orig_fsdp2_prepare_model = None
self._accelerator_patches_active = False
def _prepare_model_for_hyper_parallel(self, model: nn.Module) -> nn.Module:
"""Apply CP runtime hooks before delegating to HyperParallel FSDP2 preparation."""
if self._cp_size > 1:
model = cp_prepare_model(model, self.accelerator, self._hp_args)
return fsdp2_prepare_model(self.accelerator, model, self._hp_args)
def _activate_accelerator_patches(self) -> None:
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
if self._accelerator_patches_active:
@@ -89,12 +188,10 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
import accelerate.accelerator as acc_module # pylint: disable=C0415
hp_args = self._hp_args
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
def _hp_fsdp2_prepare_model(accelerator, model):
return fsdp2_prepare_model(accelerator, model, hp_args)
return self._prepare_model_for_hyper_parallel(model)
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
@@ -135,6 +232,91 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
return model
return super()._wrap_model(model, training=training)
def _get_train_sampler(self, train_dataset=None):
"""Match the no-CP baseline sampler semantics before CP repeats whole logical batches."""
if train_dataset is None:
train_dataset = self.train_dataset
if getattr(self.finetuning_args, "disable_shuffling", False):
return torch.utils.data.SequentialSampler(train_dataset)
return super()._get_train_sampler(train_dataset)
def _build_cp_batch_sampler(self, dataset, shuffle: bool, batch_size: int, drop_last: bool):
"""Repeat complete logical batches so CP groups consume the same baseline batch."""
sampler = self._get_train_sampler(dataset) if shuffle else torch.utils.data.SequentialSampler(dataset)
return _CPBatchRepeatedBatchSampler(
sampler,
batch_size=batch_size,
drop_last=drop_last,
repeat_factor=self._cp_size,
logical_group_size=max(1, get_platform().get_world_size() // self._cp_size),
)
def _get_cp_dataloader(self, dataset, batch_size: int, shuffle: bool):
"""Create a train dataloader whose logical batches are shared within each CP group."""
if isinstance(dataset, torch.utils.data.IterableDataset):
raise NotImplementedError(
"HyperParallel CP training requires a map-style dataset because iterable datasets cannot "
"repeat logical batches across CP ranks."
)
try:
import datasets # pylint: disable=C0415
except ImportError: # pragma: no cover
datasets = None
if datasets is not None and isinstance(dataset, datasets.Dataset):
dataset = self._remove_unused_columns(dataset, description="Training")
data_collator = self.data_collator
else:
data_collator = self._get_collator_with_removed_columns(self.data_collator, description="Training")
batch_sampler = self._build_cp_batch_sampler(
dataset,
shuffle=shuffle,
batch_size=batch_size,
drop_last=self.args.dataloader_drop_last,
)
logical_batches = len(batch_sampler) // self._cp_size
dp_size = max(1, get_platform().get_world_size() // self._cp_size)
logical_length = logical_batches // dp_size if self.args.dataloader_drop_last else _ceil_div(logical_batches, dp_size)
dataloader_params = {
"batch_sampler": batch_sampler,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers
if self.args.dataloader_num_workers > 0
else False,
}
if self.args.dataloader_num_workers > 0:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
from transformers.trainer import seed_worker # pylint: disable=C0415
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader = self.accelerator.prepare(torch.utils.data.DataLoader(dataset, **dataloader_params))
return _CPDataLoaderLengthProxy(dataloader, logical_length)
def get_train_dataloader(self):
"""Keep the no-CP logical batch stream, then repeat each whole batch across CP peers."""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
if self._cp_size <= 1:
return super().get_train_dataloader()
shuffle = not getattr(self.finetuning_args, "disable_shuffling", False)
return self._get_cp_dataloader(
dataset=self.train_dataset,
batch_size=self._train_batch_size,
shuffle=shuffle,
)
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
"""Skip redundant device moves for HSDP-wrapped models."""
if isinstance(model, HSDPModule):
@@ -157,10 +339,13 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
inputs: dict[str, Any],
num_items_in_batch: Optional[int] = None,
) -> torch.Tensor:
"""Standard training step with HSDP gradient synchronization."""
"""Standard training step with HSDP sync plus optional CP input sharding."""
model.train()
inputs = self._prepare_inputs(inputs)
if self._cp_size > 1:
inputs = shard_inputs_for_cp(inputs, self._cp_rank, self._cp_size)
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
if isinstance(model, HSDPModule):
model.set_is_last_backward(sync_gradients)

View File

@@ -50,6 +50,10 @@ def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelA
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
if getattr(hp_args, "cp_size", None) != finetuning_args.hyper_parallel_cp_size:
setattr(hp_args, "cp_size", finetuning_args.hyper_parallel_cp_size)
if hp_args.activation_mode != "none":
model_args.disable_gradient_checkpointing = True
return hp_args