From 355d5c5e5a7549705c426f6dbce66f164af22052 Mon Sep 17 00:00:00 2001 From: Santosh Bhavani Date: Wed, 31 Dec 2025 18:18:02 -0800 Subject: [PATCH] [fix] fp8: add Transformer Engine backend support (#9705) Co-authored-by: Yaowei Zheng --- src/llamafactory/hparams/model_args.py | 17 ---- src/llamafactory/hparams/parser.py | 16 ++-- src/llamafactory/hparams/training_args.py | 24 ++++- src/llamafactory/train/fp8_utils.py | 105 ++++++++++++++++------ src/llamafactory/train/pt/trainer.py | 17 ++-- src/llamafactory/train/sft/trainer.py | 19 ++-- 6 files changed, 128 insertions(+), 70 deletions(-) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index aaa83057a..a245428fe 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -298,23 +298,6 @@ class QuantizationArguments: default=None, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, ) - fp8: bool = field( - default=False, - metadata={ - "help": "Enable FP8 mixed precision training via HuggingFace Accelerate. " - "Requires PyTorch 2.7+ and Hopper architecture GPUs." - }, - ) - fp8_backend: str = field( - default="auto", - metadata={ - "help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend." - }, - ) - fp8_enable_fsdp_float8_all_gather: bool = field( - default=False, - metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."}, - ) @dataclass diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 5b262d68f..695b86dfa 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -142,15 +142,6 @@ def _verify_model_args( logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") model_args.use_fast_tokenizer = False - # Validate advanced training features - if model_args.fp8 and model_args.quantization_bit is not None: - raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.") - - if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8: - logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.") - model_args.fp8 = True - - def _check_extra_dependencies( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", @@ -347,6 +338,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo): raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.") + if training_args.fp8 and training_args.quantization_bit is not None: + raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.") + if model_args.infer_backend != EngineName.HF: raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.") @@ -363,6 +357,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS _verify_model_args(model_args, data_args, finetuning_args) _check_extra_dependencies(model_args, finetuning_args, training_args) + if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8: + logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.") + model_args.fp8 = True + if ( training_args.do_train and finetuning_args.finetuning_type == "lora" diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index 86ac0802f..1a67b7f3c 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -92,7 +92,29 @@ class RayArguments: @dataclass -class TrainingArguments(RayArguments, BaseTrainingArguments): +class Fp8Arguments: + r"""Arguments pertaining to the FP8 training.""" + fp8: bool = field( + default=False, + metadata={ + "help": "Enable FP8 mixed precision training via HuggingFace Accelerate. " + "Requires PyTorch 2.7+ and Hopper architecture GPUs." + }, + ) + fp8_backend: str = field( + default="auto", + metadata={ + "help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend." + }, + ) + fp8_enable_fsdp_float8_all_gather: bool = field( + default=False, + metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."}, + ) + + +@dataclass +class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments): r"""Arguments pertaining to the trainer.""" overwrite_output_dir: bool = field( diff --git a/src/llamafactory/train/fp8_utils.py b/src/llamafactory/train/fp8_utils.py index ab8a8ee25..dfbb4ce6c 100644 --- a/src/llamafactory/train/fp8_utils.py +++ b/src/llamafactory/train/fp8_utils.py @@ -12,35 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import types from typing import TYPE_CHECKING, Any, Optional from ..extras import logging if TYPE_CHECKING: - from ..hparams import ModelArguments + from ..hparams import TrainingArguments + logger = logging.get_logger(__name__) -def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]: +def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]: """Create AORecipeKwargs for FP8 training with HuggingFace Accelerate. Args: - model_args: Model arguments containing FP8 configuration + training_args: Training arguments containing FP8 configuration Returns: List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise """ - if not model_args.fp8: + if not training_args.fp8: return [] - try: - # Check if AORecipeKwargs is available (Accelerate 1.8.0+) - from accelerate.utils import AORecipeKwargs + backend = getattr(training_args, "fp8_backend", "auto") + logger.info_rank0(f"Creating FP8 configuration with backend: {backend}") - backend = getattr(model_args, "fp8_backend", "auto") - logger.info_rank0(f"Creating FP8 configuration with backend: {backend}") + try: + # Use Transformer Engine backend (optimal for Hopper GPUs) + if backend == "te": + from accelerate.utils import FP8RecipeKwargs + + logger.info_rank0("Using Transformer Engine FP8 backend") + return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")] + + # Use TorchAO backend (default) + from accelerate.utils import AORecipeKwargs # Create Float8LinearConfig if torchao backend is used config = None @@ -83,7 +93,7 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]: return True # Map FSDP all-gather setting if available (this affects the underlying implementation) - if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: + if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather: logger.info_rank0("FSDP float8 all-gather optimization requested") return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)] @@ -92,19 +102,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]: return [] -def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]: +def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]: """Get the mixed precision setting for Accelerate when using FP8. Args: - model_args: Model arguments containing FP8 configuration + training_args: Training arguments containing FP8 configuration Returns: "fp8" if FP8 is enabled, None otherwise """ - return "fp8" if model_args.fp8 else None + return "fp8" if training_args.fp8 else None -def configure_fp8_environment(model_args: "ModelArguments") -> None: +def configure_fp8_environment(training_args: "TrainingArguments") -> None: """Configure FP8 environment for HuggingFace Accelerate. FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether @@ -112,11 +122,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None: variables and validates the FP8 configuration. Args: - model_args: Model arguments containing FP8 configuration + training_args: Training arguments containing FP8 configuration """ - import os - - if not model_args.fp8: + if not training_args.fp8: return # Set mixed precision to fp8 for HuggingFace Accelerate @@ -124,38 +132,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None: logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8") # Configure FP8 backend and options - backend = getattr(model_args, "fp8_backend", "auto") + backend = getattr(training_args, "fp8_backend", "auto") if backend != "auto": os.environ["FP8_BACKEND"] = backend logger.info_rank0(f"Set FP8_BACKEND={backend}") # Create and validate FP8 recipe kwargs (for logging/debugging) - fp8_kwargs = create_fp8_kwargs(model_args) + fp8_kwargs = create_fp8_kwargs(training_args) logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items") # Enable FSDP float8 all-gather optimization if requested - if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: + if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather: os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true" logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true") logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate") -def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None: +def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None: """Verify that FP8 training is actually working after model preparation. Args: accelerator: The HuggingFace Accelerator instance - model_args: Model arguments containing FP8 configuration + training_args: Training arguments containing FP8 configuration """ - if not model_args.fp8: + if not training_args.fp8: return # Check Accelerate's FP8 status fp8_enabled = getattr(accelerator, "fp8_enabled", False) fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN") - backend = getattr(model_args, "fp8_backend", "auto") + backend = getattr(training_args, "fp8_backend", "auto") if backend == "torchao" or backend == "auto": logger.info_rank0( "FP8 training enabled with TorchAO backend. For optimal performance, " @@ -169,3 +177,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None: if not fp8_enabled: logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.") + + +def patch_accelerator_for_fp8() -> None: + """Patch Accelerator to inject FP8 recipe kwargs. + + This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator. + We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'. + """ + import transformer_engine.pytorch as te + from accelerate import Accelerator + + # Guard against multiple patches + if getattr(Accelerator, "_te_fp8_patched", False): + return + + # Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet) + if not hasattr(te, "fp8"): + te.fp8 = types.ModuleType("fp8") + te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported") + + try: + from accelerate.utils import TERecipeKwargs as FP8Recipe + + use_te_recipe = True + except ImportError: + from accelerate.utils import FP8RecipeKwargs as FP8Recipe + + use_te_recipe = False + + original_init = Accelerator.__init__ + + def patched_init(self, *args, **kwargs): + if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]: + if use_te_recipe: + kwargs["kwargs_handlers"] = [ + FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max") + ] + else: + kwargs["kwargs_handlers"] = [ + FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max") + ] + # Only force mixed_precision when we inject handlers + kwargs["mixed_precision"] = "fp8" + return original_init(self, *args, **kwargs) + + Accelerator.__init__ = patched_init + Accelerator._te_fp8_patched = True diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index dadf496c3..459d03d09 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -21,7 +21,7 @@ from typing_extensions import override from ...extras.packages import is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback -from ..fp8_utils import configure_fp8_environment, verify_fp8_status +from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -41,11 +41,13 @@ class CustomTrainer(Trainer): model_args: Optional["ModelArguments"] = None, **kwargs, ) -> None: + kwargs["processing_class"] = kwargs.pop("tokenizer") # Configure FP8 environment if enabled - if model_args is not None and model_args.fp8: - configure_fp8_environment(model_args) - if is_transformers_version_greater_than("4.46"): - kwargs["processing_class"] = kwargs.pop("tokenizer") + training_args = kwargs.get("args") + if training_args.fp8: + configure_fp8_environment(training_args) + if getattr(training_args, "fp8_backend", "auto") == "te": + patch_accelerator_for_fp8() super().__init__(**kwargs) if processor is not None: @@ -64,9 +66,8 @@ class CustomTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) - # Verify FP8 status after trainer initialization (accelerator should be available) - if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): - verify_fp8_status(self.accelerator, model_args) + if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization + verify_fp8_status(self.accelerator, training_args) @override def create_optimizer(self) -> "torch.optim.Optimizer": diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index ea66a8511..d369b462f 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -29,7 +29,7 @@ from ...extras import logging from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback -from ..fp8_utils import configure_fp8_environment, verify_fp8_status +from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -55,13 +55,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): gen_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> None: + kwargs["processing_class"] = kwargs.pop("tokenizer") # Configure FP8 environment if enabled - if model_args is not None and model_args.fp8: - configure_fp8_environment(model_args) - if is_transformers_version_greater_than("4.46"): - kwargs["processing_class"] = kwargs.pop("tokenizer") - else: - self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer") + training_args = kwargs.get("args") + if training_args.fp8: + configure_fp8_environment(training_args) + if getattr(training_args, "fp8_backend", "auto") == "te": + patch_accelerator_for_fp8() super().__init__(**kwargs) if processor is not None: @@ -88,9 +88,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.compute_loss_func = dft_loss_func - # Verify FP8 status after trainer initialization (accelerator should be available) - if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): - verify_fp8_status(self.accelerator, model_args) + if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization + verify_fp8_status(self.accelerator, training_args) @override def create_optimizer(self) -> "torch.optim.Optimizer":