mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-07 22:50:35 +08:00
[fix] fp8: add Transformer Engine backend support (#9705)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -298,23 +298,6 @@ class QuantizationArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||||
)
|
)
|
||||||
fp8: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
|
||||||
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
fp8_backend: str = field(
|
|
||||||
default="auto",
|
|
||||||
metadata={
|
|
||||||
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
fp8_enable_fsdp_float8_all_gather: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -142,15 +142,6 @@ def _verify_model_args(
|
|||||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||||
model_args.use_fast_tokenizer = False
|
model_args.use_fast_tokenizer = False
|
||||||
|
|
||||||
# Validate advanced training features
|
|
||||||
if model_args.fp8 and model_args.quantization_bit is not None:
|
|
||||||
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
|
||||||
|
|
||||||
if model_args.fp8_enable_fsdp_float8_all_gather and not model_args.fp8:
|
|
||||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
|
||||||
model_args.fp8 = True
|
|
||||||
|
|
||||||
|
|
||||||
def _check_extra_dependencies(
|
def _check_extra_dependencies(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
@@ -347,6 +338,9 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
if training_args.deepspeed is not None and (finetuning_args.use_galore or finetuning_args.use_apollo):
|
||||||
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
raise ValueError("GaLore and APOLLO are incompatible with DeepSpeed yet.")
|
||||||
|
|
||||||
|
if training_args.fp8 and training_args.quantization_bit is not None:
|
||||||
|
raise ValueError("FP8 training is not compatible with quantization. Please disable one of them.")
|
||||||
|
|
||||||
if model_args.infer_backend != EngineName.HF:
|
if model_args.infer_backend != EngineName.HF:
|
||||||
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
raise ValueError("vLLM/SGLang backend is only available for API, CLI and Web.")
|
||||||
|
|
||||||
@@ -363,6 +357,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||||
|
|
||||||
|
if training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||||
|
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||||
|
model_args.fp8 = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
training_args.do_train
|
training_args.do_train
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
|
|||||||
@@ -92,7 +92,29 @@ class RayArguments:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
class Fp8Arguments:
|
||||||
|
r"""Arguments pertaining to the FP8 training."""
|
||||||
|
fp8: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Enable FP8 mixed precision training via HuggingFace Accelerate. "
|
||||||
|
"Requires PyTorch 2.7+ and Hopper architecture GPUs."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
fp8_backend: str = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={
|
||||||
|
"help": "FP8 backend to use ('auto', 'torchao', 'te', 'msamp'). 'auto' selects best available backend."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
fp8_enable_fsdp_float8_all_gather: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Enable FP8 optimizations for FSDP2 all-gather operations."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingArguments(Fp8Arguments, RayArguments, BaseTrainingArguments):
|
||||||
r"""Arguments pertaining to the trainer."""
|
r"""Arguments pertaining to the trainer."""
|
||||||
|
|
||||||
overwrite_output_dir: bool = field(
|
overwrite_output_dir: bool = field(
|
||||||
|
|||||||
@@ -12,35 +12,45 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import types
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..hparams import ModelArguments
|
from ..hparams import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
|
||||||
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
|
"""Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_args: Model arguments containing FP8 configuration
|
training_args: Training arguments containing FP8 configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
|
List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
|
||||||
"""
|
"""
|
||||||
if not model_args.fp8:
|
if not training_args.fp8:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
backend = getattr(training_args, "fp8_backend", "auto")
|
||||||
# Check if AORecipeKwargs is available (Accelerate 1.8.0+)
|
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
||||||
from accelerate.utils import AORecipeKwargs
|
|
||||||
|
|
||||||
backend = getattr(model_args, "fp8_backend", "auto")
|
try:
|
||||||
logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")
|
# Use Transformer Engine backend (optimal for Hopper GPUs)
|
||||||
|
if backend == "te":
|
||||||
|
from accelerate.utils import FP8RecipeKwargs
|
||||||
|
|
||||||
|
logger.info_rank0("Using Transformer Engine FP8 backend")
|
||||||
|
return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]
|
||||||
|
|
||||||
|
# Use TorchAO backend (default)
|
||||||
|
from accelerate.utils import AORecipeKwargs
|
||||||
|
|
||||||
# Create Float8LinearConfig if torchao backend is used
|
# Create Float8LinearConfig if torchao backend is used
|
||||||
config = None
|
config = None
|
||||||
@@ -83,7 +93,7 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
||||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
|
||||||
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
||||||
|
|
||||||
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
||||||
@@ -92,19 +102,19 @@ def create_fp8_kwargs(model_args: "ModelArguments") -> list[Any]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def get_fp8_mixed_precision(model_args: "ModelArguments") -> Optional[str]:
|
def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
|
||||||
"""Get the mixed precision setting for Accelerate when using FP8.
|
"""Get the mixed precision setting for Accelerate when using FP8.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_args: Model arguments containing FP8 configuration
|
training_args: Training arguments containing FP8 configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
"fp8" if FP8 is enabled, None otherwise
|
"fp8" if FP8 is enabled, None otherwise
|
||||||
"""
|
"""
|
||||||
return "fp8" if model_args.fp8 else None
|
return "fp8" if training_args.fp8 else None
|
||||||
|
|
||||||
|
|
||||||
def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
def configure_fp8_environment(training_args: "TrainingArguments") -> None:
|
||||||
"""Configure FP8 environment for HuggingFace Accelerate.
|
"""Configure FP8 environment for HuggingFace Accelerate.
|
||||||
|
|
||||||
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
|
FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
|
||||||
@@ -112,11 +122,9 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
|||||||
variables and validates the FP8 configuration.
|
variables and validates the FP8 configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_args: Model arguments containing FP8 configuration
|
training_args: Training arguments containing FP8 configuration
|
||||||
"""
|
"""
|
||||||
import os
|
if not training_args.fp8:
|
||||||
|
|
||||||
if not model_args.fp8:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Set mixed precision to fp8 for HuggingFace Accelerate
|
# Set mixed precision to fp8 for HuggingFace Accelerate
|
||||||
@@ -124,38 +132,38 @@ def configure_fp8_environment(model_args: "ModelArguments") -> None:
|
|||||||
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
|
logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")
|
||||||
|
|
||||||
# Configure FP8 backend and options
|
# Configure FP8 backend and options
|
||||||
backend = getattr(model_args, "fp8_backend", "auto")
|
backend = getattr(training_args, "fp8_backend", "auto")
|
||||||
if backend != "auto":
|
if backend != "auto":
|
||||||
os.environ["FP8_BACKEND"] = backend
|
os.environ["FP8_BACKEND"] = backend
|
||||||
logger.info_rank0(f"Set FP8_BACKEND={backend}")
|
logger.info_rank0(f"Set FP8_BACKEND={backend}")
|
||||||
|
|
||||||
# Create and validate FP8 recipe kwargs (for logging/debugging)
|
# Create and validate FP8 recipe kwargs (for logging/debugging)
|
||||||
fp8_kwargs = create_fp8_kwargs(model_args)
|
fp8_kwargs = create_fp8_kwargs(training_args)
|
||||||
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
|
logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")
|
||||||
|
|
||||||
# Enable FSDP float8 all-gather optimization if requested
|
# Enable FSDP float8 all-gather optimization if requested
|
||||||
if hasattr(model_args, "fp8_enable_fsdp_float8_all_gather") and model_args.fp8_enable_fsdp_float8_all_gather:
|
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
|
||||||
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
|
os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
|
||||||
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
|
logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")
|
||||||
|
|
||||||
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
|
logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")
|
||||||
|
|
||||||
|
|
||||||
def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
|
||||||
"""Verify that FP8 training is actually working after model preparation.
|
"""Verify that FP8 training is actually working after model preparation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
accelerator: The HuggingFace Accelerator instance
|
accelerator: The HuggingFace Accelerator instance
|
||||||
model_args: Model arguments containing FP8 configuration
|
training_args: Training arguments containing FP8 configuration
|
||||||
"""
|
"""
|
||||||
if not model_args.fp8:
|
if not training_args.fp8:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check Accelerate's FP8 status
|
# Check Accelerate's FP8 status
|
||||||
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
|
fp8_enabled = getattr(accelerator, "fp8_enabled", False)
|
||||||
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
|
fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")
|
||||||
|
|
||||||
backend = getattr(model_args, "fp8_backend", "auto")
|
backend = getattr(training_args, "fp8_backend", "auto")
|
||||||
if backend == "torchao" or backend == "auto":
|
if backend == "torchao" or backend == "auto":
|
||||||
logger.info_rank0(
|
logger.info_rank0(
|
||||||
"FP8 training enabled with TorchAO backend. For optimal performance, "
|
"FP8 training enabled with TorchAO backend. For optimal performance, "
|
||||||
@@ -169,3 +177,50 @@ def verify_fp8_status(accelerator, model_args: "ModelArguments") -> None:
|
|||||||
|
|
||||||
if not fp8_enabled:
|
if not fp8_enabled:
|
||||||
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
|
logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")
|
||||||
|
|
||||||
|
|
||||||
|
def patch_accelerator_for_fp8() -> None:
|
||||||
|
"""Patch Accelerator to inject FP8 recipe kwargs.
|
||||||
|
|
||||||
|
This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
|
||||||
|
We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
|
||||||
|
"""
|
||||||
|
import transformer_engine.pytorch as te
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
# Guard against multiple patches
|
||||||
|
if getattr(Accelerator, "_te_fp8_patched", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
|
||||||
|
if not hasattr(te, "fp8"):
|
||||||
|
te.fp8 = types.ModuleType("fp8")
|
||||||
|
te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from accelerate.utils import TERecipeKwargs as FP8Recipe
|
||||||
|
|
||||||
|
use_te_recipe = True
|
||||||
|
except ImportError:
|
||||||
|
from accelerate.utils import FP8RecipeKwargs as FP8Recipe
|
||||||
|
|
||||||
|
use_te_recipe = False
|
||||||
|
|
||||||
|
original_init = Accelerator.__init__
|
||||||
|
|
||||||
|
def patched_init(self, *args, **kwargs):
|
||||||
|
if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
|
||||||
|
if use_te_recipe:
|
||||||
|
kwargs["kwargs_handlers"] = [
|
||||||
|
FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
kwargs["kwargs_handlers"] = [
|
||||||
|
FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
|
||||||
|
]
|
||||||
|
# Only force mixed_precision when we inject handlers
|
||||||
|
kwargs["mixed_precision"] = "fp8"
|
||||||
|
return original_init(self, *args, **kwargs)
|
||||||
|
|
||||||
|
Accelerator.__init__ = patched_init
|
||||||
|
Accelerator._te_fp8_patched = True
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
@@ -41,11 +41,13 @@ class CustomTrainer(Trainer):
|
|||||||
model_args: Optional["ModelArguments"] = None,
|
model_args: Optional["ModelArguments"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
# Configure FP8 environment if enabled
|
# Configure FP8 environment if enabled
|
||||||
if model_args is not None and model_args.fp8:
|
training_args = kwargs.get("args")
|
||||||
configure_fp8_environment(model_args)
|
if training_args.fp8:
|
||||||
if is_transformers_version_greater_than("4.46"):
|
configure_fp8_environment(training_args)
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||||
|
patch_accelerator_for_fp8()
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
@@ -64,9 +66,8 @@ class CustomTrainer(Trainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
verify_fp8_status(self.accelerator, model_args)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from ...extras import logging
|
|||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
|
|
||||||
|
|
||||||
@@ -55,13 +55,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
# Configure FP8 environment if enabled
|
# Configure FP8 environment if enabled
|
||||||
if model_args is not None and model_args.fp8:
|
training_args = kwargs.get("args")
|
||||||
configure_fp8_environment(model_args)
|
if training_args.fp8:
|
||||||
if is_transformers_version_greater_than("4.46"):
|
configure_fp8_environment(training_args)
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||||
else:
|
patch_accelerator_for_fp8()
|
||||||
self.processing_class: PreTrainedTokenizer = kwargs.get("tokenizer")
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
@@ -88,9 +88,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
self.compute_loss_func = dft_loss_func
|
self.compute_loss_func = dft_loss_func
|
||||||
|
|
||||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
verify_fp8_status(self.accelerator, model_args)
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
|
|||||||
Reference in New Issue
Block a user