mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-12 17:10:36 +08:00
[misc] lint (#9710)
This commit is contained in:
@@ -142,6 +142,7 @@ def _verify_model_args(
|
|||||||
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||||
model_args.use_fast_tokenizer = False
|
model_args.use_fast_tokenizer = False
|
||||||
|
|
||||||
|
|
||||||
def _check_extra_dependencies(
|
def _check_extra_dependencies(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ class RayArguments:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Fp8Arguments:
|
class Fp8Arguments:
|
||||||
r"""Arguments pertaining to the FP8 training."""
|
r"""Arguments pertaining to the FP8 training."""
|
||||||
|
|
||||||
fp8: bool = field(
|
fp8: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={
|
||||||
|
|||||||
@@ -139,14 +139,13 @@ def patch_config(
|
|||||||
setattr(config.text_config, "topk_method", "greedy")
|
setattr(config.text_config, "topk_method", "greedy")
|
||||||
|
|
||||||
architectures = getattr(config, "architectures", None)
|
architectures = getattr(config, "architectures", None)
|
||||||
|
if isinstance(architectures, list) and "InternVLChatModel" in architectures:
|
||||||
if isinstance(architectures, (list, tuple)) and "InternVLChatModel" in architectures:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Please download the internvl models in a Hugging Face–compatible format "
|
"Please download the internvl models in a Hugging Face–compatible format "
|
||||||
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
|
"(for example, https://huggingface.co/OpenGVLab/InternVL3-8B-hf)."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(architectures, (list, tuple)) and "LlavaLlamaForCausalLM" in architectures:
|
if isinstance(architectures, list) and "LlavaLlamaForCausalLM" in architectures:
|
||||||
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
raise ValueError("Please download llava models with hf-compatible format: https://huggingface.co/llava-hf")
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
if getattr(config, "model_type", None) == "internlm3" and not is_transformers_version_greater_than("4.47.1"):
|
||||||
|
|||||||
@@ -93,7 +93,10 @@ def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
# Map FSDP all-gather setting if available (this affects the underlying implementation)
|
||||||
if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
|
if (
|
||||||
|
hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
|
||||||
|
and training_args.fp8_enable_fsdp_float8_all_gather
|
||||||
|
):
|
||||||
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
logger.info_rank0("FSDP float8 all-gather optimization requested")
|
||||||
|
|
||||||
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import torch
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
@@ -28,7 +27,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments, ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
@@ -43,7 +42,7 @@ class CustomTrainer(Trainer):
|
|||||||
) -> None:
|
) -> None:
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
# Configure FP8 environment if enabled
|
# Configure FP8 environment if enabled
|
||||||
training_args = kwargs.get("args")
|
training_args: TrainingArguments = kwargs.get("args")
|
||||||
if training_args.fp8:
|
if training_args.fp8:
|
||||||
configure_fp8_environment(training_args)
|
configure_fp8_environment(training_args)
|
||||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||||
@@ -66,7 +65,7 @@ class CustomTrainer(Trainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
verify_fp8_status(self.accelerator, training_args)
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
from ..fp8_utils import configure_fp8_environment, patch_accelerator_for_fp8, verify_fp8_status
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
@@ -35,10 +34,10 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments, ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments, TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -57,7 +56,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
) -> None:
|
) -> None:
|
||||||
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
kwargs["processing_class"] = kwargs.pop("tokenizer")
|
||||||
# Configure FP8 environment if enabled
|
# Configure FP8 environment if enabled
|
||||||
training_args = kwargs.get("args")
|
training_args: TrainingArguments = kwargs.get("args")
|
||||||
if training_args.fp8:
|
if training_args.fp8:
|
||||||
configure_fp8_environment(training_args)
|
configure_fp8_environment(training_args)
|
||||||
if getattr(training_args, "fp8_backend", "auto") == "te":
|
if getattr(training_args, "fp8_backend", "auto") == "te":
|
||||||
@@ -88,7 +87,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
self.compute_loss_func = dft_loss_func
|
self.compute_loss_func = dft_loss_func
|
||||||
|
|
||||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
verify_fp8_status(self.accelerator, training_args)
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|||||||
Reference in New Issue
Block a user