[feat] support HyperParallel PT training and activation optimization (#10370)

This commit is contained in:
Cui-yshoho
2026-06-02 22:39:32 +08:00
committed by GitHub
parent a98a1ef101
commit 053d43c0ac
5 changed files with 326 additions and 44 deletions

View File

@@ -487,7 +487,7 @@ class FinetuningArguments(
metadata={
"help": (
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
"Only supported for the 'sft' stage with full fine-tuning."
"Only supported for the 'pt' and 'sft' stages with full fine-tuning."
)
},
)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_sft
from .workflow import run_pt, run_sft
__all__ = ["run_sft"]
__all__ = ["run_pt", "run_sft"]

View File

@@ -0,0 +1,222 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperParallel distributed trainer for LlamaFactory."""
import logging
import os
import types
from contextlib import nullcontext
from typing import Any, Optional
import torch
from hyper_parallel.integration.llamafactory import (
HSDPModule,
HyperParallelArguments,
export_to_hf_format,
fsdp2_prepare_model,
hsdp_sync_stream,
load_hsdp_model,
load_hsdp_optimizer_and_scheduler,
save_hsdp_checkpoint,
wrap_optimizer_with_skip_dtensor_dispatch,
)
from hyper_parallel.integration.llamafactory import (
clip_grad_norm_ as hp_clip_grad_norm_,
)
from torch import nn
from ..sft.trainer import CustomSeq2SeqTrainer
logger = logging.getLogger(__name__)
class HyperParallelTrainer(CustomSeq2SeqTrainer):
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
Inherits CustomSeq2SeqTrainer for training algorithm logic (loss, metrics,
prediction, sampler, etc.) and only overrides HSDP-specific behavior.
"""
def __init__(
self,
hp_args: HyperParallelArguments,
finetuning_args=None,
processor=None,
ref_model: Optional[nn.Module] = None,
**kwargs,
):
self._hp_args = hp_args
# Let CustomSeq2SeqTrainer handle everything except ref_model —
# Custom would prepare it with accelerate's fsdp2_prepare_model,
# but we need HP's version instead.
super().__init__(
finetuning_args=finetuning_args,
processor=processor,
ref_model=None,
**kwargs,
)
if not getattr(self.accelerator, "is_fsdp2", False):
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
# Prepare ref_model with HP's fsdp2_prepare_model
self.ref_model = ref_model
if self.ref_model is not None:
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
self._orig_fsdp2_prepare_model = None
self._accelerator_patches_active = False
def _activate_accelerator_patches(self) -> None:
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
if self._accelerator_patches_active:
return
import accelerate.accelerator as acc_module # pylint: disable=C0415
hp_args = self._hp_args
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
def _hp_fsdp2_prepare_model(accelerator, model):
return fsdp2_prepare_model(accelerator, model, hp_args)
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
def _hp_clip_grad_norm(accelerator, parameters, max_norm, norm_type=2):
if getattr(accelerator, "is_fsdp2", False):
accelerator.unscale_gradients()
parameter_list = list(parameters)
parameter_ids = {id(param) for param in parameter_list}
for model in accelerator._models: # pylint: disable=protected-access
if not isinstance(model, HSDPModule):
continue
model_param_ids = {id(param) for param in model.parameters()}
if parameter_ids and parameter_ids.issubset(model_param_ids):
return hp_clip_grad_norm_(parameter_list, max_norm, norm_type=norm_type)
return self._orig_accelerator_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
self.accelerator.clip_grad_norm_ = types.MethodType(_hp_clip_grad_norm, self.accelerator)
self._accelerator_patches_active = True
def _restore_accelerator_patches(self) -> None:
"""Restore original Accelerate methods."""
if not self._accelerator_patches_active:
return
import accelerate.accelerator as acc_module # pylint: disable=C0415
if self._orig_fsdp2_prepare_model is not None:
acc_module.fsdp2_prepare_model = self._orig_fsdp2_prepare_model
self.accelerator.clip_grad_norm_ = self._orig_accelerator_clip_grad_norm
self._accelerator_patches_active = False
def _wrap_model(self, model: nn.Module, training: bool = True, dataloader=None) -> nn.Module:
"""Let Accelerate own FSDP2/HSDP wrapping so optimizer remapping stays correct."""
del dataloader
if isinstance(model, HSDPModule):
return model
if training and getattr(self.accelerator, "is_fsdp2", False):
return model
return super()._wrap_model(model, training=training)
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
"""Skip redundant device moves for HSDP-wrapped models."""
if isinstance(model, HSDPModule):
return model
if device is None:
return model
return model.to(device)
def train(self, *args, **kwargs):
"""Activate HP patches during training and restore afterwards."""
self._activate_accelerator_patches()
try:
return super().train(*args, **kwargs)
finally:
self._restore_accelerator_patches()
def training_step(
self,
model: nn.Module,
inputs: dict[str, Any],
num_items_in_batch: Optional[int] = None,
) -> torch.Tensor:
"""Standard training step with HSDP gradient synchronization."""
model.train()
inputs = self._prepare_inputs(inputs)
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
if isinstance(model, HSDPModule):
model.set_is_last_backward(sync_gradients)
model.set_requires_gradient_sync(sync_gradients)
compute_loss_context_manager = getattr(self, "compute_loss_context_manager", nullcontext)
with compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
if self.args.n_gpu > 1:
loss = loss.mean()
if not getattr(self, "model_accepts_loss_kwargs", False) and getattr(self, "compute_loss_func", None) is None:
loss = loss / self.args.gradient_accumulation_steps
self.accelerator.backward(loss)
if isinstance(model, HSDPModule) and sync_gradients:
hsdp_sync_stream()
return loss.detach()
def create_optimizer(self):
"""Create optimizer and wrap step with SkipDTensorDispatch."""
optimizer = super().create_optimizer()
wrap_optimizer_with_skip_dtensor_dispatch(optimizer)
return optimizer
def _save_optimizer_and_scheduler(self, output_dir: str) -> None:
"""Save model/optimizer shards per-rank and scheduler."""
save_hsdp_checkpoint(
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
output_dir=output_dir,
should_save_scheduler=self.args.should_save and self.lr_scheduler is not None,
)
def _load_from_checkpoint(self, resume_from_checkpoint: str, model: Optional[nn.Module] = None) -> None:
"""Load model from HSDP sharded checkpoint."""
target = model if model is not None else self.model
loaded = load_hsdp_model(target, resume_from_checkpoint)
if not loaded:
return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
self._pending_hsdp_checkpoint = resume_from_checkpoint
return None
def _load_optimizer_and_scheduler(self, checkpoint: Optional[str] = None) -> None:
"""Load optimizer/scheduler from per-rank checkpoint files."""
ckpt_dir = getattr(self, "_pending_hsdp_checkpoint", None) or checkpoint
if ckpt_dir is None:
return
load_hsdp_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, ckpt_dir)
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
"""Save model weights in HuggingFace-compatible format."""
save_dir = output_dir or self.args.output_dir
os.makedirs(save_dir, exist_ok=True)
export_to_hf_format(self.model, getattr(self, "processing_class", None), save_dir)

