[fix] fp8: add Transformer Engine backend support (#9705)

Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Santosh Bhavani
2025-12-31 18:18:02 -08:00
committed by GitHub
parent 6fe6bd290b
commit 355d5c5e5a
6 changed files with 128 additions and 70 deletions

View File

@@ -298,23 +298,6 @@ class QuantizationArguments:
default=None, default=None,
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
) )
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass @dataclass

View File

@@ -142,15 +142,6 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
# Validate advanced training features
if model_args.fp8 and model_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
@@ -347,6 +338,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo): if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.") raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
if training_args.fp8 and training_args.quantization_bit is not None:
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
if model_args.infer_backend != EngineName.HF: if model_args.infer_backend != EngineName.HF:
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.") raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
@@ -363,6 +357,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
_verify_model_args(model_args, data_args, finetuning_args) _verify_model_args(model_args, data_args, finetuning_args)
_check_extra_dependencies(model_args, finetuning_args, training_args) _check_extra_dependencies(model_args, finetuning_args, training_args)
if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
model_args.fp8 = True
if ( if (
training_args.do_train training_args.do_train
and finetuning_args.finetuning_type == "lora" and finetuning_args.finetuning_type == "lora"

View File

@@ -92,7 +92,29 @@ class RayArguments:
@dataclass @dataclass
class TrainingArguments(RayArguments, BaseTrainingArguments): class Fp8Arguments:
r"""Arguments pertaining to the FP8 training."""
fp8: bool = field(
default=False,
metadata={
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
},
)
fp8_backend: str = field(
default="auto",
metadata={
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
},
)
fp8_enable_fsdp_float8_all_gather: bool = field(
default=False,
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
)
@dataclass
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
r"""Arguments pertaining to the trainer.""" r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field( overwrite_output_dir: bool = field(

View File

@@ -12,35 +12,45 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import types
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from ..extras import logging from ..extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from ..hparams import ModelArguments from ..hparams import TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]: def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate. """Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
Args: Args:
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
Returns: Returns:
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
""" """
if not model_args.fp8: if not training_args.fp8:
return [] return []
try: backend = getattr(training_args, "fp8_backend", "auto")
# Check if AORecipeKwargs is available (Accelerate 1.8.0+) logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
from accelerate.utils import AORecipeKwargs
backend = getattr(model_args, "fp8_backend", "auto") try:
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}") # Use Transformer Engine backend (optimal for Hopper GPUs)
if backend == "te":
from accelerate.utils import FP8RecipeKwargs
logger.info_rank0("Using Transformer Engine FP8 backend")
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
# Use TorchAO backend (default)
from accelerate.utils import AORecipeKwargs
# Create Float8LinearConfig if torchao backend is used # Create Float8LinearConfig if torchao backend is used
config = None config = None
@@ -83,7 +93,7 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return True return True
# Map FSDP all-gather setting if available (this affects the underlying implementation) # Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
logger.info_rank0("FSDP float8 all-gather optimization requested") logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)] return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
@@ -92,19 +102,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
return [] return []
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]: def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
"""Get the mixed precision setting for Accelerate when using FP8. """Get the mixed precision setting for Accelerate when using FP8.
Args: Args:
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
Returns: Returns:
"fp8" if FP8 is enabled, None otherwise "fp8" if FP8 is enabled, None otherwise
""" """
return "fp8" if model_args.fp8 else None return "fp8" if training_args.fp8 else None
def configure_fp8_environment(model_args: "ModelArguments") -> None: def configure_fp8_environment(training_args: "TrainingArguments") -> None:
"""Configure FP8 environment for HuggingFace Accelerate. """Configure FP8 environment for HuggingFace Accelerate.
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
@@ -112,11 +122,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
variables and validates the FP8 configuration. variables and validates the FP8 configuration.
Args: Args:
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
""" """
import os if not training_args.fp8:
if not model_args.fp8:
return return
# Set mixed precision to fp8 for HuggingFace Accelerate # Set mixed precision to fp8 for HuggingFace Accelerate
@@ -124,38 +132,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8") logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
# Configure FP8 backend and options # Configure FP8 backend and options
backend = getattr(model_args, "fp8_backend", "auto") backend = getattr(training_args, "fp8_backend", "auto")
if backend != "auto": if backend != "auto":
os.environ["FP8_BACKEND"] = backend os.environ["FP8_BACKEND"] = backend
logger.info_rank0(f"Set FP8_BACKEND={backend}") logger.info_rank0(f"Set FP8_BACKEND={backend}")
# Create and validate FP8 recipe kwargs (for logging/debugging) # Create and validate FP8 recipe kwargs (for logging/debugging)
fp8_kwargs = create_fp8_kwargs(model_args) fp8_kwargs = create_fp8_kwargs(training_args)
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items") logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
# Enable FSDP float8 all-gather optimization if requested # Enable FSDP float8 all-gather optimization if requested
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather: if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true" os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true") logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate") logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None: def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
"""Verify that FP8 training is actually working after model preparation. """Verify that FP8 training is actually working after model preparation.
Args: Args:
accelerator: The HuggingFace Accelerator instance accelerator: The HuggingFace Accelerator instance
model_args: Model arguments containing FP8 configuration training_args: Training arguments containing FP8 configuration
""" """
if not model_args.fp8: if not training_args.fp8:
return return
# Check Accelerate's FP8 status # Check Accelerate's FP8 status
fp8_enabled = getattr(accelerator, "fp8_enabled", False) fp8_enabled = getattr(accelerator, "fp8_enabled", False)
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN") fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
backend = getattr(model_args, "fp8_backend", "auto") backend = getattr(training_args, "fp8_backend", "auto")
if backend == "torchao" or backend == "auto": if backend == "torchao" or backend == "auto":
logger.info_rank0( logger.info_rank0(
"FP8 training enabled with TorchAO backend. For optimal performance, " "FP8 training enabled with TorchAO backend. For optimal performance, "
@@ -169,3 +177,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
if not fp8_enabled: if not fp8_enabled:
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.") logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
def patch_accelerator_for_fp8() -> None:
"""Patch Accelerator to inject FP8 recipe kwargs.
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
"""
import transformer_engine.pytorch as te
from accelerate import Accelerator
# Guard against multiple patches
if getattr(Accelerator, "_te_fp8_patched", False):
return
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
if not hasattr(te, "fp8"):
te.fp8 = types.ModuleType("fp8")
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
try:
from accelerate.utils import TERecipeKwargs as FP8Recipe
use_te_recipe = True
except ImportError:
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
use_te_recipe = False
original_init = Accelerator.__init__
def patched_init(self, *args, **kwargs):
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
if use_te_recipe:
kwargs["kwargs_handlers"] = [
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
else:
kwargs["kwargs_handlers"] = [
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
]
# Only force mixed_precision when we inject handlers
kwargs["mixed_precision"] = "fp8"
return original_init(self, *args, **kwargs)
Accelerator.__init__ = patched_init
Accelerator._te_fp8_patched = True

View File

@@ -21,7 +21,7 @@ from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -41,11 +41,13 @@ class CustomTrainer(Trainer):
model_args: Optional["ModelArguments"] = None, model_args: Optional["ModelArguments"] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
if model_args is not None and model_args.fp8: training_args = kwargs.get("args")
configure_fp8_environment(model_args) if training_args.fp8:
if is_transformers_version_greater_than("4.46"): configure_fp8_environment(training_args)
kwargs["processing_class"] = kwargs.pop("tokenizer") if getattr(training_args, "fp8_backend", "auto") == "te":
patch_accelerator_for_fp8()
super().__init__(**kwargs) super().__init__(**kwargs)
if processor is not None: if processor is not None:
@@ -64,9 +66,8 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
# Verify FP8 status after trainer initialization (accelerator should be available) if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): verify_fp8_status(self.accelerator, training_args)
verify_fp8_status(self.accelerator, model_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":

View File

@@ -29,7 +29,7 @@ from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -55,13 +55,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
gen_kwargs: Optional[dict[str, Any]] = None, gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
if model_args is not None and model_args.fp8: training_args = kwargs.get("args")
configure_fp8_environment(model_args) if training_args.fp8:
if is_transformers_version_greater_than("4.46"): configure_fp8_environment(training_args)
kwargs["processing_class"] = kwargs.pop("tokenizer") if getattr(training_args, "fp8_backend", "auto") == "te":
else: patch_accelerator_for_fp8()
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
super().__init__(**kwargs) super().__init__(**kwargs)
if processor is not None: if processor is not None:
@@ -88,9 +88,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func self.compute_loss_func = dft_loss_func
# Verify FP8 status after trainer initialization (accelerator should be available) if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"): verify_fp8_status(self.accelerator, training_args)
verify_fp8_status(self.accelerator, model_args)
@override @override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":