diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index bc3ef676..7007b5a2 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -186,12 +186,10 @@ def patch_target_modules( """ model_type = getattr(config, "model_type", None) if finetuning_args.freeze_vision_tower: - if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) elif model_type == "qwen2_vl": return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) - elif model_type == "pixtral": - return "^(?!.*vision_encoder).*(?:{}).*".format("|".join(target_modules)) else: return target_modules else: