diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 708db9429..a5850703d 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -125,7 +125,7 @@ def _setup_freeze_tuning( model_type = getattr(model.config, "model_type", None) if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS: - trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key) + trainable_layers.extend(COMPOSITE_MODELS[model_type].projector_keys) forbidden_modules = get_forbidden_modules(model.config, finetuning_args) for name, param in model.named_parameters(): diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index c6dec9900..658960e42 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -45,7 +45,7 @@ def apply_liger_kernel( from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel elif model_type == "gemma3_text": from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel - elif model_type == "glm4": + elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel elif model_type == "glm4v": from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index b0249b47c..e0a62566c 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -35,7 +35,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) forbidden_modules.add("output") if model_type in COMPOSITE_MODELS: - forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key) + forbidden_modules.update(COMPOSITE_MODELS[model_type].projector_keys) if freeze_vision_tower and model_type in COMPOSITE_MODELS: forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index df3eaa20c..28128c918 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -39,16 +39,22 @@ transformers_logger = transformers.utils.logging.get_logger(__name__) @dataclass class CompositeModel: model_type: str - projector_key: str + projector_keys: list[str] vision_model_keys: list[str] language_model_keys: list[str] lora_conflict_keys: list[str] - def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module": - for key in self.projector_key.split("."): - module = getattr(module, key) - return module + def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]: + mm_projectors: list[torch.nn.Module] = [] + for projector_key in self.projector_keys: + mm_projector = module + for key in projector_key.split("."): + mm_projector = getattr(mm_projector, key) + + mm_projectors.append(mm_projector) + + return mm_projectors COMPOSITE_MODELS: dict[str, "CompositeModel"] = {} @@ -56,7 +62,7 @@ COMPOSITE_MODELS: dict[str, "CompositeModel"] = {} def _register_composite_model( model_type: str, - projector_key: Optional[str] = None, + projector_keys: list[str] | None = None, vision_model_keys: Optional[list[str]] = None, language_model_keys: Optional[list[str]] = None, lora_conflict_keys: Optional[list[str]] = None, @@ -65,7 +71,7 @@ def _register_composite_model( Args: model_type: model type - projector_key: multi_modal_projector + projector_keys: multi_modal_projector vision_model_keys: vision_tower language_model_keys: language_model lora_conflict_keys: None @@ -73,7 +79,7 @@ def _register_composite_model( """ COMPOSITE_MODELS[model_type] = CompositeModel( model_type=model_type, - projector_key=projector_key or "multi_modal_projector", + projector_keys=projector_keys or ["multi_modal_projector"], vision_model_keys=vision_model_keys or ["vision_tower"], language_model_keys=language_model_keys or ["language_model", "lm_head"], lora_conflict_keys=lora_conflict_keys or [], @@ -136,12 +142,16 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen if getattr(model, "quantization_method", None): model_type = getattr(model.config, "model_type", None) if model_type in COMPOSITE_MODELS: - mm_projector = COMPOSITE_MODELS[model_type].get_projector(model) + mm_projectors = COMPOSITE_MODELS[model_type].get_projectors(model) else: return - logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.") - mm_projector.register_forward_hook(_mm_projector_forward_post_hook) + logger.info_rank0( + f"Casting multimodal projector outputs in {model_args.compute_dtype}: " + f"{COMPOSITE_MODELS[model_type].projector_keys}." + ) + for mm_projector in mm_projectors: + mm_projector.register_forward_hook(_mm_projector_forward_post_hook) def configure_visual_model(config: "PretrainedConfig") -> None: @@ -166,9 +176,9 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni forbidden_modules.update(vision_model_keys) if finetuning_args.freeze_multi_modal_projector: - projector_key = COMPOSITE_MODELS[model_type].projector_key - logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.") - forbidden_modules.add(projector_key) + projector_keys = COMPOSITE_MODELS[model_type].projector_keys + logger.info_rank0(f"Set multi model projector not trainable: {projector_keys}.") + forbidden_modules.update(projector_keys) if finetuning_args.freeze_language_model: language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys @@ -200,7 +210,7 @@ def patch_target_modules( _register_composite_model( model_type="dots_ocr", - projector_key="vision_tower.merger", + projector_keys=["vision_tower.merger"], vision_model_keys=["vision_tower"], language_model_keys=["model", "lm_head"], lora_conflict_keys=["merger"], @@ -221,6 +231,7 @@ _register_composite_model( _register_composite_model( model_type="gemma4", + projector_keys=["embed_vision", "embed_audio"], vision_model_keys=["vision_tower", "audio_tower"], lora_conflict_keys=["per_layer_projection_norm"], ) @@ -229,7 +240,7 @@ _register_composite_model( # copied from qwen2vl _register_composite_model( model_type="glm4v", - projector_key="visual.merger", + projector_keys=["visual.merger"], vision_model_keys=["visual.patch_embed", "visual.blocks"], language_model_keys=["language_model", "lm_head"], lora_conflict_keys=["patch_embed"], @@ -238,7 +249,7 @@ _register_composite_model( _register_composite_model( model_type="glm4v_moe", - projector_key="visual.merger", + projector_keys=["visual.merger"], vision_model_keys=["visual.patch_embed", "visual.blocks"], language_model_keys=["language_model", "lm_head"], lora_conflict_keys=["patch_embed"], @@ -247,7 +258,7 @@ _register_composite_model( _register_composite_model( model_type="glm_ocr", - projector_key="visual.merger", + projector_keys=["visual.merger"], vision_model_keys=["visual.patch_embed", "visual.blocks"], language_model_keys=["language_model", "lm_head"], lora_conflict_keys=["patch_embed"], @@ -264,7 +275,7 @@ _register_composite_model( _register_composite_model( model_type="Keye", - projector_key="mlp_AR", + projector_keys=["mlp_AR"], vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"], language_model_keys=["model", "lm_head"], lora_conflict_keys=["patch_embedding"], @@ -299,7 +310,7 @@ _register_composite_model( _register_composite_model( model_type="minicpmv", - projector_key="resampler", + projector_keys=["resampler"], vision_model_keys=["vpm"], language_model_keys=["llm"], ) @@ -307,7 +318,7 @@ _register_composite_model( _register_composite_model( model_type="minicpmo", - projector_key="resampler", + projector_keys=["resampler"], vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"], language_model_keys=["llm"], lora_conflict_keys=["audio_projection_layer"], @@ -316,7 +327,7 @@ _register_composite_model( _register_composite_model( model_type="mistral3", - projector_key="model.multi_modal_projector", + projector_keys=["model.multi_modal_projector"], ) @@ -339,7 +350,7 @@ _register_composite_model( _register_composite_model( model_type="qwen2_5_omni_thinker", - projector_key="visual.merger", + projector_keys=["visual.merger", "audio_tower.proj"], vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"], language_model_keys=["model", "lm_head"], lora_conflict_keys=["patch_embed"], @@ -348,7 +359,7 @@ _register_composite_model( _register_composite_model( model_type="qwen2_vl", - projector_key="visual.merger", + projector_keys=["visual.merger"], vision_model_keys=["visual.patch_embed", "visual.blocks"], language_model_keys=["language_model", "lm_head"], lora_conflict_keys=["patch_embed"], @@ -357,7 +368,7 @@ _register_composite_model( _register_composite_model( model_type="qwen2_5_vl", - projector_key="visual.merger", + projector_keys=["visual.merger"], vision_model_keys=["visual.patch_embed", "visual.blocks"], language_model_keys=["language_model", "lm_head"], lora_conflict_keys=["patch_embed"], @@ -366,7 +377,7 @@ _register_composite_model( _register_composite_model( model_type="qwen3_vl", - projector_key="visual.merger", + projector_keys=["visual.merger"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"], language_model_keys=["language_model", "lm_head"], lora_conflict_keys=["patch_embed"], @@ -375,7 +386,7 @@ _register_composite_model( _register_composite_model( model_type="qwen3_vl_moe", - projector_key="visual.merger", + projector_keys=["visual.merger"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"], language_model_keys=["language_model", "lm_head"], lora_conflict_keys=["patch_embed"], @@ -384,7 +395,7 @@ _register_composite_model( _register_composite_model( model_type="qwen3_omni_moe_thinker", - projector_key="visual.merger", + projector_keys=["visual.merger", "audio_tower.proj"], vision_model_keys=[ "visual.pos_embed", "visual.patch_embed", @@ -399,7 +410,7 @@ _register_composite_model( _register_composite_model( model_type="qwen3_5", - projector_key="model.visual.merger", + projector_keys=["model.visual.merger"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"], language_model_keys=["language_model", "lm_head"], lora_conflict_keys=["patch_embed"], @@ -408,7 +419,7 @@ _register_composite_model( _register_composite_model( model_type="qwen3_5_moe", - projector_key="model.visual.merger", + projector_keys=["model.visual.merger"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"], language_model_keys=["language_model", "lm_head"], lora_conflict_keys=["patch_embed"],