add regex of only tune lm and mm_proj

This commit is contained in:
BUAADreamer
2024-05-27 18:59:00 +08:00
parent 4bc7c10c00
commit 57eb13b75d
6 changed files with 151 additions and 6 deletions

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Tuple, List
import torch
import transformers.models
@@ -82,3 +82,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
if getattr(config, "is_yi_vl_derived_model", None):
logger.info("Detected Yi-VL model, applying projector patch.")
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
def filter_vision_tower_linear(target_modules: List[str]) -> str:
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
return target_modules