From f48d07cd6c5b4d9f8b7b43221b969e5cff73a77f Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 25 Nov 2024 19:43:42 +0800 Subject: [PATCH 1/2] fix #6136 Former-commit-id: 0516e556a71a22b8767b17734adb94eb127e7e6f --- src/llamafactory/model/model_utils/visual.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 04f1eae8..4204c6a4 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -112,6 +112,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None: "llava", "llava_next", "llava_next_video", + "mllama", "paligemma", "pixtral", "video_llava", From a489f10986259217ab1fdc0093ea6159fb8e86d0 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 25 Nov 2024 20:06:06 +0800 Subject: [PATCH 2/2] fix visual patch Former-commit-id: 75b586c31acf47d1bd28e04566ffd1d954e45596 --- src/llamafactory/model/model_utils/visual.py | 27 ++++++++++---------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 4204c6a4..65926197 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -92,7 +92,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen if getattr(model, "quantization_method", None): model_type = getattr(model.config, "model_type", None) - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") elif model_type == "qwen2_vl": mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") @@ -108,15 +108,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None: Patches VLMs before loading them. """ model_type = getattr(config, "model_type", None) - if model_type in [ - "llava", - "llava_next", - "llava_next_video", - "mllama", - "paligemma", - "pixtral", - "video_llava", - ]: # required for ds zero3 and valuehead models + if model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: + # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) if getattr(config, "is_yi_vl_derived_model", None): @@ -130,13 +123,20 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni """ model_type = getattr(config, "model_type", None) forbidden_modules = set() - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: if finetuning_args.freeze_vision_tower: forbidden_modules.add("vision_tower") if finetuning_args.train_mm_proj_only: forbidden_modules.add("language_model") + elif model_type == "mllama": + if finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_model") + + if finetuning_args.train_mm_proj_only: + forbidden_modules.add("language_model") + elif model_type == "qwen2_vl": if finetuning_args.freeze_vision_tower: forbidden_modules.add("visual") @@ -189,8 +189,9 @@ def patch_target_modules( Freezes vision tower for VLM LoRA tuning. """ model_type = getattr(config, "model_type", None) + vit_model_type = getattr(getattr(config, "vision_config", None), "model_type", None) if finetuning_args.freeze_vision_tower: - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) elif model_type == "mllama": return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules)) @@ -201,7 +202,7 @@ def patch_target_modules( else: if model_type == "qwen2_vl": return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) - elif model_type == "pixtral": + elif vit_model_type == "pixtral": return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules)) else: return target_modules