[model] set mm_projectors for omni models (#10378)

This commit is contained in:
Kingsley
2026-04-10 18:12:57 +08:00
committed by GitHub
parent fa09c01c36
commit c109c061e5
4 changed files with 44 additions and 33 deletions

View File

@@ -125,7 +125,7 @@ def _setup_freeze_tuning(
model_type = getattr(model.config, "model_type", None)
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key)
trainable_layers.extend(COMPOSITE_MODELS[model_type].projector_keys)
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters():

View File

@@ -45,7 +45,7 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "glm4":
elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "glm4v":
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel

View File

@@ -35,7 +35,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output")
if model_type in COMPOSITE_MODELS:
forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)
forbidden_modules.update(COMPOSITE_MODELS[model_type].projector_keys)
if freeze_vision_tower and model_type in COMPOSITE_MODELS:
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)

View File

@@ -39,16 +39,22 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
@dataclass
class CompositeModel:
model_type: str
projector_key: str
projector_keys: list[str]
vision_model_keys: list[str]
language_model_keys: list[str]
lora_conflict_keys: list[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."):
module = getattr(module, key)
return module
def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]:
mm_projectors: list[torch.nn.Module] = []
for projector_key in self.projector_keys:
mm_projector = module
for key in projector_key.split("."):
mm_projector = getattr(mm_projector, key)
mm_projectors.append(mm_projector)
return mm_projectors
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
@@ -56,7 +62,7 @@ COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
def _register_composite_model(
model_type: str,
projector_key: Optional[str] = None,
projector_keys: list[str] | None = None,
vision_model_keys: Optional[list[str]] = None,
language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[list[str]] = None,
@@ -65,7 +71,7 @@ def _register_composite_model(
Args:
model_type: model type
projector_key: multi_modal_projector
projector_keys: multi_modal_projector
vision_model_keys: vision_tower
language_model_keys: language_model
lora_conflict_keys: None
@@ -73,7 +79,7 @@ def _register_composite_model(
"""
COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type,
projector_key=projector_key or "multi_modal_projector",
projector_keys=projector_keys or ["multi_modal_projector"],
vision_model_keys=vision_model_keys or ["vision_tower"],
language_model_keys=language_model_keys or ["language_model", "lm_head"],
lora_conflict_keys=lora_conflict_keys or [],
@@ -136,12 +142,16 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS:
mm_projector = COMPOSITE_MODELS[model_type].get_projector(model)
mm_projectors = COMPOSITE_MODELS[model_type].get_projectors(model)
else:
return
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
logger.info_rank0(
f"Casting multimodal projector outputs in {model_args.compute_dtype}: "
f"{COMPOSITE_MODELS[model_type].projector_keys}."
)
for mm_projector in mm_projectors:
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
def configure_visual_model(config: "PretrainedConfig") -> None:
@@ -166,9 +176,9 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
forbidden_modules.update(vision_model_keys)
if finetuning_args.freeze_multi_modal_projector:
projector_key = COMPOSITE_MODELS[model_type].projector_key
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
forbidden_modules.add(projector_key)
projector_keys = COMPOSITE_MODELS[model_type].projector_keys
logger.info_rank0(f"Set multi model projector not trainable: {projector_keys}.")
forbidden_modules.update(projector_keys)
if finetuning_args.freeze_language_model:
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
@@ -200,7 +210,7 @@ def patch_target_modules(
_register_composite_model(
model_type="dots_ocr",
projector_key="vision_tower.merger",
projector_keys=["vision_tower.merger"],
vision_model_keys=["vision_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["merger"],
@@ -221,6 +231,7 @@ _register_composite_model(
_register_composite_model(
model_type="gemma4",
projector_keys=["embed_vision", "embed_audio"],
vision_model_keys=["vision_tower", "audio_tower"],
lora_conflict_keys=["per_layer_projection_norm"],
)
@@ -229,7 +240,7 @@ _register_composite_model(
# copied from qwen2vl
_register_composite_model(
model_type="glm4v",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -238,7 +249,7 @@ _register_composite_model(
_register_composite_model(
model_type="glm4v_moe",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -247,7 +258,7 @@ _register_composite_model(
_register_composite_model(
model_type="glm_ocr",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -264,7 +275,7 @@ _register_composite_model(
_register_composite_model(
model_type="Keye",
projector_key="mlp_AR",
projector_keys=["mlp_AR"],
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embedding"],
@@ -299,7 +310,7 @@ _register_composite_model(
_register_composite_model(
model_type="minicpmv",
projector_key="resampler",
projector_keys=["resampler"],
vision_model_keys=["vpm"],
language_model_keys=["llm"],
)
@@ -307,7 +318,7 @@ _register_composite_model(
_register_composite_model(
model_type="minicpmo",
projector_key="resampler",
projector_keys=["resampler"],
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
language_model_keys=["llm"],
lora_conflict_keys=["audio_projection_layer"],
@@ -316,7 +327,7 @@ _register_composite_model(
_register_composite_model(
model_type="mistral3",
projector_key="model.multi_modal_projector",
projector_keys=["model.multi_modal_projector"],
)
@@ -339,7 +350,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen2_5_omni_thinker",
projector_key="visual.merger",
projector_keys=["visual.merger", "audio_tower.proj"],
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -348,7 +359,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen2_vl",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -357,7 +368,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen2_5_vl",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -366,7 +377,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_vl",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -375,7 +386,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_vl_moe",
projector_key="visual.merger",
projector_keys=["visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -384,7 +395,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger",
projector_keys=["visual.merger", "audio_tower.proj"],
vision_model_keys=[
"visual.pos_embed",
"visual.patch_embed",
@@ -399,7 +410,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_5",
projector_key="model.visual.merger",
projector_keys=["model.visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
@@ -408,7 +419,7 @@ _register_composite_model(
_register_composite_model(
model_type="qwen3_5_moe",
projector_key="model.visual.merger",
projector_keys=["model.visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],