mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[feat] fp8 training (#8960)
Co-authored-by: Benjamin Feuer <penfever@gmail.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
		
							parent
							
								
									e2b1594d31
								
							
						
					
					
						commit
						1c44b60e3e
					
				
							
								
								
									
										48
									
								
								examples/extras/fp8/llama3_fp8_deepspeed_sft.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								examples/extras/fp8/llama3_fp8_deepspeed_sft.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,48 @@
 | 
			
		||||
# FP8 training example with DeepSpeed ZeRO-3
 | 
			
		||||
# This config demonstrates FP8 mixed precision training using HuggingFace Accelerate
 | 
			
		||||
# with DeepSpeed providing memory optimization (not FP8 handling)
 | 
			
		||||
 | 
			
		||||
### Model configuration
 | 
			
		||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### Method configuration  
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: full
 | 
			
		||||
 | 
			
		||||
### Dataset configuration
 | 
			
		||||
dataset: identity
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 1024
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
 | 
			
		||||
### Output configuration
 | 
			
		||||
output_dir: saves/llama3-8b/fp8-deepspeed/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
 | 
			
		||||
### Training configuration
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 5.0e-5
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
 | 
			
		||||
### FP8 configuration
 | 
			
		||||
fp8: true
 | 
			
		||||
fp8_backend: torchao  # Use TorchAO backend for FP8
 | 
			
		||||
fp8_enable_fsdp_float8_all_gather: false  # Not used with DeepSpeed
 | 
			
		||||
 | 
			
		||||
### DeepSpeed configuration  
 | 
			
		||||
deepspeed: examples/deepspeed/ds_z3_fp8_config.json
 | 
			
		||||
 | 
			
		||||
### Logging configuration
 | 
			
		||||
report_to: wandb
 | 
			
		||||
run_name: llama3_fp8_deepspeed_sft
 | 
			
		||||
							
								
								
									
										51
									
								
								examples/extras/fp8/llama3_fp8_fsdp_sft.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								examples/extras/fp8/llama3_fp8_fsdp_sft.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,51 @@
 | 
			
		||||
# FP8 training example with FSDP 
 | 
			
		||||
# This config demonstrates FP8 mixed precision training using HuggingFace Accelerate
 | 
			
		||||
# with FSDP for distributed training and float8 all-gather optimization
 | 
			
		||||
 | 
			
		||||
### Model configuration
 | 
			
		||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
trust_remote_code: true
 | 
			
		||||
 | 
			
		||||
### Method configuration
 | 
			
		||||
stage: sft
 | 
			
		||||
do_train: true
 | 
			
		||||
finetuning_type: full
 | 
			
		||||
 | 
			
		||||
### Dataset configuration
 | 
			
		||||
dataset: identity
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 1024
 | 
			
		||||
max_samples: 1000
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
 | 
			
		||||
### Output configuration
 | 
			
		||||
output_dir: saves/llama3-8b/fp8-fsdp/sft
 | 
			
		||||
logging_steps: 10
 | 
			
		||||
save_steps: 500
 | 
			
		||||
plot_loss: true
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
 | 
			
		||||
### Training configuration
 | 
			
		||||
per_device_train_batch_size: 1
 | 
			
		||||
gradient_accumulation_steps: 8
 | 
			
		||||
learning_rate: 5.0e-5
 | 
			
		||||
num_train_epochs: 3.0
 | 
			
		||||
lr_scheduler_type: cosine
 | 
			
		||||
warmup_ratio: 0.1
 | 
			
		||||
bf16: true
 | 
			
		||||
 | 
			
		||||
### FP8 configuration
 | 
			
		||||
fp8: true
 | 
			
		||||
fp8_backend: torchao  # Use TorchAO backend for FP8
 | 
			
		||||
fp8_enable_fsdp_float8_all_gather: true  # Enable FSDP2 float8 all-gather optimization
 | 
			
		||||
 | 
			
		||||
### FSDP configuration (using training arguments - no separate FSDP config file)
 | 
			
		||||
fsdp:
 | 
			
		||||
  - full_shard
 | 
			
		||||
  - auto_wrap
 | 
			
		||||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
 | 
			
		||||
 | 
			
		||||
### Logging configuration
 | 
			
		||||
report_to: wandb
 | 
			
		||||
run_name: llama3_fp8_fsdp_sft
 | 
			
		||||
							
								
								
									
										3
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
									
									
									
									
								
							@ -70,6 +70,9 @@ extra_require = {
 | 
			
		||||
    ],
 | 
			
		||||
    "openmind": ["openmind"],
 | 
			
		||||
    "swanlab": ["swanlab"],
 | 
			
		||||
    "fp8": ["torchao>=0.8.0", "accelerate>=1.10.0"],
 | 
			
		||||
    "fp8-te": ["transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
 | 
			
		||||
    "fp8-all": ["torchao>=0.8.0", "transformer_engine[pytorch]>=2.0.0", "accelerate>=1.10.0"],
 | 
			
		||||
    "dev": ["pre-commit", "ruff", "pytest", "build"],
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -213,6 +213,23 @@ class QuantizationArguments:
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
 | 
			
		||||
    )
 | 
			
		||||
    fp8: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
 | 
			
		||||
            "Requires PyTorch 2.7+ and Hopper architecture GPUs."
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    fp8_backend: str = field(
 | 
			
		||||
        default="auto",
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    fp8_enable_fsdp_float8_all_gather: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
 | 
			
		||||
@ -131,6 +131,14 @@ def _verify_model_args(
 | 
			
		||||
        logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
 | 
			
		||||
        model_args.use_fast_tokenizer = False
 | 
			
		||||
 | 
			
		||||
    # Validate advanced training features
 | 
			
		||||
    if model_args.fp8 and model_args.quantization_bit is not None:
 | 
			
		||||
        raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
 | 
			
		||||
 | 
			
		||||
    if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
 | 
			
		||||
        logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
 | 
			
		||||
        model_args.fp8 = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_extra_dependencies(
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										171
									
								
								src/llamafactory/train/fp8_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										171
									
								
								src/llamafactory/train/fp8_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,171 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional
 | 
			
		||||
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from ..hparams import ModelArguments
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
 | 
			
		||||
    """Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        model_args: Model arguments containing FP8 configuration
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
 | 
			
		||||
    """
 | 
			
		||||
    if not model_args.fp8:
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        # Check if AORecipeKwargs is available (Accelerate 1.8.0+)
 | 
			
		||||
        from accelerate.utils import AORecipeKwargs
 | 
			
		||||
 | 
			
		||||
        backend = getattr(model_args, "fp8_backend", "auto")
 | 
			
		||||
        logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
 | 
			
		||||
 | 
			
		||||
        # Create Float8LinearConfig if torchao backend is used
 | 
			
		||||
        config = None
 | 
			
		||||
        if backend == "torchao" or backend == "auto":
 | 
			
		||||
            from torchao.float8 import Float8LinearConfig
 | 
			
		||||
 | 
			
		||||
            # Use rowwise scaling for better performance (as recommended by torchao)
 | 
			
		||||
            # Configure alignment requirements for FP8 kernels
 | 
			
		||||
            config = Float8LinearConfig.from_recipe_name("rowwise")
 | 
			
		||||
 | 
			
		||||
            # Enable alignment for better kernel performance
 | 
			
		||||
            if hasattr(config, "enable_amax_init"):
 | 
			
		||||
                config.enable_amax_init = True
 | 
			
		||||
            if hasattr(config, "enable_pre_and_post_forward"):
 | 
			
		||||
                config.enable_pre_and_post_forward = True
 | 
			
		||||
 | 
			
		||||
        # Create module filter function to skip problematic layers
 | 
			
		||||
        # TorchAO FP8 requires dimensions divisible by 16 for optimal kernels
 | 
			
		||||
        def module_filter_func(module, layer_name):
 | 
			
		||||
            # Skip embedding and output layers for numerical stability
 | 
			
		||||
            skip_layers = ["embed", "lm_head", "output", "classifier"]
 | 
			
		||||
            if any(skip_name in layer_name.lower() for skip_name in skip_layers):
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
            # Only convert Linear layers
 | 
			
		||||
            if not (hasattr(module, "weight") and len(module.weight.shape) == 2):
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
            # Check dimension alignment for FP8 kernels
 | 
			
		||||
            weight = module.weight
 | 
			
		||||
            in_features, out_features = weight.shape[1], weight.shape[0]
 | 
			
		||||
 | 
			
		||||
            # Skip layers with dimensions not divisible by 16 to avoid kernel errors
 | 
			
		||||
            if in_features % 16 != 0 or out_features % 16 != 0:
 | 
			
		||||
                logger.debug(
 | 
			
		||||
                    f"Skipping layer {layer_name} with dimensions {out_features}x{in_features} (not divisible by 16)"
 | 
			
		||||
                )
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        # Map FSDP all-gather setting if available (this affects the underlying implementation)
 | 
			
		||||
        if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
 | 
			
		||||
            logger.info_rank0("FSDP float8 all-gather optimization requested")
 | 
			
		||||
 | 
			
		||||
        return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.info_rank0(f"Failed to create FP8 configuration: {e}")
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
 | 
			
		||||
    """Get the mixed precision setting for Accelerate when using FP8.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        model_args: Model arguments containing FP8 configuration
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        "fp8" if FP8 is enabled, None otherwise
 | 
			
		||||
    """
 | 
			
		||||
    return "fp8" if model_args.fp8 else None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def configure_fp8_environment(model_args: "ModelArguments") -> None:
 | 
			
		||||
    """Configure FP8 environment for HuggingFace Accelerate.
 | 
			
		||||
 | 
			
		||||
    FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
 | 
			
		||||
    DeepSpeed or FSDP is used for distributed training. This function sets up the environment
 | 
			
		||||
    variables and validates the FP8 configuration.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        model_args: Model arguments containing FP8 configuration
 | 
			
		||||
    """
 | 
			
		||||
    import os
 | 
			
		||||
 | 
			
		||||
    if not model_args.fp8:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    # Set mixed precision to fp8 for HuggingFace Accelerate
 | 
			
		||||
    os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
 | 
			
		||||
    logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
 | 
			
		||||
 | 
			
		||||
    # Configure FP8 backend and options
 | 
			
		||||
    backend = getattr(model_args, "fp8_backend", "auto")
 | 
			
		||||
    if backend != "auto":
 | 
			
		||||
        os.environ["FP8_BACKEND"] = backend
 | 
			
		||||
        logger.info_rank0(f"Set FP8_BACKEND={backend}")
 | 
			
		||||
 | 
			
		||||
    # Create and validate FP8 recipe kwargs (for logging/debugging)
 | 
			
		||||
    fp8_kwargs = create_fp8_kwargs(model_args)
 | 
			
		||||
    logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
 | 
			
		||||
 | 
			
		||||
    # Enable FSDP float8 all-gather optimization if requested
 | 
			
		||||
    if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
 | 
			
		||||
        os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
 | 
			
		||||
        logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
 | 
			
		||||
 | 
			
		||||
    logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
 | 
			
		||||
    """Verify that FP8 training is actually working after model preparation.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        accelerator: The HuggingFace Accelerator instance
 | 
			
		||||
        model_args: Model arguments containing FP8 configuration
 | 
			
		||||
    """
 | 
			
		||||
    if not model_args.fp8:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    # Check Accelerate's FP8 status
 | 
			
		||||
    fp8_enabled = getattr(accelerator, "fp8_enabled", False)
 | 
			
		||||
    fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
 | 
			
		||||
 | 
			
		||||
    backend = getattr(model_args, "fp8_backend", "auto")
 | 
			
		||||
    if backend == "torchao" or backend == "auto":
 | 
			
		||||
        logger.info_rank0(
 | 
			
		||||
            "FP8 training enabled with TorchAO backend. For optimal performance, "
 | 
			
		||||
            "ensure model layer dimensions are mostly divisible by 16. "
 | 
			
		||||
            "If you encounter issues, try fp8_backend='te' with Transformer Engine."
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        logger.info_rank0(f"FP8 training enabled with {backend} backend.")
 | 
			
		||||
 | 
			
		||||
    logger.info_rank0(f"Accelerate FP8 status - enabled: {fp8_enabled}, backend: {fp8_backend_type}")
 | 
			
		||||
 | 
			
		||||
    if not fp8_enabled:
 | 
			
		||||
        logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
 | 
			
		||||
@ -21,21 +21,29 @@ from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.packages import is_transformers_version_greater_than
 | 
			
		||||
from ..callbacks import SaveProcessorCallback
 | 
			
		||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import FinetuningArguments
 | 
			
		||||
    from ...hparams import FinetuningArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CustomTrainer(Trainer):
 | 
			
		||||
    r"""Inherit Trainer for custom optimizer."""
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
 | 
			
		||||
        self,
 | 
			
		||||
        finetuning_args: "FinetuningArguments",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        model_args: Optional["ModelArguments"] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        # Configure FP8 environment if enabled
 | 
			
		||||
        if model_args is not None and model_args.fp8:
 | 
			
		||||
            configure_fp8_environment(model_args)
 | 
			
		||||
        if is_transformers_version_greater_than("4.46"):
 | 
			
		||||
            kwargs["processing_class"] = kwargs.pop("tokenizer")
 | 
			
		||||
 | 
			
		||||
@ -56,6 +64,10 @@ class CustomTrainer(Trainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
        # Verify FP8 status after trainer initialization (accelerator should be available)
 | 
			
		||||
        if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
 | 
			
		||||
            verify_fp8_status(self.accelerator, model_args)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,7 @@ from ...extras import logging
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.packages import is_transformers_version_greater_than
 | 
			
		||||
from ..callbacks import SaveProcessorCallback
 | 
			
		||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -37,7 +38,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.trainer import PredictionOutput
 | 
			
		||||
 | 
			
		||||
    from ...hparams import FinetuningArguments
 | 
			
		||||
    from ...hparams import FinetuningArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
@ -50,9 +51,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
        self,
 | 
			
		||||
        finetuning_args: "FinetuningArguments",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        model_args: Optional["ModelArguments"] = None,
 | 
			
		||||
        gen_kwargs: Optional[dict[str, Any]] = None,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        # Configure FP8 environment if enabled
 | 
			
		||||
        if model_args is not None and model_args.fp8:
 | 
			
		||||
            configure_fp8_environment(model_args)
 | 
			
		||||
        if is_transformers_version_greater_than("4.46"):
 | 
			
		||||
            kwargs["processing_class"] = kwargs.pop("tokenizer")
 | 
			
		||||
        else:
 | 
			
		||||
@ -83,6 +88,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
 | 
			
		||||
            self.compute_loss_func = dft_loss_func
 | 
			
		||||
 | 
			
		||||
        # Verify FP8 status after trainer initialization (accelerator should be available)
 | 
			
		||||
        if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
 | 
			
		||||
            verify_fp8_status(self.accelerator, model_args)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user