mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-07 22:50:35 +08:00
[fix] fp8: add Transformer Engine backend support (#9705)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -298,23 +298,6 @@ class QuantizationArguments:
|
||||
default=None,
|
||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||
)
|
||||
fp8: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||
},
|
||||
)
|
||||
fp8_backend: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -142,15 +142,6 @@ def _verify_model_args(
|
||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||
model_args.use_fast_tokenizer = False
|
||||
|
||||
# Validate advanced training features
|
||||
if model_args.fp8 and model_args.quantization_bit is not None:
|
||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||
|
||||
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
model_args.fp8 = True
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
@@ -347,6 +338,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||
|
||||
if training_args.fp8 and training_args.quantization_bit is not None:
|
||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||
|
||||
if model_args.infer_backend != EngineName.HF:
|
||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||
|
||||
@@ -363,6 +357,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
|
||||
if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
model_args.fp8 = True
|
||||
|
||||
if (
|
||||
training_args.do_train
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
|
||||
@@ -92,7 +92,29 @@ class RayArguments:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||
class Fp8Arguments:
|
||||
r"""Arguments pertaining to the FP8 training."""
|
||||
fp8: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||
},
|
||||
)
|
||||
fp8_backend: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||
},
|
||||
)
|
||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||
r"""Arguments pertaining to the trainer."""
|
||||
|
||||
overwrite_output_dir: bool = field(
|
||||
|
||||
@@ -12,36 +12,46 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from ..extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..hparams import ModelArguments
|
||||
from ..hparams import TrainingArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
|
||||
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
|
||||
Returns:
|
||||
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
|
||||
"""
|
||||
if not model_args.fp8:
|
||||
if not training_args.fp8:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
backend = getattr(training_args, "fp8_backend", "auto")
|
||||
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
||||
|
||||
try:
|
||||
# Use Transformer Engine backend (optimal for Hopper GPUs)
|
||||
if backend == "te":
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
logger.info_rank0("Using Transformer Engine FP8 backend")
|
||||
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
|
||||
|
||||
# Use TorchAO backend (default)
|
||||
from accelerate.utils import AORecipeKwargs
|
||||
|
||||
# Create Float8LinearConfig if torchao backend is used
|
||||
config = None
|
||||
if backend == "torchao" or backend == "auto":
|
||||
@@ -83,7 +93,7 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
return True
|
||||
|
||||
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
||||
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
|
||||
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
||||
|
||||
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
||||
@@ -92,19 +102,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
|
||||
def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
|
||||
"""Get the mixed precision setting for Accelerate when using FP8.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
|
||||
Returns:
|
||||
"fp8" if FP8 is enabled, None otherwise
|
||||
"""
|
||||
return "fp8" if model_args.fp8 else None
|
||||
return "fp8" if training_args.fp8 else None
|
||||
|
||||
|
||||
def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
def configure_fp8_environment(training_args: "TrainingArguments") -> None:
|
||||
"""Configure FP8 environment for HuggingFace Accelerate.
|
||||
|
||||
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
|
||||
@@ -112,11 +122,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
variables and validates the FP8 configuration.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
"""
|
||||
import os
|
||||
|
||||
if not model_args.fp8:
|
||||
if not training_args.fp8:
|
||||
return
|
||||
|
||||
# Set mixed precision to fp8 for HuggingFace Accelerate
|
||||
@@ -124,38 +132,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
||||
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
|
||||
|
||||
# Configure FP8 backend and options
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
backend = getattr(training_args, "fp8_backend", "auto")
|
||||
if backend != "auto":
|
||||
os.environ["FP8_BACKEND"] = backend
|
||||
logger.info_rank0(f"Set FP8_BACKEND={backend}")
|
||||
|
||||
# Create and validate FP8 recipe kwargs (for logging/debugging)
|
||||
fp8_kwargs = create_fp8_kwargs(model_args)
|
||||
fp8_kwargs = create_fp8_kwargs(training_args)
|
||||
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
|
||||
|
||||
# Enable FSDP float8 all-gather optimization if requested
|
||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
||||
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
|
||||
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
|
||||
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
|
||||
|
||||
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
|
||||
|
||||
|
||||
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
||||
def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
|
||||
"""Verify that FP8 training is actually working after model preparation.
|
||||
|
||||
Args:
|
||||
accelerator: The HuggingFace Accelerator instance
|
||||
model_args: Model arguments containing FP8 configuration
|
||||
training_args: Training arguments containing FP8 configuration
|
||||
"""
|
||||
if not model_args.fp8:
|
||||
if not training_args.fp8:
|
||||
return
|
||||
|
||||
# Check Accelerate's FP8 status
|
||||
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
|
||||
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
|
||||
|
||||
backend = getattr(model_args, "fp8_backend", "auto")
|
||||
backend = getattr(training_args, "fp8_backend", "auto")
|
||||
if backend == "torchao" or backend == "auto":
|
||||
logger.info_rank0(
|
||||
"FP8 training enabled with TorchAO backend. For optimal performance, "
|
||||
@@ -169,3 +177,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
||||
|
||||
if not fp8_enabled:
|
||||
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
|
||||
|
||||
|
||||
def patch_accelerator_for_fp8() -> None:
|
||||
"""Patch Accelerator to inject FP8 recipe kwargs.
|
||||
|
||||
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
|
||||
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
|
||||
"""
|
||||
import transformer_engine.pytorch as te
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Guard against multiple patches
|
||||
if getattr(Accelerator, "_te_fp8_patched", False):
|
||||
return
|
||||
|
||||
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
|
||||
if not hasattr(te, "fp8"):
|
||||
te.fp8 = types.ModuleType("fp8")
|
||||
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
|
||||
|
||||
try:
|
||||
from accelerate.utils import TERecipeKwargs as FP8Recipe
|
||||
|
||||
use_te_recipe = True
|
||||
except ImportError:
|
||||
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
|
||||
|
||||
use_te_recipe = False
|
||||
|
||||
original_init = Accelerator.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
|
||||
if use_te_recipe:
|
||||
kwargs["kwargs_handlers"] = [
|
||||
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||
]
|
||||
else:
|
||||
kwargs["kwargs_handlers"] = [
|
||||
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||
]
|
||||
# Only force mixed_precision when we inject handlers
|
||||
kwargs["mixed_precision"] = "fp8"
|
||||
return original_init(self, *args, **kwargs)
|
||||
|
||||
Accelerator.__init__ = patched_init
|
||||
Accelerator._te_fp8_patched = True
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing_extensions import override
|
||||
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
@@ -41,11 +41,13 @@ class CustomTrainer(Trainer):
|
||||
model_args: Optional["ModelArguments"] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Configure FP8 environment if enabled
|
||||
if model_args is not None and model_args.fp8:
|
||||
configure_fp8_environment(model_args)
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
# Configure FP8 environment if enabled
|
||||
training_args = kwargs.get("args")
|
||||
if training_args.fp8:
|
||||
configure_fp8_environment(training_args)
|
||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||
patch_accelerator_for_fp8()
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
@@ -64,9 +66,8 @@ class CustomTrainer(Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
||||
verify_fp8_status(self.accelerator, model_args)
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
|
||||
@@ -29,7 +29,7 @@ from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
@@ -55,13 +55,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Configure FP8 environment if enabled
|
||||
if model_args is not None and model_args.fp8:
|
||||
configure_fp8_environment(model_args)
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
||||
# Configure FP8 environment if enabled
|
||||
training_args = kwargs.get("args")
|
||||
if training_args.fp8:
|
||||
configure_fp8_environment(training_args)
|
||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||
patch_accelerator_for_fp8()
|
||||
|
||||
super().__init__(**kwargs)
|
||||
if processor is not None:
|
||||
@@ -88,9 +88,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
self.compute_loss_func = dft_loss_func
|
||||
|
||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
||||
verify_fp8_status(self.accelerator, model_args)
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
|
||||
Reference in New Issue
Block a user