mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-20 13:50:35 +08:00
fix full/freeze tuning for mllm
This commit is contained in:
@@ -1,9 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from .quantization import QuantizationMethod
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -13,29 +10,28 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
|
||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||
r"""
|
||||
Finds all available modules to apply lora or galore.
|
||||
"""
|
||||
quantization_method = getattr(model, "quantization_method", None)
|
||||
if quantization_method is None:
|
||||
linear_cls = torch.nn.Linear
|
||||
elif quantization_method == QuantizationMethod.BITS_AND_BYTES:
|
||||
import bitsandbytes as bnb
|
||||
forbidden_modules = {"lm_head"}
|
||||
|
||||
linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
|
||||
|
||||
output_layer_names = ["lm_head"]
|
||||
if model.config.model_type == "chatglm":
|
||||
output_layer_names.append("output_layer")
|
||||
forbidden_modules.add("output_layer")
|
||||
elif model.config.model_type == "internlm2":
|
||||
output_layer_names.append("output")
|
||||
forbidden_modules.add("output")
|
||||
elif model.config.model_type in ["llava", "paligemma"]:
|
||||
forbidden_modules.add("multi_modal_projector")
|
||||
|
||||
if freeze_vision_tower:
|
||||
forbidden_modules.add("vision_tower")
|
||||
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, linear_cls) and not any(output_layer in name for output_layer in output_layer_names):
|
||||
if any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||
continue
|
||||
|
||||
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||
|
||||
@@ -35,6 +35,8 @@ class QuantizationMethod(str, Enum):
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
QUANTO = "quanto"
|
||||
EETQ = "eetq"
|
||||
HQQ = "hqq"
|
||||
|
||||
|
||||
def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Tuple, List
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
import transformers.models
|
||||
@@ -82,8 +82,3 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||
|
||||
|
||||
def filter_vision_tower_linear(target_modules: List[str]) -> str:
|
||||
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
|
||||
return target_modules
|
||||
|
||||
Reference in New Issue
Block a user