From 1c44b60e3e12be94191b24419bfc8fe3242fca1f Mon Sep 17 00:00:00 2001 From: Ben Feuer Date: Tue, 30 Sep 2025 23:32:53 -0700 Subject: [PATCH] [feat] fp8 training (#8960) Co-authored-by: Benjamin Feuer Co-authored-by: Yaowei Zheng --- .../extras/fp8/llama3_fp8_deepspeed_sft.yaml | 48 +++++ examples/extras/fp8/llama3_fp8_fsdp_sft.yaml | 51 ++++++ setup.py | 3 + src/llamafactory/hparams/model_args.py | 17 ++ src/llamafactory/hparams/parser.py | 8 + src/llamafactory/train/fp8_utils.py | 171 ++++++++++++++++++ src/llamafactory/train/pt/trainer.py | 16 +- src/llamafactory/train/sft/trainer.py | 11 +- 8 files changed, 322 insertions(+), 3 deletions(-) create mode 100644 examples/extras/fp8/llama3_fp8_deepspeed_sft.yaml create mode 100644 examples/extras/fp8/llama3_fp8_fsdp_sft.yaml create mode 100644 src/llamafactory/train/fp8_utils.py diff --git a/examples/extras/fp8/llama3_fp8_deepspeed_sft.yaml b/examples/extras/fp8/llama3_fp8_deepspeed_sft.yaml new file mode 100644 index 00000000..1f92c52b --- /dev/null +++ b/examples/extras/fp8/llama3_fp8_deepspeed_sft.yaml @@ -0,0 +1,48 @@ +# FP8 training example with DeepSpeed ZeRO-3 +# This config demonstrates FP8 mixed precision training using HuggingFace Accelerate +# with DeepSpeed providing memory optimization (not FP8 handling) + +### Model configuration +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +trust_remote_code: true + +### Method configuration +stage: sft +do_train: true +finetuning_type: full + +### Dataset configuration +dataset: identity +template: llama3 +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### Output configuration +output_dir: saves/llama3-8b/fp8-deepspeed/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### Training configuration +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 5.0e-5 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true + +### FP8 configuration +fp8: true +fp8_backend: torchao # Use TorchAO backend for FP8 +fp8_enable_fsdp_float8_all_gather: false # Not used with DeepSpeed + +### DeepSpeed configuration +deepspeed: examples/deepspeed/ds_z3_fp8_config.json + +### Logging configuration +report_to: wandb +run_name: llama3_fp8_deepspeed_sft \ No newline at end of file diff --git a/examples/extras/fp8/llama3_fp8_fsdp_sft.yaml b/examples/extras/fp8/llama3_fp8_fsdp_sft.yaml new file mode 100644 index 00000000..7590140d --- /dev/null +++ b/examples/extras/fp8/llama3_fp8_fsdp_sft.yaml @@ -0,0 +1,51 @@ +# FP8 training example with FSDP +# This config demonstrates FP8 mixed precision training using HuggingFace Accelerate +# with FSDP for distributed training and float8 all-gather optimization + +### Model configuration +model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +trust_remote_code: true + +### Method configuration +stage: sft +do_train: true +finetuning_type: full + +### Dataset configuration +dataset: identity +template: llama3 +cutoff_len: 1024 +max_samples: 1000 +overwrite_cache: true +preprocessing_num_workers: 16 + +### Output configuration +output_dir: saves/llama3-8b/fp8-fsdp/sft +logging_steps: 10 +save_steps: 500 +plot_loss: true +overwrite_output_dir: true + +### Training configuration +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +learning_rate: 5.0e-5 +num_train_epochs: 3.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +bf16: true + +### FP8 configuration +fp8: true +fp8_backend: torchao # Use TorchAO backend for FP8 +fp8_enable_fsdp_float8_all_gather: true # Enable FSDP2 float8 all-gather optimization + +### FSDP configuration (using training arguments - no separate FSDP config file) +fsdp: + - full_shard + - auto_wrap +fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + +### Logging configuration +report_to: wandb +run_name: llama3_fp8_fsdp_sft \ No newline at end of file diff --git a/setup.py b/setup.py index 08ba557e..ea99f6f5 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,9 @@ extra_require = { ], "openmind": ["openmind"], "swanlab": ["swanlab"], + "fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"], + "fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"], + "fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"], "dev": ["pre-commit", "ruff", "pytest", "build"], } diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index d2f7cc52..0576b729 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -213,6 +213,23 @@ class QuantizationArguments: default=None, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, ) + fp8: bool = field( + default=False, + metadata={ + "help": "Enable FP8 mixed precision training via HuggingFace Accelerate. " + "Requires PyTorch 2.7+ and Hopper architecture GPUs." + }, + ) + fp8_backend: str = field( + default="auto", + metadata={ + "help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend." + }, + ) + fp8_enable_fsdp_float8_all_gather: bool = field( + default=False, + metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."}, + ) @dataclass diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index cd3ad9aa..40869a00 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -131,6 +131,14 @@ def _verify_model_args( logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") model_args.use_fast_tokenizer = False + # Validate advanced training features + if model_args.fp8 and model_args.quantization_bit is not None: + raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.") + + if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8: + logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.") + model_args.fp8 = True + def _check_extra_dependencies( model_args: "ModelArguments", diff --git a/src/llamafactory/train/fp8_utils.py b/src/llamafactory/train/fp8_utils.py new file mode 100644 index 00000000..ab8a8ee2 --- /dev/null +++ b/src/llamafactory/train/fp8_utils.py @@ -0,0 +1,171 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Optional + +from ..extras import logging + + +if TYPE_CHECKING: + from ..hparams import ModelArguments + +logger = logging.get_logger(__name__) + + +def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]: + """Create AORecipeKwargs for FP8 training with HuggingFace Accelerate. + + Args: + model_args: Model arguments containing FP8 configuration + + Returns: + List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise + """ + if not model_args.fp8: + return [] + + try: + # Check if AORecipeKwargs is available (Accelerate 1.8.0+) + from accelerate.utils import AORecipeKwargs + + backend = getattr(model_args, "fp8_backend", "auto") + logger.info_rank0(f"Creating FP8 configuration with backend: {backend}") + + # Create Float8LinearConfig if torchao backend is used + config = None + if backend == "torchao" or backend == "auto": + from torchao.float8 import Float8LinearConfig + + # Use rowwise scaling for better performance (as recommended by torchao) + # Configure alignment requirements for FP8 kernels + config = Float8LinearConfig.from_recipe_name("rowwise") + + # Enable alignment for better kernel performance + if hasattr(config, "enable_amax_init"): + config.enable_amax_init = True + if hasattr(config, "enable_pre_and_post_forward"): + config.enable_pre_and_post_forward = True + + # Create module filter function to skip problematic layers + # TorchAO FP8 requires dimensions divisible by 16 for optimal kernels + def module_filter_func(module, layer_name): + # Skip embedding and output layers for numerical stability + skip_layers = ["embed", "lm_head", "output", "classifier"] + if any(skip_name in layer_name.lower() for skip_name in skip_layers): + return False + + # Only convert Linear layers + if not (hasattr(module, "weight") and len(module.weight.shape) == 2): + return False + + # Check dimension alignment for FP8 kernels + weight = module.weight + in_features, out_features = weight.shape[1], weight.shape[0] + + # Skip layers with dimensions not divisible by 16 to avoid kernel errors + if in_features % 16 != 0 or out_features % 16 != 0: + logger.debug( + f"Skipping layer {layer_name} with dimensions {out_features}x{in_features} (not divisible by 16)" + ) + return False + + return True + + # Map FSDP all-gather setting if available (this affects the underlying implementation) + if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: + logger.info_rank0("FSDP float8 all-gather optimization requested") + + return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)] + except Exception as e: + logger.info_rank0(f"Failed to create FP8 configuration: {e}") + return [] + + +def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]: + """Get the mixed precision setting for Accelerate when using FP8. + + Args: + model_args: Model arguments containing FP8 configuration + + Returns: + "fp8" if FP8 is enabled, None otherwise + """ + return "fp8" if model_args.fp8 else None + + +def configure_fp8_environment(model_args: "ModelArguments") -> None: + """Configure FP8 environment for HuggingFace Accelerate. + + FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether + DeepSpeed or FSDP is used for distributed training. This function sets up the environment + variables and validates the FP8 configuration. + + Args: + model_args: Model arguments containing FP8 configuration + """ + import os + + if not model_args.fp8: + return + + # Set mixed precision to fp8 for HuggingFace Accelerate + os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8" + logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8") + + # Configure FP8 backend and options + backend = getattr(model_args, "fp8_backend", "auto") + if backend != "auto": + os.environ["FP8_BACKEND"] = backend + logger.info_rank0(f"Set FP8_BACKEND={backend}") + + # Create and validate FP8 recipe kwargs (for logging/debugging) + fp8_kwargs = create_fp8_kwargs(model_args) + logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items") + + # Enable FSDP float8 all-gather optimization if requested + if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: + os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true" + logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true") + + logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate") + + +def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None: + """Verify that FP8 training is actually working after model preparation. + + Args: + accelerator: The HuggingFace Accelerator instance + model_args: Model arguments containing FP8 configuration + """ + if not model_args.fp8: + return + + # Check Accelerate's FP8 status + fp8_enabled = getattr(accelerator, "fp8_enabled", False) + fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN") + + backend = getattr(model_args, "fp8_backend", "auto") + if backend == "torchao" or backend == "auto": + logger.info_rank0( + "FP8 training enabled with TorchAO backend. For optimal performance, " + "ensure model layer dimensions are mostly divisible by 16. " + "If you encounter issues, try fp8_backend='te' with Transformer Engine." + ) + else: + logger.info_rank0(f"FP8 training enabled with {backend} backend.") + + logger.info_rank0(f"Accelerate FP8 status - enabled: {fp8_enabled}, backend: {fp8_backend_type}") + + if not fp8_enabled: + logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.") diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 096cbf68..dadf496c 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -21,21 +21,29 @@ from typing_extensions import override from ...extras.packages import is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback +from ..fp8_utils import configure_fp8_environment, verify_fp8_status from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: from transformers import ProcessorMixin - from ...hparams import FinetuningArguments + from ...hparams import FinetuningArguments, ModelArguments class CustomTrainer(Trainer): r"""Inherit Trainer for custom optimizer.""" def __init__( - self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + self, + finetuning_args: "FinetuningArguments", + processor: Optional["ProcessorMixin"], + model_args: Optional["ModelArguments"] = None, + **kwargs, ) -> None: + # Configure FP8 environment if enabled + if model_args is not None and model_args.fp8: + configure_fp8_environment(model_args) if is_transformers_version_greater_than("4.46"): kwargs["processing_class"] = kwargs.pop("tokenizer") @@ -56,6 +64,10 @@ class CustomTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) + # Verify FP8 status after trainer initialization (accelerator should be available) + if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): + verify_fp8_status(self.accelerator, model_args) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index d378a3a3..ea66a851 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -29,6 +29,7 @@ from ...extras import logging from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback +from ..fp8_utils import configure_fp8_environment, verify_fp8_status from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -37,7 +38,7 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizer, ProcessorMixin from transformers.trainer import PredictionOutput - from ...hparams import FinetuningArguments + from ...hparams import FinetuningArguments, ModelArguments logger = logging.get_logger(__name__) @@ -50,9 +51,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], + model_args: Optional["ModelArguments"] = None, gen_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> None: + # Configure FP8 environment if enabled + if model_args is not None and model_args.fp8: + configure_fp8_environment(model_args) if is_transformers_version_greater_than("4.46"): kwargs["processing_class"] = kwargs.pop("tokenizer") else: @@ -83,6 +88,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.compute_loss_func = dft_loss_func + # Verify FP8 status after trainer initialization (accelerator should be available) + if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): + verify_fp8_status(self.accelerator, model_args) + @override def create_optimizer(self) -> "torch.optim.Optimizer": if self.optimizer is None: