[feat] support HyperParallel Context Parallel feature (#10559)

Co-authored-by: wcrzlh <weichaoran@huawei.com>
This commit is contained in:
Chaoran Wei
2026-06-22 07:40:44 +08:00
committed by GitHub
parent 8792f06161
commit 802bcfe969
3 changed files with 200 additions and 6 deletions

View File

@@ -500,6 +500,10 @@ class FinetuningArguments(
) )
}, },
) )
hyper_parallel_cp_size: int = field(
default=1,
metadata={"help": "Context parallel size used when `use_hyper_parallel=True`."},
)
use_muon: bool = field( use_muon: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use the Muon optimizer."}, metadata={"help": "Whether or not to use the Muon optimizer."},
@@ -576,6 +580,7 @@ class FinetuningArguments(
assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
assert self.hyper_parallel_cp_size > 0, "`hyper_parallel_cp_size` must be greater than 0."
if self.stage == "ppo" and self.reward_model is None: if self.stage == "ppo" and self.reward_model is None:
raise ValueError("`reward_model` is necessary for PPO training.") raise ValueError("`reward_model` is necessary for PPO training.")

View File

@@ -18,6 +18,7 @@ import logging
import os import os
import types import types
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@@ -35,6 +36,13 @@ from hyper_parallel.integration.llamafactory import (
from hyper_parallel.integration.llamafactory import ( from hyper_parallel.integration.llamafactory import (
clip_grad_norm_ as hp_clip_grad_norm_, clip_grad_norm_ as hp_clip_grad_norm_,
) )
from hyper_parallel.integration.llamafactory.context_parallel import (
cp_prepare_model,
get_cp_rank,
get_dp_rank,
shard_inputs_for_cp,
)
from hyper_parallel.platform import get_platform
from torch import nn from torch import nn
from ..sft.trainer import CustomSeq2SeqTrainer from ..sft.trainer import CustomSeq2SeqTrainer
@@ -43,6 +51,87 @@ from ..sft.trainer import CustomSeq2SeqTrainer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _CPBatchRepeatedBatchSampler(torch.utils.data.BatchSampler):
"""Repeat logical batches so Accelerate shards CP peers onto the same samples."""
def __init__(self, sampler, batch_size: int, drop_last: bool, repeat_factor: int, logical_group_size: int):
super().__init__(sampler, batch_size, drop_last)
self.repeat_factor = repeat_factor
self.logical_group_size = logical_group_size
def __len__(self):
logical_length = super().__len__()
if not self.drop_last and logical_length > 0:
logical_length = _ceil_div(logical_length, self.logical_group_size) * self.logical_group_size
return logical_length * self.repeat_factor
def __iter__(self):
initial_data = []
logical_count = 0
pad_cursor = 0
max_initial_data = self.batch_size * self.logical_group_size
def collect_initial_data(batch):
if len(initial_data) < max_initial_data:
initial_data.extend(batch[: max_initial_data - len(initial_data)])
def get_padding_item():
nonlocal pad_cursor
item = initial_data[pad_cursor % len(initial_data)]
pad_cursor += 1
return item
def pad_batch(batch):
batch = list(batch)
if self.drop_last or len(batch) == self.batch_size:
return batch
while len(batch) < self.batch_size:
batch.append(get_padding_item())
return batch
def make_padding_batch():
return [get_padding_item() for _ in range(self.batch_size)]
def repeat_batch(batch):
for _ in range(self.repeat_factor):
yield list(batch)
for batch in super().__iter__():
collect_initial_data(batch)
batch = pad_batch(batch)
logical_count += 1
yield from repeat_batch(batch)
if self.drop_last or logical_count == 0:
return
while logical_count % self.logical_group_size != 0:
logical_count += 1
yield from repeat_batch(make_padding_batch())
class _CPDataLoaderLengthProxy:
"""Keep baseline logical dataloader length while yielding CP-repeated batches."""
def __init__(self, dataloader, logical_length: int):
self._dataloader = dataloader
self._logical_length = logical_length
def __iter__(self):
return iter(self._dataloader)
def __len__(self):
return self._logical_length
def __getattr__(self, name):
return getattr(self._dataloader, name)
def _ceil_div(numerator: int, denominator: int) -> int:
return (numerator + denominator - 1) // denominator
class HyperParallelTrainer(CustomSeq2SeqTrainer): class HyperParallelTrainer(CustomSeq2SeqTrainer):
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard. """Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
@@ -73,15 +162,25 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
if not getattr(self.accelerator, "is_fsdp2", False): if not getattr(self.accelerator, "is_fsdp2", False):
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.") raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
# Prepare ref_model with HP's fsdp2_prepare_model self._cp_size = hp_args.cp_size
self._cp_rank = get_cp_rank(hp_args) if self._cp_size > 1 else 0
self._dp_rank = get_dp_rank(hp_args) if self._cp_size > 1 else get_platform().get_rank()
# Prepare ref_model with the same CP + HSDP path as the train model.
self.ref_model = ref_model self.ref_model = ref_model
if self.ref_model is not None: if self.ref_model is not None:
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args) self.ref_model = self._prepare_model_for_hyper_parallel(self.ref_model)
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_ self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
self._orig_fsdp2_prepare_model = None self._orig_fsdp2_prepare_model = None
self._accelerator_patches_active = False self._accelerator_patches_active = False
def _prepare_model_for_hyper_parallel(self, model: nn.Module) -> nn.Module:
"""Apply CP runtime hooks before delegating to HyperParallel FSDP2 preparation."""
if self._cp_size > 1:
model = cp_prepare_model(model, self.accelerator, self._hp_args)
return fsdp2_prepare_model(self.accelerator, model, self._hp_args)
def _activate_accelerator_patches(self) -> None: def _activate_accelerator_patches(self) -> None:
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_.""" """Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
if self._accelerator_patches_active: if self._accelerator_patches_active:
@@ -89,12 +188,10 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
import accelerate.accelerator as acc_module # pylint: disable=C0415 import accelerate.accelerator as acc_module # pylint: disable=C0415
hp_args = self._hp_args
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
def _hp_fsdp2_prepare_model(accelerator, model): def _hp_fsdp2_prepare_model(accelerator, model):
return fsdp2_prepare_model(accelerator, model, hp_args) return self._prepare_model_for_hyper_parallel(model)
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
@@ -135,6 +232,91 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
return model return model
return super()._wrap_model(model, training=training) return super()._wrap_model(model, training=training)
def _get_train_sampler(self, train_dataset=None):
"""Match the no-CP baseline sampler semantics before CP repeats whole logical batches."""
if train_dataset is None:
train_dataset = self.train_dataset
if getattr(self.finetuning_args, "disable_shuffling", False):
return torch.utils.data.SequentialSampler(train_dataset)
return super()._get_train_sampler(train_dataset)
def _build_cp_batch_sampler(self, dataset, shuffle: bool, batch_size: int, drop_last: bool):
"""Repeat complete logical batches so CP groups consume the same baseline batch."""
sampler = self._get_train_sampler(dataset) if shuffle else torch.utils.data.SequentialSampler(dataset)
return _CPBatchRepeatedBatchSampler(
sampler,
batch_size=batch_size,
drop_last=drop_last,
repeat_factor=self._cp_size,
logical_group_size=max(1, get_platform().get_world_size() // self._cp_size),
)
def _get_cp_dataloader(self, dataset, batch_size: int, shuffle: bool):
"""Create a train dataloader whose logical batches are shared within each CP group."""
if isinstance(dataset, torch.utils.data.IterableDataset):
raise NotImplementedError(
"HyperParallel CP training requires a map-style dataset because iterable datasets cannot "
"repeat logical batches across CP ranks."
)
try:
import datasets # pylint: disable=C0415
except ImportError: # pragma: no cover
datasets = None
if datasets is not None and isinstance(dataset, datasets.Dataset):
dataset = self._remove_unused_columns(dataset, description="Training")
data_collator = self.data_collator
else:
data_collator = self._get_collator_with_removed_columns(self.data_collator, description="Training")
batch_sampler = self._build_cp_batch_sampler(
dataset,
shuffle=shuffle,
batch_size=batch_size,
drop_last=self.args.dataloader_drop_last,
)
logical_batches = len(batch_sampler) // self._cp_size
dp_size = max(1, get_platform().get_world_size() // self._cp_size)
logical_length = logical_batches // dp_size if self.args.dataloader_drop_last else _ceil_div(logical_batches, dp_size)
dataloader_params = {
"batch_sampler": batch_sampler,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers
if self.args.dataloader_num_workers > 0
else False,
}
if self.args.dataloader_num_workers > 0:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
from transformers.trainer import seed_worker # pylint: disable=C0415
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader = self.accelerator.prepare(torch.utils.data.DataLoader(dataset, **dataloader_params))
return _CPDataLoaderLengthProxy(dataloader, logical_length)
def get_train_dataloader(self):
"""Keep the no-CP logical batch stream, then repeat each whole batch across CP peers."""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
if self._cp_size <= 1:
return super().get_train_dataloader()
shuffle = not getattr(self.finetuning_args, "disable_shuffling", False)
return self._get_cp_dataloader(
dataset=self.train_dataset,
batch_size=self._train_batch_size,
shuffle=shuffle,
)
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None): def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
"""Skip redundant device moves for HSDP-wrapped models.""" """Skip redundant device moves for HSDP-wrapped models."""
if isinstance(model, HSDPModule): if isinstance(model, HSDPModule):
@@ -157,10 +339,13 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
inputs: dict[str, Any], inputs: dict[str, Any],
num_items_in_batch: Optional[int] = None, num_items_in_batch: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Standard training step with HSDP gradient synchronization.""" """Standard training step with HSDP sync plus optional CP input sharding."""
model.train() model.train()
inputs = self._prepare_inputs(inputs) inputs = self._prepare_inputs(inputs)
if self._cp_size > 1:
inputs = shard_inputs_for_cp(inputs, self._cp_rank, self._cp_size)
sync_gradients = getattr(self.accelerator, "sync_gradients", True) sync_gradients = getattr(self.accelerator, "sync_gradients", True)
if isinstance(model, HSDPModule): if isinstance(model, HSDPModule):
model.set_is_last_backward(sync_gradients) model.set_is_last_backward(sync_gradients)

View File

@@ -50,6 +50,10 @@ def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelA
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415 from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args) hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
if getattr(hp_args, "cp_size", None) != finetuning_args.hyper_parallel_cp_size:
setattr(hp_args, "cp_size", finetuning_args.hyper_parallel_cp_size)
if hp_args.activation_mode != "none": if hp_args.activation_mode != "none":
model_args.disable_gradient_checkpointing = True model_args.disable_gradient_checkpointing = True
return hp_args return hp_args