View File

@@ -12,8 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional
from transformers import DataCollatorForLanguageModeling
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
@@ -21,9 +24,9 @@ from ...extras.misc import calculate_tps
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import SaveProcessorCallback
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
from ..trainer_utils import create_modelcard_and_push, create_ref_model
from .trainer import HyperParallelTrainer
if TYPE_CHECKING:
@@ -35,6 +38,90 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelArguments"):
r"""Load HyperParallel arguments and apply LlamaFactory-side overrides.
When activation optimization is enabled, skip native gradient checkpointing
so HP can install its own via ``setup_activation_optimization``.
"""
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
if hp_args.activation_mode != "none":
model_args.disable_gradient_checkpointing = True
return hp_args
def run_pt(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
hp_args = _prepare_hp_args(finetuning_args, model_args)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
**tokenizer_module,
)
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval")
if isinstance(dataset_module.get("eval_dataset"), dict):
for key in dataset_module["eval_dataset"].keys():
try:
perplexity = math.exp(metrics[f"eval_{key}_loss"])
except OverflowError:
perplexity = float("inf")
metrics[f"eval_{key}_perplexity"] = perplexity
else:
try:
perplexity = math.exp(metrics["eval_loss"])
except OverflowError:
perplexity = float("inf")
metrics["eval_perplexity"] = perplexity
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
@@ -43,13 +130,7 @@ def run_sft(
generating_args: "GeneratingArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
HyperParallelArguments,
HyperParallelTrainer,
)
hp_args = _prepare_hp_args(finetuning_args, model_args)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
@@ -94,25 +175,6 @@ def run_sft(
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
callbacks = list(callbacks or [])
processor = tokenizer_module.get("processor")
if processor is not None:
callbacks.append(SaveProcessorCallback(processor))
compute_loss_func = None
if finetuning_args.use_dft_loss:
compute_loss_func = dft_loss_func
elif finetuning_args.use_eaft_loss:
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
elif finetuning_args.use_asft_loss:
from functools import partial
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
@@ -122,20 +184,11 @@ def run_sft(
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
compute_loss_func=compute_loss_func,
**dataset_module,
**tokenizer_module,
**metric_module,
)
if finetuning_args.use_badam:
from types import MethodType
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
trainer.add_callback(BAdamCallback)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)

View File

@@ -88,12 +88,19 @@ def _training_function(config: dict[str, Any]) -> None:
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
if finetuning_args.stage in ["pt", "sft"] and finetuning_args.use_hyper_parallel:
if not is_hyper_parallel_available():
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
from .hyper_parallel import run_sft as run_sft_hp
raise ImportError(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
if finetuning_args.stage == "pt":
from .hyper_parallel import run_pt as run_pt_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
run_pt_hp(model_args, data_args, training_args, finetuning_args, callbacks)
else:
from .hyper_parallel import run_sft as run_sft_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if not is_mcore_adapter_available():