[misc] lint (#9710)

This commit is contained in:
Yaowei Zheng
2026-01-04 13:47:56 +08:00
committed by GitHub
parent 9ae62c6fc0
commit 8600530002
6 changed files with 15 additions and 13 deletions

View File

@@ -19,7 +19,6 @@ import torch
from transformers import Trainer
from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -28,7 +27,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
class CustomTrainer(Trainer):
@@ -43,7 +42,7 @@ class CustomTrainer(Trainer):
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled
training_args = kwargs.get("args")
training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8:
configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te":
@@ -66,7 +65,7 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@override