diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 4204c6a4..65926197 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -92,7 +92,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen if getattr(model, "quantization_method", None): model_type = getattr(model.config, "model_type", None) - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") elif model_type == "qwen2_vl": mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") @@ -108,15 +108,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None: Patches VLMs before loading them. """ model_type = getattr(config, "model_type", None) - if model_type in [ - "llava", - "llava_next", - "llava_next_video", - "mllama", - "paligemma", - "pixtral", - "video_llava", - ]: # required for ds zero3 and valuehead models + if model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: + # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) if getattr(config, "is_yi_vl_derived_model", None): @@ -130,13 +123,20 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni """ model_type = getattr(config, "model_type", None) forbidden_modules = set() - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: if finetuning_args.freeze_vision_tower: forbidden_modules.add("vision_tower") if finetuning_args.train_mm_proj_only: forbidden_modules.add("language_model") + elif model_type == "mllama": + if finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_model") + + if finetuning_args.train_mm_proj_only: + forbidden_modules.add("language_model") + elif model_type == "qwen2_vl": if finetuning_args.freeze_vision_tower: forbidden_modules.add("visual") @@ -189,8 +189,9 @@ def patch_target_modules( Freezes vision tower for VLM LoRA tuning. """ model_type = getattr(config, "model_type", None) + vit_model_type = getattr(getattr(config, "vision_config", None), "model_type", None) if finetuning_args.freeze_vision_tower: - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) elif model_type == "mllama": return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules)) @@ -201,7 +202,7 @@ def patch_target_modules( else: if model_type == "qwen2_vl": return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) - elif model_type == "pixtral": + elif vit_model_type == "pixtral": return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules)) else: return target_modules