From 802bcfe9697eff8d99e20660ae37aa09b1678c4a Mon Sep 17 00:00:00 2001 From: Chaoran Wei <77485245+wcrzlh@users.noreply.github.com> Date: Mon, 22 Jun 2026 07:40:44 +0800 Subject: [PATCH] [feat] support HyperParallel Context Parallel feature (#10559) Co-authored-by: wcrzlh --- src/llamafactory/hparams/finetuning_args.py | 5 + .../train/hyper_parallel/trainer.py | 197 +++++++++++++++++- .../train/hyper_parallel/workflow.py | 4 + 3 files changed, 200 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index b2547c510..2fa199695 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -500,6 +500,10 @@ class FinetuningArguments( ) }, ) + hyper_parallel_cp_size: int = field( + default=1, + metadata={"help": "Context parallel size used when `use_hyper_parallel=True`."}, + ) use_muon: bool = field( default=False, metadata={"help": "Whether or not to use the Muon optimizer."}, @@ -576,6 +580,7 @@ class FinetuningArguments( assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method." assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + assert self.hyper_parallel_cp_size > 0, "`hyper_parallel_cp_size` must be greater than 0." if self.stage == "ppo" and self.reward_model is None: raise ValueError("`reward_model` is necessary for PPO training.") diff --git a/src/llamafactory/train/hyper_parallel/trainer.py b/src/llamafactory/train/hyper_parallel/trainer.py index 8ca131143..47d752f5b 100644 --- a/src/llamafactory/train/hyper_parallel/trainer.py +++ b/src/llamafactory/train/hyper_parallel/trainer.py @@ -18,6 +18,7 @@ import logging import os import types from contextlib import nullcontext +from functools import partial from typing import Any, Optional import torch @@ -35,6 +36,13 @@ from hyper_parallel.integration.llamafactory import ( from hyper_parallel.integration.llamafactory import ( clip_grad_norm_ as hp_clip_grad_norm_, ) +from hyper_parallel.integration.llamafactory.context_parallel import ( + cp_prepare_model, + get_cp_rank, + get_dp_rank, + shard_inputs_for_cp, +) +from hyper_parallel.platform import get_platform from torch import nn from ..sft.trainer import CustomSeq2SeqTrainer @@ -43,6 +51,87 @@ from ..sft.trainer import CustomSeq2SeqTrainer logger = logging.getLogger(__name__) +class _CPBatchRepeatedBatchSampler(torch.utils.data.BatchSampler): + """Repeat logical batches so Accelerate shards CP peers onto the same samples.""" + + def __init__(self, sampler, batch_size: int, drop_last: bool, repeat_factor: int, logical_group_size: int): + super().__init__(sampler, batch_size, drop_last) + self.repeat_factor = repeat_factor + self.logical_group_size = logical_group_size + + def __len__(self): + logical_length = super().__len__() + if not self.drop_last and logical_length > 0: + logical_length = _ceil_div(logical_length, self.logical_group_size) * self.logical_group_size + return logical_length * self.repeat_factor + + def __iter__(self): + initial_data = [] + logical_count = 0 + pad_cursor = 0 + max_initial_data = self.batch_size * self.logical_group_size + + def collect_initial_data(batch): + if len(initial_data) < max_initial_data: + initial_data.extend(batch[: max_initial_data - len(initial_data)]) + + def get_padding_item(): + nonlocal pad_cursor + item = initial_data[pad_cursor % len(initial_data)] + pad_cursor += 1 + return item + + def pad_batch(batch): + batch = list(batch) + if self.drop_last or len(batch) == self.batch_size: + return batch + + while len(batch) < self.batch_size: + batch.append(get_padding_item()) + return batch + + def make_padding_batch(): + return [get_padding_item() for _ in range(self.batch_size)] + + def repeat_batch(batch): + for _ in range(self.repeat_factor): + yield list(batch) + + for batch in super().__iter__(): + collect_initial_data(batch) + batch = pad_batch(batch) + logical_count += 1 + yield from repeat_batch(batch) + + if self.drop_last or logical_count == 0: + return + + while logical_count % self.logical_group_size != 0: + logical_count += 1 + yield from repeat_batch(make_padding_batch()) + + +class _CPDataLoaderLengthProxy: + """Keep baseline logical dataloader length while yielding CP-repeated batches.""" + + def __init__(self, dataloader, logical_length: int): + self._dataloader = dataloader + self._logical_length = logical_length + + def __iter__(self): + return iter(self._dataloader) + + def __len__(self): + return self._logical_length + + def __getattr__(self, name): + return getattr(self._dataloader, name) + + +def _ceil_div(numerator: int, denominator: int) -> int: + return (numerator + denominator - 1) // denominator + + class HyperParallelTrainer(CustomSeq2SeqTrainer): """Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard. @@ -73,15 +162,25 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer): if not getattr(self.accelerator, "is_fsdp2", False): raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.") - # Prepare ref_model with HP's fsdp2_prepare_model + self._cp_size = hp_args.cp_size + self._cp_rank = get_cp_rank(hp_args) if self._cp_size > 1 else 0 + self._dp_rank = get_dp_rank(hp_args) if self._cp_size > 1 else get_platform().get_rank() + + # Prepare ref_model with the same CP + HSDP path as the train model. self.ref_model = ref_model if self.ref_model is not None: - self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args) + self.ref_model = self._prepare_model_for_hyper_parallel(self.ref_model) self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_ self._orig_fsdp2_prepare_model = None self._accelerator_patches_active = False + def _prepare_model_for_hyper_parallel(self, model: nn.Module) -> nn.Module: + """Apply CP runtime hooks before delegating to HyperParallel FSDP2 preparation.""" + if self._cp_size > 1: + model = cp_prepare_model(model, self.accelerator, self._hp_args) + return fsdp2_prepare_model(self.accelerator, model, self._hp_args) + def _activate_accelerator_patches(self) -> None: """Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_.""" if self._accelerator_patches_active: @@ -89,12 +188,10 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer): import accelerate.accelerator as acc_module # pylint: disable=C0415 - hp_args = self._hp_args - self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model def _hp_fsdp2_prepare_model(accelerator, model): - return fsdp2_prepare_model(accelerator, model, hp_args) + return self._prepare_model_for_hyper_parallel(model) acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model @@ -135,6 +232,91 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer): return model return super()._wrap_model(model, training=training) + def _get_train_sampler(self, train_dataset=None): + """Match the no-CP baseline sampler semantics before CP repeats whole logical batches.""" + if train_dataset is None: + train_dataset = self.train_dataset + if getattr(self.finetuning_args, "disable_shuffling", False): + return torch.utils.data.SequentialSampler(train_dataset) + return super()._get_train_sampler(train_dataset) + + def _build_cp_batch_sampler(self, dataset, shuffle: bool, batch_size: int, drop_last: bool): + """Repeat complete logical batches so CP groups consume the same baseline batch.""" + sampler = self._get_train_sampler(dataset) if shuffle else torch.utils.data.SequentialSampler(dataset) + return _CPBatchRepeatedBatchSampler( + sampler, + batch_size=batch_size, + drop_last=drop_last, + repeat_factor=self._cp_size, + logical_group_size=max(1, get_platform().get_world_size() // self._cp_size), + ) + + def _get_cp_dataloader(self, dataset, batch_size: int, shuffle: bool): + """Create a train dataloader whose logical batches are shared within each CP group.""" + if isinstance(dataset, torch.utils.data.IterableDataset): + raise NotImplementedError( + "HyperParallel CP training requires a map-style dataset because iterable datasets cannot " + "repeat logical batches across CP ranks." + ) + + try: + import datasets # pylint: disable=C0415 + except ImportError: # pragma: no cover + datasets = None + + if datasets is not None and isinstance(dataset, datasets.Dataset): + dataset = self._remove_unused_columns(dataset, description="Training") + data_collator = self.data_collator + else: + data_collator = self._get_collator_with_removed_columns(self.data_collator, description="Training") + + batch_sampler = self._build_cp_batch_sampler( + dataset, + shuffle=shuffle, + batch_size=batch_size, + drop_last=self.args.dataloader_drop_last, + ) + logical_batches = len(batch_sampler) // self._cp_size + dp_size = max(1, get_platform().get_world_size() // self._cp_size) + logical_length = logical_batches // dp_size if self.args.dataloader_drop_last else _ceil_div(logical_batches, dp_size) + + dataloader_params = { + "batch_sampler": batch_sampler, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers + if self.args.dataloader_num_workers > 0 + else False, + } + if self.args.dataloader_num_workers > 0: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + from transformers.trainer import seed_worker # pylint: disable=C0415 + + dataloader_params["worker_init_fn"] = partial( + seed_worker, + num_workers=self.args.dataloader_num_workers, + rank=self.args.process_index, + ) + + dataloader = self.accelerator.prepare(torch.utils.data.DataLoader(dataset, **dataloader_params)) + return _CPDataLoaderLengthProxy(dataloader, logical_length) + + def get_train_dataloader(self): + """Keep the no-CP logical batch stream, then repeat each whole batch across CP peers.""" + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + if self._cp_size <= 1: + return super().get_train_dataloader() + + shuffle = not getattr(self.finetuning_args, "disable_shuffling", False) + return self._get_cp_dataloader( + dataset=self.train_dataset, + batch_size=self._train_batch_size, + shuffle=shuffle, + ) + def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None): """Skip redundant device moves for HSDP-wrapped models.""" if isinstance(model, HSDPModule): @@ -157,10 +339,13 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer): inputs: dict[str, Any], num_items_in_batch: Optional[int] = None, ) -> torch.Tensor: - """Standard training step with HSDP gradient synchronization.""" + """Standard training step with HSDP sync plus optional CP input sharding.""" model.train() inputs = self._prepare_inputs(inputs) + if self._cp_size > 1: + inputs = shard_inputs_for_cp(inputs, self._cp_rank, self._cp_size) + sync_gradients = getattr(self.accelerator, "sync_gradients", True) if isinstance(model, HSDPModule): model.set_is_last_backward(sync_gradients) diff --git a/src/llamafactory/train/hyper_parallel/workflow.py b/src/llamafactory/train/hyper_parallel/workflow.py index 360a8ccb8..4eb4cc99b 100644 --- a/src/llamafactory/train/hyper_parallel/workflow.py +++ b/src/llamafactory/train/hyper_parallel/workflow.py @@ -50,6 +50,10 @@ def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelA from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415 hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args) + + if getattr(hp_args, "cp_size", None) != finetuning_args.hyper_parallel_cp_size: + setattr(hp_args, "cp_size", finetuning_args.hyper_parallel_cp_size) + if hp_args.activation_mode != "none": model_args.disable_gradient_checkpointing = True return hp_args