From 253752cccaf9fcc80989022d6a1f745c1b3e125f Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 23 Nov 2024 18:34:15 +0000 Subject: [PATCH] add forbidden modules Former-commit-id: df477370dc67315effac1a8f48068c4e9c4067a5 --- src/llamafactory/model/model_utils/misc.py | 6 ++++-- src/llamafactory/model/model_utils/visual.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 52cf9eb3..5f4b747e 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) forbidden_modules.add("output_layer") elif model_type == "internlm2": forbidden_modules.add("output") - elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: + elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: forbidden_modules.add("multi_modal_projector") elif model_type == "qwen2_vl": forbidden_modules.add("merger") if freeze_vision_tower: - if model_type == "qwen2_vl": + if model_type == "mllama": + forbidden_modules.add("vision_model") + elif model_type == "qwen2_vl": forbidden_modules.add("visual") else: forbidden_modules.add("vision_tower") diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index e93c5dc6..04f1eae8 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -191,6 +191,8 @@ def patch_target_modules( if finetuning_args.freeze_vision_tower: if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) + elif model_type == "mllama": + return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules)) elif model_type == "qwen2_vl": return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) else: