diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 52cf9eb3..5f4b747e 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) forbidden_modules.add("output_layer") elif model_type == "internlm2": forbidden_modules.add("output") - elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: + elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: forbidden_modules.add("multi_modal_projector") elif model_type == "qwen2_vl": forbidden_modules.add("merger") if freeze_vision_tower: - if model_type == "qwen2_vl": + if model_type == "mllama": + forbidden_modules.add("vision_model") + elif model_type == "qwen2_vl": forbidden_modules.add("visual") else: forbidden_modules.add("vision_tower") diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index e93c5dc6..04f1eae8 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -191,6 +191,8 @@ def patch_target_modules( if finetuning_args.freeze_vision_tower: if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) + elif model_type == "mllama": + return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules)) elif model_type == "qwen2_vl": return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) else: