fix llava config

This commit is contained in:
hiyouga
2024-05-12 00:02:49 +08:00
parent 5da097f406
commit b033232aea
5 changed files with 15 additions and 15 deletions

View File

@@ -6,7 +6,7 @@ from ...extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel
from ...hparams import ModelArguments
@@ -14,6 +14,11 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def configure_hidden_size(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava":
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
@@ -22,7 +27,7 @@ def autocast_projector_dtype(
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name):
if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)