diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 05a3e298..8b8c6fb8 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -23,7 +23,7 @@ from ..extras import logging from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.quantization import QuantizationMethod from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model -from .model_utils.visual import get_forbidden_modules, patch_target_modules +from .model_utils.visual import COMPOSITE_MODELS, get_forbidden_modules, patch_target_modules if TYPE_CHECKING: @@ -100,7 +100,7 @@ def _setup_freeze_tuning( hidden_modules.add(name.split(".1.")[-1].split(".")[0]) if re.search(r"\.\d+\.", name) is None: - non_hidden_modules.add(name.split(".")[-2]) + non_hidden_modules.add(name.split(".")[-2]) # remove weight/bias trainable_layers = [] for module_name in finetuning_args.freeze_trainable_modules: @@ -121,6 +121,10 @@ def _setup_freeze_tuning( trainable_layers.append(module_name) + model_type = getattr(model.config, "model_type", None) + if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS: + trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key) + forbidden_modules = get_forbidden_modules(model.config, finetuning_args) for name, param in model.named_parameters(): if any(trainable_layer in name for trainable_layer in trainable_layers) and not any(