[feat] fp8 training (#8960)

Co-authored-by: Benjamin Feuer <penfever@gmail.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Ben Feuer 2025-09-30 23:32:53 -07:00 committed by GitHub
parent e2b1594d31
commit 1c44b60e3e
8 changed files with 322 additions and 3 deletions

View File

@ -0,0 +1,48 @@
# FP8 training example with DeepSpeed ZeRO-3
# This config demonstrates FP8 mixed precision training using HuggingFace Accelerate
# with DeepSpeed providing memory optimization (not FP8 handling)
### Model configuration
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### Method configuration
stage: sft
do_train: true
finetuning_type: full
### Dataset configuration
dataset: identity
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### Output configuration
output_dir: saves/llama3-8b/fp8-deepspeed/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### Training configuration
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
### FP8 configuration
fp8: true
fp8_backend: torchao # Use TorchAO backend for FP8
fp8_enable_fsdp_float8_all_gather: false # Not used with DeepSpeed
### DeepSpeed configuration
deepspeed: examples/deepspeed/ds_z3_fp8_config.json
### Logging configuration
report_to: wandb
run_name: llama3_fp8_deepspeed_sft

View File

@ -0,0 +1,51 @@
# FP8 training example with FSDP
# This config demonstrates FP8 mixed precision training using HuggingFace Accelerate
# with FSDP for distributed training and float8 all-gather optimization
### Model configuration
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
trust_remote_code: true
### Method configuration
stage: sft
do_train: true
finetuning_type: full
### Dataset configuration
dataset: identity
template: llama3
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
### Output configuration
output_dir: saves/llama3-8b/fp8-fsdp/sft
logging_steps: 10
save_steps: 500
plot_loss: true
overwrite_output_dir: true
### Training configuration
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 5.0e-5
num_train_epochs: 3.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
### FP8 configuration
fp8: true
fp8_backend: torchao # Use TorchAO backend for FP8
fp8_enable_fsdp_float8_all_gather: true # Enable FSDP2 float8 all-gather optimization
### FSDP configuration (using training arguments - no separate FSDP config file)
fsdp:
- full_shard
- auto_wrap
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
### Logging configuration
report_to: wandb
run_name: llama3_fp8_fsdp_sft

View File

@ -70,6 +70,9 @@ extra_require = {
], ],
"openmind": ["openmind"], "openmind": ["openmind"],
"swanlab": ["swanlab"], "swanlab": ["swanlab"],
"fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"],
"fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
"fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
"dev": ["pre-commit", "ruff", "pytest", "build"], "dev": ["pre-commit", "ruff", "pytest", "build"],
} }

View File

@ -213,6 +213,23 @@ class QuantizationArguments:
default=None, default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
) )
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass @dataclass

View File

@ -131,6 +131,14 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
# Validate advanced training features
if model_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",

View File

@ -0,0 +1,171 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Optional
from ..extras import logging
if TYPE_CHECKING:
from ..hparams import ModelArguments
logger = logging.get_logger(__name__)
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
Args:
model_args: Model arguments containing FP8 configuration
Returns:
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
"""
if not model_args.fp8:
return []
try:
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
from accelerate.utils import AORecipeKwargs
backend = getattr(model_args, "fp8_backend", "auto")
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
# Create Float8LinearConfig if torchao backend is used
config = None
if backend == "torchao" or backend == "auto":
from torchao.float8 import Float8LinearConfig
# Use rowwise scaling for better performance (as recommended by torchao)
# Configure alignment requirements for FP8 kernels
config = Float8LinearConfig.from_recipe_name("rowwise")
# Enable alignment for better kernel performance
if hasattr(config, "enable_amax_init"):
config.enable_amax_init = True
if hasattr(config, "enable_pre_and_post_forward"):
config.enable_pre_and_post_forward = True
# Create module filter function to skip problematic layers
# TorchAO FP8 requires dimensions divisible by 16 for optimal kernels
def module_filter_func(module, layer_name):
# Skip embedding and output layers for numerical stability
skip_layers = ["embed", "lm_head", "output", "classifier"]
if any(skip_name in layer_name.lower() for skip_name in skip_layers):
return False
# Only convert Linear layers
if not (hasattr(module, "weight") and len(module.weight.shape) == 2):
return False
# Check dimension alignment for FP8 kernels
weight = module.weight
in_features, out_features = weight.shape[1], weight.shape[0]
# Skip layers with dimensions not divisible by 16 to avoid kernel errors
if in_features % 16 != 0 or out_features % 16 != 0:
logger.debug(
f"Skipping layer {layer_name} with dimensions {out_features}x{in_features} (not divisible by 16)"
)
return False
return True
# Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
except Exception as e:
logger.info_rank0(f"Failed to create FP8 configuration: {e}")
return []
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
"""Get the mixed precision setting for Accelerate when using FP8.
Args:
model_args: Model arguments containing FP8 configuration
Returns:
"fp8" if FP8 is enabled, None otherwise
"""
return "fp8" if model_args.fp8 else None
def configure_fp8_environment(model_args: "ModelArguments") -> None:
"""Configure FP8 environment for HuggingFace Accelerate.
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
DeepSpeed or FSDP is used for distributed training. This function sets up the environment
variables and validates the FP8 configuration.
Args:
model_args: Model arguments containing FP8 configuration
"""
import os
if not model_args.fp8:
return
# Set mixed precision to fp8 for HuggingFace Accelerate
os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
# Configure FP8 backend and options
backend = getattr(model_args, "fp8_backend", "auto")
if backend != "auto":
os.environ["FP8_BACKEND"] = backend
logger.info_rank0(f"Set FP8_BACKEND={backend}")
# Create and validate FP8 recipe kwargs (for logging/debugging)
fp8_kwargs = create_fp8_kwargs(model_args)
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
# Enable FSDP float8 all-gather optimization if requested
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
"""Verify that FP8 training is actually working after model preparation.
Args:
accelerator: The HuggingFace Accelerator instance
model_args: Model arguments containing FP8 configuration
"""
if not model_args.fp8:
return
# Check Accelerate's FP8 status
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
backend = getattr(model_args, "fp8_backend", "auto")
if backend == "torchao" or backend == "auto":
logger.info_rank0(
"FP8 training enabled with TorchAO backend. For optimal performance, "
"ensure model layer dimensions are mostly divisible by 16. "
"If you encounter issues, try fp8_backend='te' with Transformer Engine."
)
else:
logger.info_rank0(f"FP8 training enabled with {backend} backend.")
logger.info_rank0(f"Accelerate FP8 status - enabled: {fp8_enabled}, backend: {fp8_backend_type}")
if not fp8_enabled:
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")

View File

@ -21,21 +21,29 @@ from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import ProcessorMixin
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments, ModelArguments
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
r"""Inherit Trainer for custom optimizer.""" r"""Inherit Trainer for custom optimizer."""
def __init__( def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs self,
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
model_args: Optional["ModelArguments"] = None,
**kwargs,
) -> None: ) -> None:
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"): if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer") kwargs["processing_class"] = kwargs.pop("tokenizer")
@ -56,6 +64,10 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:

View File

@ -29,6 +29,7 @@ from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@ -37,7 +38,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments from ...hparams import FinetuningArguments, ModelArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -50,9 +51,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self, self,
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
model_args: Optional["ModelArguments"] = None,
gen_kwargs: Optional[dict[str, Any]] = None, gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"): if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer") kwargs["processing_class"] = kwargs.pop("tokenizer")
else: else:
@ -83,6 +88,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func self.compute_loss_func = dft_loss_func
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None: