mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
[feat] fp8 training (#8960)
Co-authored-by: Benjamin Feuer <penfever@gmail.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -29,6 +29,7 @@ from ...extras import logging
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
from ..fp8_utils import configure_fp8_environment, verify_fp8_status
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
|
||||
@@ -37,7 +38,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -50,9 +51,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self,
|
||||
finetuning_args: "FinetuningArguments",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
model_args: Optional["ModelArguments"] = None,
|
||||
gen_kwargs: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Configure FP8 environment if enabled
|
||||
if model_args is not None and model_args.fp8:
|
||||
configure_fp8_environment(model_args)
|
||||
if is_transformers_version_greater_than("4.46"):
|
||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||
else:
|
||||
@@ -83,6 +88,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
self.compute_loss_func = dft_loss_func
|
||||
|
||||
# Verify FP8 status after trainer initialization (accelerator should be available)
|
||||
if model_args is not None and model_args.fp8 and hasattr(self, "accelerator"):
|
||||
verify_fp8_status(self.accelerator, model_args)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
|
||||
Reference in New Issue
Block a user