From 2baf8bf03d475700a73f4066fe40b3cbc4c070c8 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Fri, 14 Feb 2025 21:38:43 +0800 Subject: [PATCH] [misc] fix lora regex (#6944) * fix lora regex * fix Former-commit-id: 1ada3ae5a3a14057341540c6d6ba985adf95f348 --- src/llamafactory/data/mm_plugin.py | 2 +- src/llamafactory/extras/constants.py | 2 +- src/llamafactory/model/adapter.py | 2 +- src/llamafactory/model/model_utils/misc.py | 2 +- src/llamafactory/model/model_utils/visual.py | 55 ++++++++++---------- 5 files changed, 32 insertions(+), 31 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 8d69b5a2..781130c4 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1147,7 +1147,7 @@ class Qwen2vlPlugin(BasePlugin): mm_inputs = self._get_mm_inputs(images, videos, audios, processor) fps_per_video = mm_inputs.pop("fps_per_video", []) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and fps_per_video: + if "second_per_grid_ts" in processor.model_input_names and fps_per_video: mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video] return mm_inputs diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 520d3958..74369bb9 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1210,7 +1210,7 @@ register_model_group( DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6", }, }, - template="minicpm_v", + template="minicpm_o", multimodal=True, ) diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 2602d5a3..399500e0 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -201,7 +201,7 @@ def _setup_lora_tuning( if finetuning_args.use_llama_pro: target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) - target_modules = patch_target_modules(model.config, finetuning_args, target_modules) + target_modules = patch_target_modules(model, finetuning_args, target_modules) if ( finetuning_args.use_dora diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index e5f8ce5f..fc777ecb 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -77,7 +77,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n ): module_names.append(name) - logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) + logger.info_rank0("Apply lora to layers: {}.".format(",".join(map(str, trainable_layer_ids)))) return module_names diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 57b6a31a..4a80a4e7 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -16,7 +16,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple import torch import transformers @@ -42,6 +42,7 @@ class CompositeModel: projector_key: str vision_model_keys: List[str] language_model_keys: List[str] + lora_conflict_keys: List[str] def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module": for key in self.projector_key.split("."): @@ -58,15 +59,14 @@ def _register_composite_model( projector_key: Optional[str] = None, vision_model_keys: Optional[List[str]] = None, language_model_keys: Optional[List[str]] = None, + lora_conflict_keys: Optional[List[str]] = None, ): - projector_key = projector_key or "multi_modal_projector" - vision_model_keys = vision_model_keys or ["vision_tower"] - language_model_keys = language_model_keys or ["language_model"] COMPOSITE_MODELS[model_type] = CompositeModel( model_type=model_type, - projector_key=projector_key, - vision_model_keys=vision_model_keys, - language_model_keys=language_model_keys, + projector_key=projector_key or "multi_modal_projector", + vision_model_keys=vision_model_keys or ["vision_tower"], + language_model_keys=language_model_keys or ["language_model"], + lora_conflict_keys=lora_conflict_keys or [], ) @@ -210,29 +210,25 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P def patch_target_modules( - config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] -) -> Union[str, List[str]]: + model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] +) -> List[str]: r""" Freezes vision tower for VLM LoRA tuning. """ - model_type = getattr(config, "model_type", None) - vit_model_type = getattr(getattr(config, "vision_config", None), "model_type", None) - if finetuning_args.freeze_vision_tower: - if model_type in COMPOSITE_MODELS: - vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys - logger.info_rank0(f"Set vision model not trainable: {vision_model_keys}.") - vision_model_keys = "|".join(vision_model_keys) - target_modules = "|".join(target_modules) - return f"^(?!.*{vision_model_keys}).*(?:{target_modules}).*" - else: - return target_modules + model_type = getattr(model.config, "model_type", None) + if model_type in COMPOSITE_MODELS: + forbidden_modules = get_forbidden_modules(model.config, finetuning_args) + forbidden_modules.update(COMPOSITE_MODELS[model_type].lora_conflict_keys) + module_names = [] + for name, _ in model.named_modules(): + if any(target_module in name for target_module in target_modules) and not any( + forbidden_module in name for forbidden_module in forbidden_modules + ): + module_names.append(name) + + return module_names else: - if model_type == "qwen2_vl": # avoid attaching lora to Conv3D layer - return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) - elif vit_model_type == "pixtral": - return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules)) - else: - return target_modules + return target_modules _register_composite_model( @@ -252,6 +248,7 @@ _register_composite_model( _register_composite_model( model_type="minicpmv", + projector_key="resampler", vision_model_keys=["vpm"], language_model_keys=["llm"], ) @@ -259,8 +256,10 @@ _register_composite_model( _register_composite_model( model_type="minicpmo", - vision_model_keys=["vpm", "apm", "resampler", "tts"], + projector_key="resampler", + vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"], language_model_keys=["llm"], + lora_conflict_keys=["audio_projection_layer"], ) @@ -291,6 +290,7 @@ _register_composite_model( projector_key="visual.merger", vision_model_keys=["visual.patch_embed", "visual.blocks"], language_model_keys=["model", "lm_head"], + lora_conflict_keys=["patch_embed"], ) @@ -299,4 +299,5 @@ _register_composite_model( projector_key="visual.merger", vision_model_keys=["visual.patch_embed", "visual.blocks"], language_model_keys=["model", "lm_head"], + lora_conflict_keys=["patch_embed"], )