From 88159688bb25bd8cce725357dd9d7bdd473dc9aa Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 12 May 2024 00:02:49 +0800 Subject: [PATCH] fix llava config Former-commit-id: b033232aeaa1890ec6946387608aad4779a7ba10 --- src/llmtuner/model/adapter.py | 3 +++ src/llmtuner/model/loader.py | 2 +- src/llmtuner/model/patcher.py | 9 +++------ src/llmtuner/model/utils/valuehead.py | 7 +------ src/llmtuner/model/utils/visual.py | 9 +++++++-- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index d43e00f0..0ffb91c1 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -46,6 +46,9 @@ def init_adapter( if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): model = model.float() + if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model + model.vision_tower.requires_grad_(False) + if finetuning_args.finetuning_type == "freeze" and is_trainable: logger.info("Fine-tuning method: Freeze") num_layers = ( diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index ead6178f..ea55de27 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -106,7 +106,7 @@ def load_model( """ init_kwargs = _get_init_kwargs(model_args) config = load_config(model_args) - patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead) + patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) model = None lazy_load = False diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 31cba492..fd99bd3b 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -15,8 +15,8 @@ from .utils.longlora import configure_longlora from .utils.moe import add_z3_leaf_module, configure_moe from .utils.quantization import configure_quantization from .utils.rope import configure_rope -from .utils.valuehead import configure_valuehead, prepare_valuehead_model -from .utils.visual import autocast_projector_dtype +from .utils.valuehead import prepare_valuehead_model +from .utils.visual import autocast_projector_dtype, configure_hidden_size if TYPE_CHECKING: @@ -40,7 +40,6 @@ def patch_config( model_args: "ModelArguments", init_kwargs: Dict[str, Any], is_trainable: bool, - add_valuehead: bool, ) -> None: if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) @@ -50,9 +49,7 @@ def patch_config( configure_longlora(config, model_args, is_trainable) configure_quantization(config, tokenizer, model_args, init_kwargs) configure_moe(config, model_args, is_trainable) - - if add_valuehead: - configure_valuehead(config) + configure_hidden_size(config) if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) diff --git a/src/llmtuner/model/utils/valuehead.py b/src/llmtuner/model/utils/valuehead.py index a6180753..d813729e 100644 --- a/src/llmtuner/model/utils/valuehead.py +++ b/src/llmtuner/model/utils/valuehead.py @@ -8,7 +8,7 @@ from ...extras.logging import get_logger if TYPE_CHECKING: - from transformers import PretrainedConfig, PreTrainedModel + from transformers import PreTrainedModel from ...hparams import ModelArguments @@ -16,11 +16,6 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def configure_valuehead(config: "PretrainedConfig") -> None: - if getattr(config, "model_type", None) == "llava": - setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None)) - - def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: r""" Loads value head parameters from Hugging Face Hub or local disk. diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index cb51301b..b29a9ba5 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -6,7 +6,7 @@ from ...extras.logging import get_logger if TYPE_CHECKING: - from transformers import PreTrainedModel + from transformers import PretrainedConfig, PreTrainedModel from ...hparams import ModelArguments @@ -14,6 +14,11 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def configure_hidden_size(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "llava": + setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) + + def autocast_projector_dtype( model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector" ) -> None: @@ -22,7 +27,7 @@ def autocast_projector_dtype( ) -> "torch.Tensor": return output.to(model_args.compute_dtype) - if hasattr(model, mm_projector_name): + if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None): logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector.register_forward_hook(_mm_projector_forward_post_hook)