mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
[feat] support HyperParallel PT training and activation optimization (#10370)
This commit is contained in:
@@ -487,7 +487,7 @@ class FinetuningArguments(
|
||||
metadata={
|
||||
"help": (
|
||||
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
|
||||
"Only supported for the 'sft' stage with full fine-tuning."
|
||||
"Only supported for the 'pt' and 'sft' stages with full fine-tuning."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .workflow import run_sft
|
||||
from .workflow import run_pt, run_sft
|
||||
|
||||
|
||||
__all__ = ["run_sft"]
|
||||
__all__ = ["run_pt", "run_sft"]
|
||||
|
||||
222
src/llamafactory/train/hyper_parallel/trainer.py
Normal file
222
src/llamafactory/train/hyper_parallel/trainer.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""HyperParallel distributed trainer for LlamaFactory."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import types
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from hyper_parallel.integration.llamafactory import (
|
||||
HSDPModule,
|
||||
HyperParallelArguments,
|
||||
export_to_hf_format,
|
||||
fsdp2_prepare_model,
|
||||
hsdp_sync_stream,
|
||||
load_hsdp_model,
|
||||
load_hsdp_optimizer_and_scheduler,
|
||||
save_hsdp_checkpoint,
|
||||
wrap_optimizer_with_skip_dtensor_dispatch,
|
||||
)
|
||||
from hyper_parallel.integration.llamafactory import (
|
||||
clip_grad_norm_ as hp_clip_grad_norm_,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from ..sft.trainer import CustomSeq2SeqTrainer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HyperParallelTrainer(CustomSeq2SeqTrainer):
|
||||
"""Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard.
|
||||
|
||||
Inherits CustomSeq2SeqTrainer for training algorithm logic (loss, metrics,
|
||||
prediction, sampler, etc.) and only overrides HSDP-specific behavior.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hp_args: HyperParallelArguments,
|
||||
finetuning_args=None,
|
||||
processor=None,
|
||||
ref_model: Optional[nn.Module] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._hp_args = hp_args
|
||||
|
||||
# Let CustomSeq2SeqTrainer handle everything except ref_model —
|
||||
# Custom would prepare it with accelerate's fsdp2_prepare_model,
|
||||
# but we need HP's version instead.
|
||||
super().__init__(
|
||||
finetuning_args=finetuning_args,
|
||||
processor=processor,
|
||||
ref_model=None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not getattr(self.accelerator, "is_fsdp2", False):
|
||||
raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.")
|
||||
|
||||
# Prepare ref_model with HP's fsdp2_prepare_model
|
||||
self.ref_model = ref_model
|
||||
if self.ref_model is not None:
|
||||
self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args)
|
||||
|
||||
self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_
|
||||
self._orig_fsdp2_prepare_model = None
|
||||
self._accelerator_patches_active = False
|
||||
|
||||
def _activate_accelerator_patches(self) -> None:
|
||||
"""Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_."""
|
||||
if self._accelerator_patches_active:
|
||||
return
|
||||
|
||||
import accelerate.accelerator as acc_module # pylint: disable=C0415
|
||||
|
||||
hp_args = self._hp_args
|
||||
|
||||
self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model
|
||||
|
||||
def _hp_fsdp2_prepare_model(accelerator, model):
|
||||
return fsdp2_prepare_model(accelerator, model, hp_args)
|
||||
|
||||
acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model
|
||||
|
||||
def _hp_clip_grad_norm(accelerator, parameters, max_norm, norm_type=2):
|
||||
if getattr(accelerator, "is_fsdp2", False):
|
||||
accelerator.unscale_gradients()
|
||||
parameter_list = list(parameters)
|
||||
parameter_ids = {id(param) for param in parameter_list}
|
||||
for model in accelerator._models: # pylint: disable=protected-access
|
||||
if not isinstance(model, HSDPModule):
|
||||
continue
|
||||
model_param_ids = {id(param) for param in model.parameters()}
|
||||
if parameter_ids and parameter_ids.issubset(model_param_ids):
|
||||
return hp_clip_grad_norm_(parameter_list, max_norm, norm_type=norm_type)
|
||||
return self._orig_accelerator_clip_grad_norm(parameters, max_norm, norm_type=norm_type)
|
||||
|
||||
self.accelerator.clip_grad_norm_ = types.MethodType(_hp_clip_grad_norm, self.accelerator)
|
||||
self._accelerator_patches_active = True
|
||||
|
||||
def _restore_accelerator_patches(self) -> None:
|
||||
"""Restore original Accelerate methods."""
|
||||
if not self._accelerator_patches_active:
|
||||
return
|
||||
|
||||
import accelerate.accelerator as acc_module # pylint: disable=C0415
|
||||
|
||||
if self._orig_fsdp2_prepare_model is not None:
|
||||
acc_module.fsdp2_prepare_model = self._orig_fsdp2_prepare_model
|
||||
self.accelerator.clip_grad_norm_ = self._orig_accelerator_clip_grad_norm
|
||||
self._accelerator_patches_active = False
|
||||
|
||||
def _wrap_model(self, model: nn.Module, training: bool = True, dataloader=None) -> nn.Module:
|
||||
"""Let Accelerate own FSDP2/HSDP wrapping so optimizer remapping stays correct."""
|
||||
del dataloader
|
||||
if isinstance(model, HSDPModule):
|
||||
return model
|
||||
if training and getattr(self.accelerator, "is_fsdp2", False):
|
||||
return model
|
||||
return super()._wrap_model(model, training=training)
|
||||
|
||||
def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None):
|
||||
"""Skip redundant device moves for HSDP-wrapped models."""
|
||||
if isinstance(model, HSDPModule):
|
||||
return model
|
||||
if device is None:
|
||||
return model
|
||||
return model.to(device)
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
"""Activate HP patches during training and restore afterwards."""
|
||||
self._activate_accelerator_patches()
|
||||
try:
|
||||
return super().train(*args, **kwargs)
|
||||
finally:
|
||||
self._restore_accelerator_patches()
|
||||
|
||||
def training_step(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Any],
|
||||
num_items_in_batch: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Standard training step with HSDP gradient synchronization."""
|
||||
model.train()
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
sync_gradients = getattr(self.accelerator, "sync_gradients", True)
|
||||
if isinstance(model, HSDPModule):
|
||||
model.set_is_last_backward(sync_gradients)
|
||||
model.set_requires_gradient_sync(sync_gradients)
|
||||
|
||||
compute_loss_context_manager = getattr(self, "compute_loss_context_manager", nullcontext)
|
||||
with compute_loss_context_manager():
|
||||
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean()
|
||||
|
||||
if not getattr(self, "model_accepts_loss_kwargs", False) and getattr(self, "compute_loss_func", None) is None:
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if isinstance(model, HSDPModule) and sync_gradients:
|
||||
hsdp_sync_stream()
|
||||
|
||||
return loss.detach()
|
||||
|
||||
def create_optimizer(self):
|
||||
"""Create optimizer and wrap step with SkipDTensorDispatch."""
|
||||
optimizer = super().create_optimizer()
|
||||
wrap_optimizer_with_skip_dtensor_dispatch(optimizer)
|
||||
return optimizer
|
||||
|
||||
def _save_optimizer_and_scheduler(self, output_dir: str) -> None:
|
||||
"""Save model/optimizer shards per-rank and scheduler."""
|
||||
save_hsdp_checkpoint(
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.lr_scheduler,
|
||||
output_dir=output_dir,
|
||||
should_save_scheduler=self.args.should_save and self.lr_scheduler is not None,
|
||||
)
|
||||
|
||||
def _load_from_checkpoint(self, resume_from_checkpoint: str, model: Optional[nn.Module] = None) -> None:
|
||||
"""Load model from HSDP sharded checkpoint."""
|
||||
target = model if model is not None else self.model
|
||||
loaded = load_hsdp_model(target, resume_from_checkpoint)
|
||||
if not loaded:
|
||||
return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
|
||||
self._pending_hsdp_checkpoint = resume_from_checkpoint
|
||||
return None
|
||||
|
||||
def _load_optimizer_and_scheduler(self, checkpoint: Optional[str] = None) -> None:
|
||||
"""Load optimizer/scheduler from per-rank checkpoint files."""
|
||||
ckpt_dir = getattr(self, "_pending_hsdp_checkpoint", None) or checkpoint
|
||||
if ckpt_dir is None:
|
||||
return
|
||||
load_hsdp_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, ckpt_dir)
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
||||
"""Save model weights in HuggingFace-compatible format."""
|
||||
save_dir = output_dir or self.args.output_dir
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
export_to_hf_format(self.model, getattr(self, "processing_class", None), save_dir)
|
||||
@@ -12,8 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
|
||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
@@ -21,9 +24,9 @@ from ...extras.misc import calculate_tps
|
||||
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
|
||||
from ..trainer_utils import create_modelcard_and_push, create_ref_model
|
||||
from .trainer import HyperParallelTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -35,6 +38,90 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelArguments"):
|
||||
r"""Load HyperParallel arguments and apply LlamaFactory-side overrides.
|
||||
|
||||
When activation optimization is enabled, skip native gradient checkpointing
|
||||
so HP can install its own via ``setup_activation_optimization``.
|
||||
"""
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||
|
||||
from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415
|
||||
|
||||
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||
if hp_args.activation_mode != "none":
|
||||
model_args.disable_gradient_checkpointing = True
|
||||
return hp_args
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
hp_args = _prepare_hp_args(finetuning_args, model_args)
|
||||
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
trainer = HyperParallelTrainer(
|
||||
hp_args=hp_args,
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
)
|
||||
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
keys = ["loss"]
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||
else:
|
||||
keys += ["eval_loss"]
|
||||
|
||||
plot_loss(training_args.output_dir, keys=keys)
|
||||
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
|
||||
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||
for key in dataset_module["eval_dataset"].keys():
|
||||
try:
|
||||
perplexity = math.exp(metrics[f"eval_{key}_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
metrics[f"eval_{key}_perplexity"] = perplexity
|
||||
else:
|
||||
try:
|
||||
perplexity = math.exp(metrics["eval_loss"])
|
||||
except OverflowError:
|
||||
perplexity = float("inf")
|
||||
|
||||
metrics["eval_perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
@@ -43,13 +130,7 @@ def run_sft(
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||
|
||||
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
|
||||
HyperParallelArguments,
|
||||
HyperParallelTrainer,
|
||||
)
|
||||
hp_args = _prepare_hp_args(finetuning_args, model_args)
|
||||
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
@@ -94,25 +175,6 @@ def run_sft(
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
|
||||
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||
|
||||
callbacks = list(callbacks or [])
|
||||
processor = tokenizer_module.get("processor")
|
||||
if processor is not None:
|
||||
callbacks.append(SaveProcessorCallback(processor))
|
||||
|
||||
compute_loss_func = None
|
||||
if finetuning_args.use_dft_loss:
|
||||
compute_loss_func = dft_loss_func
|
||||
elif finetuning_args.use_eaft_loss:
|
||||
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
elif finetuning_args.use_asft_loss:
|
||||
from functools import partial
|
||||
|
||||
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
|
||||
|
||||
trainer = HyperParallelTrainer(
|
||||
hp_args=hp_args,
|
||||
model=model,
|
||||
@@ -122,20 +184,11 @@ def run_sft(
|
||||
callbacks=callbacks,
|
||||
gen_kwargs=gen_kwargs,
|
||||
ref_model=ref_model,
|
||||
compute_loss_func=compute_loss_func,
|
||||
**dataset_module,
|
||||
**tokenizer_module,
|
||||
**metric_module,
|
||||
)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from types import MethodType
|
||||
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
|
||||
|
||||
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
|
||||
trainer.add_callback(BAdamCallback)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
|
||||
@@ -88,12 +88,19 @@ def _training_function(config: dict[str, Any]) -> None:
|
||||
|
||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||
|
||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||
if finetuning_args.stage in ["pt", "sft"] and finetuning_args.use_hyper_parallel:
|
||||
if not is_hyper_parallel_available():
|
||||
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||
from .hyper_parallel import run_sft as run_sft_hp
|
||||
raise ImportError(
|
||||
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
||||
)
|
||||
if finetuning_args.stage == "pt":
|
||||
from .hyper_parallel import run_pt as run_pt_hp
|
||||
|
||||
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
run_pt_hp(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
from .hyper_parallel import run_sft as run_sft_hp
|
||||
|
||||
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
|
||||
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||
if not is_mcore_adapter_available():
|
||||
|
||||
Reference in New Issue
Block a user