From e80e50805cd501addada9cdbcd9e55d475b39b6e Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 15 May 2024 16:39:57 +0800 Subject: [PATCH] Update visual.py Former-commit-id: cbeef2aaea0577fd1929e7f156a2b8601b31814e --- src/llmtuner/model/utils/visual.py | 53 ++++++++++++++---------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/src/llmtuner/model/utils/visual.py b/src/llmtuner/model/utils/visual.py index 1f770861..9a5134ff 100644 --- a/src/llmtuner/model/utils/visual.py +++ b/src/llmtuner/model/utils/visual.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING, Tuple import torch -import transformers -from torch import nn +import transformers.models +from transformers.activations import ACT2FN from ...extras.logging import get_logger @@ -16,9 +16,23 @@ if TYPE_CHECKING: logger = get_logger(__name__) -def configure_hidden_size(config: "PretrainedConfig") -> None: - if getattr(config, "model_type", None) == "llava": - setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) +class LlavaMultiModalProjector(torch.nn.Module): + def __init__(self, config: "LlavaConfig"): + super().__init__() + + self.linear_1 = torch.nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_2 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.linear_3 = torch.nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + self.linear_4 = torch.nn.LayerNorm(config.text_config.hidden_size, bias=True) + self.act = ACT2FN[config.projector_hidden_act] + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_3(hidden_states) + hidden_states = self.linear_4(hidden_states) + return hidden_states def autocast_projector_dtype( @@ -35,28 +49,9 @@ def autocast_projector_dtype( mm_projector.register_forward_hook(_mm_projector_forward_post_hook) -class LlavaMultiModalProjectorYiVL(nn.Module): - def __init__(self, config: "LlavaConfig"): - super().__init__() - self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) - self.linear_2 = nn.LayerNorm(config.text_config.hidden_size, bias=True) - self.linear_3 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) - self.linear_4 = nn.LayerNorm(config.text_config.hidden_size, bias=True) - self.act = nn.GELU() +def configure_visual_model(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "llava": + setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) - def forward(self, image_features): - dtype_ = self.linear_1.weight.dtype - hidden_states = self.linear_1(image_features) - hidden_states = self.linear_2(hidden_states) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_3(hidden_states) - hidden_states = self.linear_4(hidden_states) - hidden_states = hidden_states.to(dtype_) - return hidden_states - - -def configure_visual(config: "PretrainedConfig", model_args: "ModelArguments") -> None: - logger = get_logger(__name__) - if model_args.visual_inputs and "Yi" in getattr(config.text_config, "_name_or_path", None): - transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorYiVL - logger.info("Patched Multimodal Projector for Yi-VL.") + if getattr(config, "is_yi_vl_derived_model", None): + transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjector