[feat] fp8 training (#8960)

Co-authored-by: Benjamin Feuer <penfever@gmail.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
Ben Feuer
2025-09-30 23:32:53 -07:00
committed by GitHub
parent e2b1594d31
commit 1c44b60e3e
8 changed files with 322 additions and 3 deletions

View File

@@ -21,21 +21,29 @@ from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments
from ...hparams import FinetuningArguments, ModelArguments
class CustomTrainer(Trainer):
r"""Inherit Trainer for custom optimizer."""
def __init__(
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
self,
finetuning_args: "FinetuningArguments",
processor: Optional["ProcessorMixin"],
model_args: Optional["ModelArguments"] = None,
**kwargs,
) -> None:
# Configure FP8 environment if enabled
if model_args is not None and model_args.fp8:
configure_fp8_environment(model_args)
if is_transformers_version_greater_than("4.46"):
kwargs["processing_class"] = kwargs.pop("tokenizer")
@@ -56,6 +64,10 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
# Verify FP8 status after trainer initialization (accelerator should be available)
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
verify_fp8_status(self.accelerator, model_args)
@override
def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: