[fix] fp8: add Transformer Engine backend support (#9705)

Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Santosh Bhavani
2025-12-31 18:18:02 -08:00
committed by GitHub
parent 6fe6bd290b
commit 355d5c5e5a
6 changed files with 128 additions and 70 deletions

View File

@@ -29,7 +29,7 @@ from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -55,13 +55,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
gen_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
else:
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
training_args = kwargs.get("args")
if training_args.fp8:
configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te":
patch_accelerator_for_fp8()
super().__init__(**kwargs)
if processor is not None:
@@ -88,9 +88,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":