[misc] lint (#9710)

This commit is contained in:
Yaowei Zheng
2026-01-04 13:47:56 +08:00
committed by GitHub
parent 9ae62c6fc0
commit 8600530002
6 changed files with 15 additions and 13 deletions

View File

@@ -142,6 +142,7 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",

View File

@@ -94,6 +94,7 @@ class RayArguments:
@dataclass
class Fp8Arguments:
r"""Arguments pertaining to the FP8 training."""
fp8: bool = field(
default=False,
metadata={

View File

@@ -139,14 +139,13 @@ def patch_config(
setattr(config.text_config, "topk_method", "greedy")
architectures = getattr(config, "architectures", None)
if isinstance(architectures, (list, tuple)) and "InternVLChatModel" in architectures:
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
raise ValueError(
"Please download the internvl models in a Hugging Facecompatible format "
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
)
if isinstance(architectures, (list, tuple)) and "LlavaLlamaForCausalLM" in architectures:
if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):

View File

@@ -93,7 +93,10 @@ def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
return True
# Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
if (
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
and training_args.fp8_enable_fsdp_float8_all_gather
):
logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]

View File

@@ -19,7 +19,6 @@ import torch
from transformers import Trainer
from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -28,7 +27,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from transformers import ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
class CustomTrainer(Trainer):
@@ -43,7 +42,7 @@ class CustomTrainer(Trainer):
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled
training_args = kwargs.get("args")
training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8:
configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te":
@@ -66,7 +65,7 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@override

View File

@@ -27,7 +27,6 @@ from typing_extensions import override
from ...extras import logging
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -35,10 +34,10 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING:
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin
from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments, ModelArguments
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
logger = logging.get_logger(__name__)
@@ -57,7 +56,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled
training_args = kwargs.get("args")
training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8:
configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te":
@@ -88,7 +87,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)
@override