diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 2a1ecc943..b2547c510 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -487,7 +487,7 @@ class FinetuningArguments( metadata={ "help": ( "Whether or not to use HyperParallel distributed training backend (FSDP/TP). " - "Only supported for the 'sft' stage with full fine-tuning." + "Only supported for the 'pt' and 'sft' stages with full fine-tuning." ) }, ) diff --git a/src/llamafactory/train/hyper_parallel/__init__.py b/src/llamafactory/train/hyper_parallel/__init__.py index 6107a9ae7..88bbf20dd 100644 --- a/src/llamafactory/train/hyper_parallel/__init__.py +++ b/src/llamafactory/train/hyper_parallel/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .workflow import run_sft +from .workflow import run_pt, run_sft -__all__ = ["run_sft"] +__all__ = ["run_pt", "run_sft"] diff --git a/src/llamafactory/train/hyper_parallel/trainer.py b/src/llamafactory/train/hyper_parallel/trainer.py new file mode 100644 index 000000000..8ca131143 --- /dev/null +++ b/src/llamafactory/train/hyper_parallel/trainer.py @@ -0,0 +1,222 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HyperParallel distributed trainer for LlamaFactory.""" + +import logging +import os +import types +from contextlib import nullcontext +from typing import Any, Optional + +import torch +from hyper_parallel.integration.llamafactory import ( + HSDPModule, + HyperParallelArguments, + export_to_hf_format, + fsdp2_prepare_model, + hsdp_sync_stream, + load_hsdp_model, + load_hsdp_optimizer_and_scheduler, + save_hsdp_checkpoint, + wrap_optimizer_with_skip_dtensor_dispatch, +) +from hyper_parallel.integration.llamafactory import ( + clip_grad_norm_ as hp_clip_grad_norm_, +) +from torch import nn + +from ..sft.trainer import CustomSeq2SeqTrainer + + +logger = logging.getLogger(__name__) + + +class HyperParallelTrainer(CustomSeq2SeqTrainer): + """Trainer that replaces Accelerate FSDP2 with HyperParallel fully_shard. + + Inherits CustomSeq2SeqTrainer for training algorithm logic (loss, metrics, + prediction, sampler, etc.) and only overrides HSDP-specific behavior. + """ + + def __init__( + self, + hp_args: HyperParallelArguments, + finetuning_args=None, + processor=None, + ref_model: Optional[nn.Module] = None, + **kwargs, + ): + self._hp_args = hp_args + + # Let CustomSeq2SeqTrainer handle everything except ref_model — + # Custom would prepare it with accelerate's fsdp2_prepare_model, + # but we need HP's version instead. + super().__init__( + finetuning_args=finetuning_args, + processor=processor, + ref_model=None, + **kwargs, + ) + + if not getattr(self.accelerator, "is_fsdp2", False): + raise ValueError("HyperParallel trainer requires Accelerate FSDP2 mode to be enabled.") + + # Prepare ref_model with HP's fsdp2_prepare_model + self.ref_model = ref_model + if self.ref_model is not None: + self.ref_model = fsdp2_prepare_model(self.accelerator, self.ref_model, self._hp_args) + + self._orig_accelerator_clip_grad_norm = self.accelerator.clip_grad_norm_ + self._orig_fsdp2_prepare_model = None + self._accelerator_patches_active = False + + def _activate_accelerator_patches(self) -> None: + """Patch Accelerate to use HyperParallel fsdp2_prepare_model and clip_grad_norm_.""" + if self._accelerator_patches_active: + return + + import accelerate.accelerator as acc_module # pylint: disable=C0415 + + hp_args = self._hp_args + + self._orig_fsdp2_prepare_model = acc_module.fsdp2_prepare_model + + def _hp_fsdp2_prepare_model(accelerator, model): + return fsdp2_prepare_model(accelerator, model, hp_args) + + acc_module.fsdp2_prepare_model = _hp_fsdp2_prepare_model + + def _hp_clip_grad_norm(accelerator, parameters, max_norm, norm_type=2): + if getattr(accelerator, "is_fsdp2", False): + accelerator.unscale_gradients() + parameter_list = list(parameters) + parameter_ids = {id(param) for param in parameter_list} + for model in accelerator._models: # pylint: disable=protected-access + if not isinstance(model, HSDPModule): + continue + model_param_ids = {id(param) for param in model.parameters()} + if parameter_ids and parameter_ids.issubset(model_param_ids): + return hp_clip_grad_norm_(parameter_list, max_norm, norm_type=norm_type) + return self._orig_accelerator_clip_grad_norm(parameters, max_norm, norm_type=norm_type) + + self.accelerator.clip_grad_norm_ = types.MethodType(_hp_clip_grad_norm, self.accelerator) + self._accelerator_patches_active = True + + def _restore_accelerator_patches(self) -> None: + """Restore original Accelerate methods.""" + if not self._accelerator_patches_active: + return + + import accelerate.accelerator as acc_module # pylint: disable=C0415 + + if self._orig_fsdp2_prepare_model is not None: + acc_module.fsdp2_prepare_model = self._orig_fsdp2_prepare_model + self.accelerator.clip_grad_norm_ = self._orig_accelerator_clip_grad_norm + self._accelerator_patches_active = False + + def _wrap_model(self, model: nn.Module, training: bool = True, dataloader=None) -> nn.Module: + """Let Accelerate own FSDP2/HSDP wrapping so optimizer remapping stays correct.""" + del dataloader + if isinstance(model, HSDPModule): + return model + if training and getattr(self.accelerator, "is_fsdp2", False): + return model + return super()._wrap_model(model, training=training) + + def _move_model_to_device(self, model: nn.Module, device: Optional[torch.device] = None): + """Skip redundant device moves for HSDP-wrapped models.""" + if isinstance(model, HSDPModule): + return model + if device is None: + return model + return model.to(device) + + def train(self, *args, **kwargs): + """Activate HP patches during training and restore afterwards.""" + self._activate_accelerator_patches() + try: + return super().train(*args, **kwargs) + finally: + self._restore_accelerator_patches() + + def training_step( + self, + model: nn.Module, + inputs: dict[str, Any], + num_items_in_batch: Optional[int] = None, + ) -> torch.Tensor: + """Standard training step with HSDP gradient synchronization.""" + model.train() + inputs = self._prepare_inputs(inputs) + + sync_gradients = getattr(self.accelerator, "sync_gradients", True) + if isinstance(model, HSDPModule): + model.set_is_last_backward(sync_gradients) + model.set_requires_gradient_sync(sync_gradients) + + compute_loss_context_manager = getattr(self, "compute_loss_context_manager", nullcontext) + with compute_loss_context_manager(): + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + + if self.args.n_gpu > 1: + loss = loss.mean() + + if not getattr(self, "model_accepts_loss_kwargs", False) and getattr(self, "compute_loss_func", None) is None: + loss = loss / self.args.gradient_accumulation_steps + + self.accelerator.backward(loss) + + if isinstance(model, HSDPModule) and sync_gradients: + hsdp_sync_stream() + + return loss.detach() + + def create_optimizer(self): + """Create optimizer and wrap step with SkipDTensorDispatch.""" + optimizer = super().create_optimizer() + wrap_optimizer_with_skip_dtensor_dispatch(optimizer) + return optimizer + + def _save_optimizer_and_scheduler(self, output_dir: str) -> None: + """Save model/optimizer shards per-rank and scheduler.""" + save_hsdp_checkpoint( + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + output_dir=output_dir, + should_save_scheduler=self.args.should_save and self.lr_scheduler is not None, + ) + + def _load_from_checkpoint(self, resume_from_checkpoint: str, model: Optional[nn.Module] = None) -> None: + """Load model from HSDP sharded checkpoint.""" + target = model if model is not None else self.model + loaded = load_hsdp_model(target, resume_from_checkpoint) + if not loaded: + return super()._load_from_checkpoint(resume_from_checkpoint, model=model) + self._pending_hsdp_checkpoint = resume_from_checkpoint + return None + + def _load_optimizer_and_scheduler(self, checkpoint: Optional[str] = None) -> None: + """Load optimizer/scheduler from per-rank checkpoint files.""" + ckpt_dir = getattr(self, "_pending_hsdp_checkpoint", None) or checkpoint + if ckpt_dir is None: + return + load_hsdp_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, ckpt_dir) + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + """Save model weights in HuggingFace-compatible format.""" + save_dir = output_dir or self.args.output_dir + os.makedirs(save_dir, exist_ok=True) + export_to_hf_format(self.model, getattr(self, "processing_class", None), save_dir) diff --git a/src/llamafactory/train/hyper_parallel/workflow.py b/src/llamafactory/train/hyper_parallel/workflow.py index 5929deef2..360a8ccb8 100644 --- a/src/llamafactory/train/hyper_parallel/workflow.py +++ b/src/llamafactory/train/hyper_parallel/workflow.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import TYPE_CHECKING, Optional +from transformers import DataCollatorForLanguageModeling + from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger @@ -21,9 +24,9 @@ from ...extras.misc import calculate_tps from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..callbacks import SaveProcessorCallback from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor -from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func +from ..trainer_utils import create_modelcard_and_push, create_ref_model +from .trainer import HyperParallelTrainer if TYPE_CHECKING: @@ -35,6 +38,90 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def _prepare_hp_args(finetuning_args: "FinetuningArguments", model_args: "ModelArguments"): + r"""Load HyperParallel arguments and apply LlamaFactory-side overrides. + + When activation optimization is enabled, skip native gradient checkpointing + so HP can install its own via ``setup_activation_optimization``. + """ + if not is_hyper_parallel_available(): + raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.") + + from hyper_parallel.integration.llamafactory import HyperParallelArguments # pylint: disable=C0415 + + hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args) + if hp_args.activation_mode != "none": + model_args.disable_gradient_checkpointing = True + return hp_args + + +def run_pt( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[list["TrainerCallback"]] = None, +): + hp_args = _prepare_hp_args(finetuning_args, model_args) + + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + trainer = HyperParallelTrainer( + hp_args=hp_args, + model=model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + **tokenizer_module, + ) + + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + keys = ["loss"] + if isinstance(dataset_module.get("eval_dataset"), dict): + keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] + else: + keys += ["eval_loss"] + + plot_loss(training_args.output_dir, keys=keys) + + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval") + + if isinstance(dataset_module.get("eval_dataset"), dict): + for key in dataset_module["eval_dataset"].keys(): + try: + perplexity = math.exp(metrics[f"eval_{key}_loss"]) + except OverflowError: + perplexity = float("inf") + + metrics[f"eval_{key}_perplexity"] = perplexity + else: + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") + + metrics["eval_perplexity"] = perplexity + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) + + def run_sft( model_args: "ModelArguments", data_args: "DataArguments", @@ -43,13 +130,7 @@ def run_sft( generating_args: "GeneratingArguments", callbacks: Optional[list["TrainerCallback"]] = None, ): - if not is_hyper_parallel_available(): - raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.") - - from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415 - HyperParallelArguments, - HyperParallelTrainer, - ) + hp_args = _prepare_hp_args(finetuning_args, model_args) tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] @@ -94,25 +175,6 @@ def run_sft( gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids gen_kwargs["pad_token_id"] = tokenizer.pad_token_id - hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args) - - callbacks = list(callbacks or []) - processor = tokenizer_module.get("processor") - if processor is not None: - callbacks.append(SaveProcessorCallback(processor)) - - compute_loss_func = None - if finetuning_args.use_dft_loss: - compute_loss_func = dft_loss_func - elif finetuning_args.use_eaft_loss: - compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731 - outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha - ) - elif finetuning_args.use_asft_loss: - from functools import partial - - compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha) - trainer = HyperParallelTrainer( hp_args=hp_args, model=model, @@ -122,20 +184,11 @@ def run_sft( callbacks=callbacks, gen_kwargs=gen_kwargs, ref_model=ref_model, - compute_loss_func=compute_loss_func, **dataset_module, **tokenizer_module, **metric_module, ) - if finetuning_args.use_badam: - from types import MethodType - - from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import] - - trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator) - trainer.add_callback(BAdamCallback) - # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 2079d840c..22f7dbf87 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -88,12 +88,19 @@ def _training_function(config: dict[str, Any]) -> None: callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last - if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel: + if finetuning_args.stage in ["pt", "sft"] and finetuning_args.use_hyper_parallel: if not is_hyper_parallel_available(): - raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.") - from .hyper_parallel import run_sft as run_sft_hp + raise ImportError( + "hyper_parallel is not installed. Please install it with `pip install hyper_parallel`." + ) + if finetuning_args.stage == "pt": + from .hyper_parallel import run_pt as run_pt_hp - run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + run_pt_hp(model_args, data_args, training_args, finetuning_args, callbacks) + else: + from .hyper_parallel import run_sft as run_sft_hp + + run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca: if not is_mcore_adapter_available():