From 86005300026c94ea30347f746b821a46ccfe651b Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Sun, 4 Jan 2026 13:47:56 +0800 Subject: [PATCH] [misc] lint (#9710) --- src/llamafactory/hparams/parser.py | 1 + src/llamafactory/hparams/training_args.py | 1 + src/llamafactory/model/patcher.py | 5 ++--- src/llamafactory/train/fp8_utils.py | 5 ++++- src/llamafactory/train/pt/trainer.py | 7 +++---- src/llamafactory/train/sft/trainer.py | 9 ++++----- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 695b86dfa..b1ffbb706 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -142,6 +142,7 @@ def _verify_model_args( logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") model_args.use_fast_tokenizer = False + def _check_extra_dependencies( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index 1a67b7f3c..cb34154fa 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -94,6 +94,7 @@ class RayArguments: @dataclass class Fp8Arguments: r"""Arguments pertaining to the FP8 training.""" + fp8: bool = field( default=False, metadata={ diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index d89cd6091..bf811043d 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -139,14 +139,13 @@ def patch_config( setattr(config.text_config, "topk_method", "greedy") architectures = getattr(config, "architectures", None) - - if isinstance(architectures, (list, tuple)) and "InternVLChatModel" in architectures: + if isinstance(architectures, list) and "InternVLChatModel" in architectures: raise ValueError( "Please download the internvl models in a Hugging Face–compatible format " "(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)." ) - if isinstance(architectures, (list, tuple)) and "LlavaLlamaForCausalLM" in architectures: + if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures: raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf") if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"): diff --git a/src/llamafactory/train/fp8_utils.py b/src/llamafactory/train/fp8_utils.py index dfbb4ce6c..33728fead 100644 --- a/src/llamafactory/train/fp8_utils.py +++ b/src/llamafactory/train/fp8_utils.py @@ -93,7 +93,10 @@ def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]: return True # Map FSDP all-gather setting if available (this affects the underlying implementation) - if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather: + if ( + hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") + and training_args.fp8_enable_fsdp_float8_all_gather + ): logger.info_rank0("FSDP float8 all-gather optimization requested") return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)] diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 459d03d09..0a4bef3dd 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -19,7 +19,6 @@ import torch from transformers import Trainer from typing_extensions import override -from ...extras.packages import is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -28,7 +27,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: from transformers import ProcessorMixin - from ...hparams import FinetuningArguments, ModelArguments + from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments class CustomTrainer(Trainer): @@ -43,7 +42,7 @@ class CustomTrainer(Trainer): ) -> None: kwargs["processing_class"] = kwargs.pop("tokenizer") # Configure FP8 environment if enabled - training_args = kwargs.get("args") + training_args: TrainingArguments = kwargs.get("args") if training_args.fp8: configure_fp8_environment(training_args) if getattr(training_args, "fp8_backend", "auto") == "te": @@ -66,7 +65,7 @@ class CustomTrainer(Trainer): self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) - if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization + if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization verify_fp8_status(self.accelerator, training_args) @override diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index d369b462f..0ee389b3c 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -27,7 +27,6 @@ from typing_extensions import override from ...extras import logging from ...extras.constants import IGNORE_INDEX -from ...extras.packages import is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -35,10 +34,10 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: from torch.utils.data import Dataset - from transformers import PreTrainedTokenizer, ProcessorMixin + from transformers import ProcessorMixin from transformers.trainer import PredictionOutput - from ...hparams import FinetuningArguments, ModelArguments + from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments logger = logging.get_logger(__name__) @@ -57,7 +56,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ) -> None: kwargs["processing_class"] = kwargs.pop("tokenizer") # Configure FP8 environment if enabled - training_args = kwargs.get("args") + training_args: TrainingArguments = kwargs.get("args") if training_args.fp8: configure_fp8_environment(training_args) if getattr(training_args, "fp8_backend", "auto") == "te": @@ -88,7 +87,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.compute_loss_func = dft_loss_func - if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization + if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization verify_fp8_status(self.accelerator, training_args) @override