mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-15 01:06:01 +08:00
[model] set mm_projectors for omni models (#10378)
This commit is contained in:
@@ -125,7 +125,7 @@ def _setup_freeze_tuning(
|
||||
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
|
||||
trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key)
|
||||
trainable_layers.extend(COMPOSITE_MODELS[model_type].projector_keys)
|
||||
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
for name, param in model.named_parameters():
|
||||
|
||||
@@ -45,7 +45,7 @@ def apply_liger_kernel(
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
|
||||
elif model_type == "gemma3_text":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
|
||||
elif model_type == "glm4":
|
||||
elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
|
||||
elif model_type == "glm4v":
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel
|
||||
|
||||
@@ -35,7 +35,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
||||
forbidden_modules.add("output")
|
||||
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)
|
||||
forbidden_modules.update(COMPOSITE_MODELS[model_type].projector_keys)
|
||||
|
||||
if freeze_vision_tower and model_type in COMPOSITE_MODELS:
|
||||
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)
|
||||
|
||||
@@ -39,16 +39,22 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
@dataclass
|
||||
class CompositeModel:
|
||||
model_type: str
|
||||
projector_key: str
|
||||
projector_keys: list[str]
|
||||
vision_model_keys: list[str]
|
||||
language_model_keys: list[str]
|
||||
lora_conflict_keys: list[str]
|
||||
|
||||
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
|
||||
for key in self.projector_key.split("."):
|
||||
module = getattr(module, key)
|
||||
|
||||
return module
|
||||
def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]:
|
||||
mm_projectors: list[torch.nn.Module] = []
|
||||
for projector_key in self.projector_keys:
|
||||
mm_projector = module
|
||||
for key in projector_key.split("."):
|
||||
mm_projector = getattr(mm_projector, key)
|
||||
|
||||
mm_projectors.append(mm_projector)
|
||||
|
||||
return mm_projectors
|
||||
|
||||
|
||||
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
||||
@@ -56,7 +62,7 @@ COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
||||
|
||||
def _register_composite_model(
|
||||
model_type: str,
|
||||
projector_key: Optional[str] = None,
|
||||
projector_keys: list[str] | None = None,
|
||||
vision_model_keys: Optional[list[str]] = None,
|
||||
language_model_keys: Optional[list[str]] = None,
|
||||
lora_conflict_keys: Optional[list[str]] = None,
|
||||
@@ -65,7 +71,7 @@ def _register_composite_model(
|
||||
|
||||
Args:
|
||||
model_type: model type
|
||||
projector_key: multi_modal_projector
|
||||
projector_keys: multi_modal_projector
|
||||
vision_model_keys: vision_tower
|
||||
language_model_keys: language_model
|
||||
lora_conflict_keys: None
|
||||
@@ -73,7 +79,7 @@ def _register_composite_model(
|
||||
"""
|
||||
COMPOSITE_MODELS[model_type] = CompositeModel(
|
||||
model_type=model_type,
|
||||
projector_key=projector_key or "multi_modal_projector",
|
||||
projector_keys=projector_keys or ["multi_modal_projector"],
|
||||
vision_model_keys=vision_model_keys or ["vision_tower"],
|
||||
language_model_keys=language_model_keys or ["language_model", "lm_head"],
|
||||
lora_conflict_keys=lora_conflict_keys or [],
|
||||
@@ -136,12 +142,16 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
||||
if getattr(model, "quantization_method", None):
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
mm_projector = COMPOSITE_MODELS[model_type].get_projector(model)
|
||||
mm_projectors = COMPOSITE_MODELS[model_type].get_projectors(model)
|
||||
else:
|
||||
return
|
||||
|
||||
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
logger.info_rank0(
|
||||
f"Casting multimodal projector outputs in {model_args.compute_dtype}: "
|
||||
f"{COMPOSITE_MODELS[model_type].projector_keys}."
|
||||
)
|
||||
for mm_projector in mm_projectors:
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
|
||||
|
||||
def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
@@ -166,9 +176,9 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
forbidden_modules.update(vision_model_keys)
|
||||
|
||||
if finetuning_args.freeze_multi_modal_projector:
|
||||
projector_key = COMPOSITE_MODELS[model_type].projector_key
|
||||
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
|
||||
forbidden_modules.add(projector_key)
|
||||
projector_keys = COMPOSITE_MODELS[model_type].projector_keys
|
||||
logger.info_rank0(f"Set multi model projector not trainable: {projector_keys}.")
|
||||
forbidden_modules.update(projector_keys)
|
||||
|
||||
if finetuning_args.freeze_language_model:
|
||||
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
|
||||
@@ -200,7 +210,7 @@ def patch_target_modules(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="dots_ocr",
|
||||
projector_key="vision_tower.merger",
|
||||
projector_keys=["vision_tower.merger"],
|
||||
vision_model_keys=["vision_tower"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["merger"],
|
||||
@@ -221,6 +231,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="gemma4",
|
||||
projector_keys=["embed_vision", "embed_audio"],
|
||||
vision_model_keys=["vision_tower", "audio_tower"],
|
||||
lora_conflict_keys=["per_layer_projection_norm"],
|
||||
)
|
||||
@@ -229,7 +240,7 @@ _register_composite_model(
|
||||
# copied from qwen2vl
|
||||
_register_composite_model(
|
||||
model_type="glm4v",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -238,7 +249,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="glm4v_moe",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -247,7 +258,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="glm_ocr",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -264,7 +275,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="Keye",
|
||||
projector_key="mlp_AR",
|
||||
projector_keys=["mlp_AR"],
|
||||
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embedding"],
|
||||
@@ -299,7 +310,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="minicpmv",
|
||||
projector_key="resampler",
|
||||
projector_keys=["resampler"],
|
||||
vision_model_keys=["vpm"],
|
||||
language_model_keys=["llm"],
|
||||
)
|
||||
@@ -307,7 +318,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="minicpmo",
|
||||
projector_key="resampler",
|
||||
projector_keys=["resampler"],
|
||||
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
|
||||
language_model_keys=["llm"],
|
||||
lora_conflict_keys=["audio_projection_layer"],
|
||||
@@ -316,7 +327,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="mistral3",
|
||||
projector_key="model.multi_modal_projector",
|
||||
projector_keys=["model.multi_modal_projector"],
|
||||
)
|
||||
|
||||
|
||||
@@ -339,7 +350,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_5_omni_thinker",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger", "audio_tower.proj"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -348,7 +359,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_vl",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -357,7 +368,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen2_5_vl",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -366,7 +377,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -375,7 +386,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_vl_moe",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -384,7 +395,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_omni_moe_thinker",
|
||||
projector_key="visual.merger",
|
||||
projector_keys=["visual.merger", "audio_tower.proj"],
|
||||
vision_model_keys=[
|
||||
"visual.pos_embed",
|
||||
"visual.patch_embed",
|
||||
@@ -399,7 +410,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5",
|
||||
projector_key="model.visual.merger",
|
||||
projector_keys=["model.visual.merger"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
@@ -408,7 +419,7 @@ _register_composite_model(
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5_moe",
|
||||
projector_key="model.visual.merger",
|
||||
projector_keys=["model.visual.merger"],
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
|
||||
Reference in New Issue
Block a user