fix full/freeze tuning for mllm

This commit is contained in:
hiyouga
2024-05-27 20:37:57 +08:00
parent 838f2fb3e4
commit 08564838bd
7 changed files with 76 additions and 61 deletions

View File

@@ -1,9 +1,6 @@
from typing import TYPE_CHECKING, List
import torch
from ...extras.logging import get_logger
from .quantization import QuantizationMethod
if TYPE_CHECKING:
@@ -13,29 +10,28 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
r"""
Finds all available modules to apply lora or galore.
"""
quantization_method = getattr(model, "quantization_method", None)
if quantization_method is None:
linear_cls = torch.nn.Linear
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
import bitsandbytes as bnb
forbidden_modules = {"lm_head"}
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
else:
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
output_layer_names = ["lm_head"]
if model.config.model_type == "chatglm":
output_layer_names.append("output_layer")
forbidden_modules.add("output_layer")
elif model.config.model_type == "internlm2":
output_layer_names.append("output")
forbidden_modules.add("output")
elif model.config.model_type in ["llava", "paligemma"]:
forbidden_modules.add("multi_modal_projector")
if freeze_vision_tower:
forbidden_modules.add("vision_tower")
module_names = set()
for name, module in model.named_modules():
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
if any(forbidden_module in name for forbidden_module in forbidden_modules):
continue
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
module_names.add(name.split(".")[-1])
logger.info("Found linear modules: {}".format(",".join(module_names)))

View File

@@ -35,6 +35,8 @@ class QuantizationMethod(str, Enum):
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple, List
from typing import TYPE_CHECKING, Tuple
import torch
import transformers.models
@@ -82,8 +82,3 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def filter_vision_tower_linear(target_modules: List[str]) -> str:
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
return target_modules