mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-25 16:48:57 +08:00
[feat] support HyperParallel Context Parallel feature (#10559)
Co-authored-by: wcrzlh <weichaoran@huawei.com>
This commit is contained in:
@@ -500,6 +500,10 @@ class FinetuningArguments(
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
hyper_parallel_cp_size: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Context parallel size used when `use_hyper_parallel=True`."},
|
||||||
|
)
|
||||||
use_muon: bool = field(
|
use_muon: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the Muon optimizer."},
|
metadata={"help": "Whether or not to use the Muon optimizer."},
|
||||||
@@ -576,6 +580,7 @@ class FinetuningArguments(
|
|||||||
assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
|
assert self.finetuning_type in ["lora", "oft", "freeze", "full"], "Invalid fine-tuning method."
|
||||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
|
assert self.hyper_parallel_cp_size > 0, "`hyper_parallel_cp_size` must be greater than 0."
|
||||||
|
|
||||||
if self.stage == "ppo" and self.reward_model is None:
|
if self.stage == "ppo" and self.reward_model is None:
|
||||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import types
|
import types
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -35,6 +36,13 @@ from hyper_parallel.integration.llamafactory import (
|
|||||||
from hyper_parallel.integration.llamafactory import (
|
from hyper_parallel.integration.llamafactory import (
|
||||||
clip_grad_norm_ as hp_clip_grad_norm_,
|
clip_grad_norm_ as hp_clip_grad_norm_,
|
||||||
)
|
)
|
||||||
|
from hyper_parallel.integration.llamafactory.context_parallel import (
|
||||||
|
cp_prepare_model,
|
||||||
|
get_cp_rank,
|
||||||
|
get_dp_rank,
|
||||||
|
shard_inputs_for_cp,
|
||||||
|
)
|
||||||
|
from hyper_parallel.platform import get_platform
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..sft.trainer import CustomSeq2SeqTrainer
|
from ..sft.trainer import CustomSeq2SeqTrainer
|
||||||
@@ -43,6 +51,87 @@ from ..sft.trainer import CustomSeq2SeqTrainer
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _CPBatchRepeatedBatchSampler(torch.utils.data.BatchSampler):
|
||||||
|
"""Repeat logical batches so Accelerate shards CP peers onto the same samples."""
|
||||||
|
|
||||||
|
def __init__(self, sampler, batch_size: int, drop_last: bool, repeat_factor: int, logical_group_size: int):
|
||||||
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
|
self.repeat_factor = repeat_factor
|
||||||
|
self.logical_group_size = logical_group_size
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
logical_length = super().__len__()
|
||||||
|
if not self.drop_last and logical_length > 0:
|
||||||
|
logical_length = _ceil_div(logical_length, self.logical_group_size) * self.logical_group_size
|
||||||
|
return logical_length * self.repeat_factor
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
initial_data = []
|
||||||
|
logical_count = 0
|
||||||
|
pad_cursor = 0
|
||||||
|
max_initial_data = self.batch_size * self.logical_group_size
|
||||||
|
|
||||||
|
def collect_initial_data(batch):
|
||||||
|
if len(initial_data) < max_initial_data:
|
||||||
|
initial_data.extend(batch[: max_initial_data - len(initial_data)])
|
||||||
|
|
||||||
|
def get_padding_item():
|
||||||
|
nonlocal pad_cursor
|
||||||
|
item = initial_data[pad_cursor % len(initial_data)]
|
||||||
|
pad_cursor += 1
|
||||||
|
return item
|
||||||
|
|
||||||
|
def pad_batch(batch):
|
||||||
|
batch = list(batch)
|
||||||
|
if self.drop_last or len(batch) == self.batch_size:
|
||||||
|
return batch
|
||||||
|
|
||||||
|
while len(batch) < self.batch_size:
|
||||||
|
batch.append(get_padding_item())
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def make_padding_batch():
|
||||||
|
return [get_padding_item() for _ in range(self.batch_size)]
|
||||||
|
|
||||||
|
def repeat_batch(batch):
|
||||||
|
for _ in range(self.repeat_factor):
|
||||||
|
yield list(batch)
|
||||||
|
|
||||||
|
for batch in super().__iter__():
|
||||||
|
collect_initial_data(batch)
|
||||||
|
batch = pad_batch(batch)
|
||||||
|
logical_count += 1
|
||||||
|
yield from repeat_batch(batch)
|
||||||
|
|
||||||
|
if self.drop_last or logical_count == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
while logical_count % self.logical_group_size != 0:
|
||||||
|
logical_count += 1
|
||||||
|
yield from repeat_batch(make_padding_batch())
|
||||||
|
|
||||||
|
|
||||||
|
class _CPDataLoaderLengthProxy:
|
||||||
|
"""Keep baseline logical dataloader length while yielding CP-repeated batches."""
|
||||||
|
|
||||||
|
def __init__(self, dataloader, logical_length: int):
|
||||||
|
self._dataloader = dataloader
|
||||||
|
self._logical_length = logical_length
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._dataloader)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._logical_length
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._dataloader, name)
|
||||||
|
|
||||||
|
|
||||||
|
def _ceil_div(numerator: int, denominator: int) -> int:
|
||||||
|
return (numerator + denominator - 1) // denominator
|
||||||
|
|
||||||
|
|
||||||
class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
||||||
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
|
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
|
||||||
|
|
||||||
@@ -73,15 +162,25 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
|||||||
if not getattr(self.accelerator, "is_fsdp2", False):
|
if not getattr(self.accelerator, "is_fsdp2", False):
|
||||||
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
|
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
|
||||||
|
|
||||||
# Prepare ref_model with HP's fsdp2_prepare_model
|
self._cp_size = hp_args.cp_size
|
||||||
|
self._cp_rank = get_cp_rank(hp_args) if self._cp_size > 1 else 0
|
||||||
|
self._dp_rank = get_dp_rank(hp_args) if self._cp_size > 1 else get_platform().get_rank()
|
||||||
|
|
||||||
|
# Prepare ref_model with the same CP + HSDP path as the train model.
|
||||||
self.ref_model = ref_model
|
self.ref_model = ref_model
|
||||||
if self.ref_model is not None:
|
if self.ref_model is not None:
|
||||||
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
|
self.ref_model = self._prepare_model_for_hyper_parallel(self.ref_model)
|
||||||
|
|
||||||
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
|
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
|
||||||
self._orig_fsdp2_prepare_model = None
|
self._orig_fsdp2_prepare_model = None
|
||||||
self._accelerator_patches_active = False
|
self._accelerator_patches_active = False
|
||||||
|
|
||||||
|
def _prepare_model_for_hyper_parallel(self, model: nn.Module) -> nn.Module:
|
||||||
|
"""Apply CP runtime hooks before delegating to HyperParallel FSDP2 preparation."""
|
||||||
|
if self._cp_size > 1:
|
||||||
|
model = cp_prepare_model(model, self.accelerator, self._hp_args)
|
||||||
|
return fsdp2_prepare_model(self.accelerator, model, self._hp_args)
|
||||||
|
|
||||||
def _activate_accelerator_patches(self) -> None:
|
def _activate_accelerator_patches(self) -> None:
|
||||||
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
|
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
|
||||||
if self._accelerator_patches_active:
|
if self._accelerator_patches_active:
|
||||||
@@ -89,12 +188,10 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
|||||||
|
|
||||||
import accelerate.accelerator as acc_module # pylint: disable=C0415
|
import accelerate.accelerator as acc_module # pylint: disable=C0415
|
||||||
|
|
||||||
hp_args = self._hp_args
|
|
||||||
|
|
||||||
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
|
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
|
||||||
|
|
||||||
def _hp_fsdp2_prepare_model(accelerator, model):
|
def _hp_fsdp2_prepare_model(accelerator, model):
|
||||||
return fsdp2_prepare_model(accelerator, model, hp_args)
|
return self._prepare_model_for_hyper_parallel(model)
|
||||||
|
|
||||||
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
|
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
|
||||||
|
|
||||||
@@ -135,6 +232,91 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
|||||||
return model
|
return model
|
||||||
return super()._wrap_model(model, training=training)
|
return super()._wrap_model(model, training=training)
|
||||||
|
|
||||||
|
def _get_train_sampler(self, train_dataset=None):
|
||||||
|
"""Match the no-CP baseline sampler semantics before CP repeats whole logical batches."""
|
||||||
|
if train_dataset is None:
|
||||||
|
train_dataset = self.train_dataset
|
||||||
|
if getattr(self.finetuning_args, "disable_shuffling", False):
|
||||||
|
return torch.utils.data.SequentialSampler(train_dataset)
|
||||||
|
return super()._get_train_sampler(train_dataset)
|
||||||
|
|
||||||
|
def _build_cp_batch_sampler(self, dataset, shuffle: bool, batch_size: int, drop_last: bool):
|
||||||
|
"""Repeat complete logical batches so CP groups consume the same baseline batch."""
|
||||||
|
sampler = self._get_train_sampler(dataset) if shuffle else torch.utils.data.SequentialSampler(dataset)
|
||||||
|
return _CPBatchRepeatedBatchSampler(
|
||||||
|
sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
drop_last=drop_last,
|
||||||
|
repeat_factor=self._cp_size,
|
||||||
|
logical_group_size=max(1, get_platform().get_world_size() // self._cp_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_cp_dataloader(self, dataset, batch_size: int, shuffle: bool):
|
||||||
|
"""Create a train dataloader whose logical batches are shared within each CP group."""
|
||||||
|
if isinstance(dataset, torch.utils.data.IterableDataset):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"HyperParallel CP training requires a map-style dataset because iterable datasets cannot "
|
||||||
|
"repeat logical batches across CP ranks."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import datasets # pylint: disable=C0415
|
||||||
|
except ImportError: # pragma: no cover
|
||||||
|
datasets = None
|
||||||
|
|
||||||
|
if datasets is not None and isinstance(dataset, datasets.Dataset):
|
||||||
|
dataset = self._remove_unused_columns(dataset, description="Training")
|
||||||
|
data_collator = self.data_collator
|
||||||
|
else:
|
||||||
|
data_collator = self._get_collator_with_removed_columns(self.data_collator, description="Training")
|
||||||
|
|
||||||
|
batch_sampler = self._build_cp_batch_sampler(
|
||||||
|
dataset,
|
||||||
|
shuffle=shuffle,
|
||||||
|
batch_size=batch_size,
|
||||||
|
drop_last=self.args.dataloader_drop_last,
|
||||||
|
)
|
||||||
|
logical_batches = len(batch_sampler) // self._cp_size
|
||||||
|
dp_size = max(1, get_platform().get_world_size() // self._cp_size)
|
||||||
|
logical_length = logical_batches // dp_size if self.args.dataloader_drop_last else _ceil_div(logical_batches, dp_size)
|
||||||
|
|
||||||
|
dataloader_params = {
|
||||||
|
"batch_sampler": batch_sampler,
|
||||||
|
"collate_fn": data_collator,
|
||||||
|
"num_workers": self.args.dataloader_num_workers,
|
||||||
|
"pin_memory": self.args.dataloader_pin_memory,
|
||||||
|
"persistent_workers": self.args.dataloader_persistent_workers
|
||||||
|
if self.args.dataloader_num_workers > 0
|
||||||
|
else False,
|
||||||
|
}
|
||||||
|
if self.args.dataloader_num_workers > 0:
|
||||||
|
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||||
|
|
||||||
|
from transformers.trainer import seed_worker # pylint: disable=C0415
|
||||||
|
|
||||||
|
dataloader_params["worker_init_fn"] = partial(
|
||||||
|
seed_worker,
|
||||||
|
num_workers=self.args.dataloader_num_workers,
|
||||||
|
rank=self.args.process_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataloader = self.accelerator.prepare(torch.utils.data.DataLoader(dataset, **dataloader_params))
|
||||||
|
return _CPDataLoaderLengthProxy(dataloader, logical_length)
|
||||||
|
|
||||||
|
def get_train_dataloader(self):
|
||||||
|
"""Keep the no-CP logical batch stream, then repeat each whole batch across CP peers."""
|
||||||
|
if self.train_dataset is None:
|
||||||
|
raise ValueError("Trainer: training requires a train_dataset.")
|
||||||
|
if self._cp_size <= 1:
|
||||||
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
|
shuffle = not getattr(self.finetuning_args, "disable_shuffling", False)
|
||||||
|
return self._get_cp_dataloader(
|
||||||
|
dataset=self.train_dataset,
|
||||||
|
batch_size=self._train_batch_size,
|
||||||
|
shuffle=shuffle,
|
||||||
|
)
|
||||||
|
|
||||||
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
|
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
|
||||||
"""Skip redundant device moves for HSDP-wrapped models."""
|
"""Skip redundant device moves for HSDP-wrapped models."""
|
||||||
if isinstance(model, HSDPModule):
|
if isinstance(model, HSDPModule):
|
||||||
@@ -157,10 +339,13 @@ class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
|||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
num_items_in_batch: Optional[int] = None,
|
num_items_in_batch: Optional[int] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Standard training step with HSDP gradient synchronization."""
|
"""Standard training step with HSDP sync plus optional CP input sharding."""
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_inputs(inputs)
|
inputs = self._prepare_inputs(inputs)
|
||||||
|
|
||||||
|
if self._cp_size > 1:
|
||||||
|
inputs = shard_inputs_for_cp(inputs, self._cp_rank, self._cp_size)
|
||||||
|
|
||||||
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
|
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
|
||||||
if isinstance(model, HSDPModule):
|
if isinstance(model, HSDPModule):
|
||||||
model.set_is_last_backward(sync_gradients)
|
model.set_is_last_backward(sync_gradients)
|
||||||
|
|||||||
@@ -50,6 +50,10 @@ def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelA
|
|||||||
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
|
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
|
||||||
|
|
||||||
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||||
|
|
||||||
|
if getattr(hp_args, "cp_size", None) != finetuning_args.hyper_parallel_cp_size:
|
||||||
|
setattr(hp_args, "cp_size", finetuning_args.hyper_parallel_cp_size)
|
||||||
|
|
||||||
if hp_args.activation_mode != "none":
|
if hp_args.activation_mode != "none":
|
||||||
model_args.disable_gradient_checkpointing = True
|
model_args.disable_gradient_checkpointing = True
|
||||||
return hp_args
|
return hp_args
|
||||||
|
|||||||
Reference in New Issue
Block a user