[misc] fix lora regex (#6944)

* fix lora regex

* fix

Former-commit-id: 1ada3ae5a3a14057341540c6d6ba985adf95f348
This commit is contained in:
hoshi-hiyouga 2025-02-14 21:38:43 +08:00 committed by GitHub
parent 13e1b7ee2b
commit 2baf8bf03d
5 changed files with 32 additions and 31 deletions

View File

@ -1147,7 +1147,7 @@ class Qwen2vlPlugin(BasePlugin):
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", []) fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and fps_per_video: if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video] mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
return mm_inputs return mm_inputs

View File

@ -1210,7 +1210,7 @@ register_model_group(
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6", DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-o-2_6",
}, },
}, },
template="minicpm_v", template="minicpm_o",
multimodal=True, multimodal=True,
) )

View File

@ -201,7 +201,7 @@ def _setup_lora_tuning(
if finetuning_args.use_llama_pro: if finetuning_args.use_llama_pro:
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
target_modules = patch_target_modules(model.config, finetuning_args, target_modules) target_modules = patch_target_modules(model, finetuning_args, target_modules)
if ( if (
finetuning_args.use_dora finetuning_args.use_dora

View File

@ -77,7 +77,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
): ):
module_names.append(name) module_names.append(name)
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) logger.info_rank0("Apply lora to layers: {}.".format(",".join(map(str, trainable_layer_ids))))
return module_names return module_names

View File

@ -16,7 +16,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
import torch import torch
import transformers import transformers
@ -42,6 +42,7 @@ class CompositeModel:
projector_key: str projector_key: str
vision_model_keys: List[str] vision_model_keys: List[str]
language_model_keys: List[str] language_model_keys: List[str]
lora_conflict_keys: List[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module": def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."): for key in self.projector_key.split("."):
@ -58,15 +59,14 @@ def _register_composite_model(
projector_key: Optional[str] = None, projector_key: Optional[str] = None,
vision_model_keys: Optional[List[str]] = None, vision_model_keys: Optional[List[str]] = None,
language_model_keys: Optional[List[str]] = None, language_model_keys: Optional[List[str]] = None,
lora_conflict_keys: Optional[List[str]] = None,
): ):
projector_key = projector_key or "multi_modal_projector"
vision_model_keys = vision_model_keys or ["vision_tower"]
language_model_keys = language_model_keys or ["language_model"]
COMPOSITE_MODELS[model_type] = CompositeModel( COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type, model_type=model_type,
projector_key=projector_key, projector_key=projector_key or "multi_modal_projector",
vision_model_keys=vision_model_keys, vision_model_keys=vision_model_keys or ["vision_tower"],
language_model_keys=language_model_keys, language_model_keys=language_model_keys or ["language_model"],
lora_conflict_keys=lora_conflict_keys or [],
) )
@ -210,27 +210,23 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P
def patch_target_modules( def patch_target_modules(
config: "PretrainedConfig", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> Union[str, List[str]]: ) -> List[str]:
r""" r"""
Freezes vision tower for VLM LoRA tuning. Freezes vision tower for VLM LoRA tuning.
""" """
model_type = getattr(config, "model_type", None) model_type = getattr(model.config, "model_type", None)
vit_model_type = getattr(getattr(config, "vision_config", None), "model_type", None)
if finetuning_args.freeze_vision_tower:
if model_type in COMPOSITE_MODELS: if model_type in COMPOSITE_MODELS:
vision_model_keys = COMPOSITE_MODELS[model_type].vision_model_keys forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
logger.info_rank0(f"Set vision model not trainable: {vision_model_keys}.") forbidden_modules.update(COMPOSITE_MODELS[model_type].lora_conflict_keys)
vision_model_keys = "|".join(vision_model_keys) module_names = []
target_modules = "|".join(target_modules) for name, _ in model.named_modules():
return f"^(?!.*{vision_model_keys}).*(?:{target_modules}).*" if any(target_module in name for target_module in target_modules) and not any(
else: forbidden_module in name for forbidden_module in forbidden_modules
return target_modules ):
else: module_names.append(name)
if model_type == "qwen2_vl": # avoid attaching lora to Conv3D layer
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules)) return module_names
elif vit_model_type == "pixtral":
return "^(?!.*patch_conv).*(?:{}).*".format("|".join(target_modules))
else: else:
return target_modules return target_modules
@ -252,6 +248,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="minicpmv", model_type="minicpmv",
projector_key="resampler",
vision_model_keys=["vpm"], vision_model_keys=["vpm"],
language_model_keys=["llm"], language_model_keys=["llm"],
) )
@ -259,8 +256,10 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="minicpmo", model_type="minicpmo",
vision_model_keys=["vpm", "apm", "resampler", "tts"], projector_key="resampler",
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
language_model_keys=["llm"], language_model_keys=["llm"],
lora_conflict_keys=["audio_projection_layer"],
) )
@ -291,6 +290,7 @@ _register_composite_model(
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"], language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
) )
@ -299,4 +299,5 @@ _register_composite_model(
projector_key="visual.merger", projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["model", "lm_head"], language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
) )