[misc] lint (#9710)

This commit is contained in:
Yaowei Zheng
2026-01-04 13:47:56 +08:00
committed by GitHub
parent 9ae62c6fc0
commit 8600530002
6 changed files with 15 additions and 13 deletions

View File

@@ -142,6 +142,7 @@ def _verify_model_args(
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
model_args.use_fast_tokenizer = False model_args.use_fast_tokenizer = False
def _check_extra_dependencies( def _check_extra_dependencies(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",

View File

@@ -94,6 +94,7 @@ class RayArguments:
@dataclass @dataclass
class Fp8Arguments: class Fp8Arguments:
r"""Arguments pertaining to the FP8 training.""" r"""Arguments pertaining to the FP8 training."""
fp8: bool = field( fp8: bool = field(
default=False, default=False,
metadata={ metadata={

View File

@@ -139,14 +139,13 @@ def patch_config(
setattr(config.text_config, "topk_method", "greedy") setattr(config.text_config, "topk_method", "greedy")
architectures = getattr(config, "architectures", None) architectures = getattr(config, "architectures", None)
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
if isinstance(architectures, (list, tuple)) and "InternVLChatModel" in architectures:
raise ValueError( raise ValueError(
"Please download the internvl models in a Hugging Facecompatible format " "Please download the internvl models in a Hugging Facecompatible format "
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)." "(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
) )
if isinstance(architectures, (list, tuple)) and "LlavaLlamaForCausalLM" in architectures: if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):

View File

@@ -93,7 +93,10 @@ def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
return True return True
# Map FSDP all-gather setting if available (this affects the underlying implementation) # Map FSDP all-gather setting if available (this affects the underlying implementation)
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather: if (
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
and training_args.fp8_enable_fsdp_float8_all_gather
):
logger.info_rank0("FSDP float8 all-gather optimization requested") logger.info_rank0("FSDP float8 all-gather optimization requested")
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)] return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]

View File

@@ -19,7 +19,6 @@ import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override from typing_extensions import override
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -28,7 +27,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import ProcessorMixin from transformers import ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
class CustomTrainer(Trainer): class CustomTrainer(Trainer):
@@ -43,7 +42,7 @@ class CustomTrainer(Trainer):
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer") kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
training_args = kwargs.get("args") training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8: if training_args.fp8:
configure_fp8_environment(training_args) configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te": if getattr(training_args, "fp8_backend", "auto") == "te":
@@ -66,7 +65,7 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)
@override @override

View File

@@ -27,7 +27,6 @@ from typing_extensions import override
from ...extras import logging from ...extras import logging
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
@@ -35,10 +34,10 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer, ProcessorMixin from transformers import ProcessorMixin
from transformers.trainer import PredictionOutput from transformers.trainer import PredictionOutput
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@@ -57,7 +56,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
) -> None: ) -> None:
kwargs["processing_class"] = kwargs.pop("tokenizer") kwargs["processing_class"] = kwargs.pop("tokenizer")
# Configure FP8 environment if enabled # Configure FP8 environment if enabled
training_args = kwargs.get("args") training_args: TrainingArguments = kwargs.get("args")
if training_args.fp8: if training_args.fp8:
configure_fp8_environment(training_args) configure_fp8_environment(training_args)
if getattr(training_args, "fp8_backend", "auto") == "te": if getattr(training_args, "fp8_backend", "auto") == "te":
@@ -88,7 +87,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func self.compute_loss_func = dft_loss_func
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)
@override @override