mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
Merge pull request #3835 from BUAADreamer/main
fix some features in llava-style training Former-commit-id: 838f2fb3e423a0471ff2898f737401e92bbafe2b
This commit is contained in:
commit
605e70d0e1
@ -38,6 +38,20 @@
|
|||||||
"assistant_tag": "assistant"
|
"assistant_tag": "assistant"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"mllm_pt_demo": {
|
||||||
|
"hf_hub_url": "BUAADreamer/mllm_pt_demo",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "messages",
|
||||||
|
"images": "images"
|
||||||
|
},
|
||||||
|
"tags": {
|
||||||
|
"role_tag": "role",
|
||||||
|
"content_tag": "content",
|
||||||
|
"user_tag": "user",
|
||||||
|
"assistant_tag": "assistant"
|
||||||
|
}
|
||||||
|
},
|
||||||
"alpaca_en": {
|
"alpaca_en": {
|
||||||
"hf_hub_url": "llamafactory/alpaca_en",
|
"hf_hub_url": "llamafactory/alpaca_en",
|
||||||
"ms_hub_url": "llamafactory/alpaca_en"
|
"ms_hub_url": "llamafactory/alpaca_en"
|
||||||
|
@ -85,6 +85,10 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."},
|
||||||
)
|
)
|
||||||
|
tune_mm_proj: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whethor or not only finetune mm_projector for MLLM."},
|
||||||
|
)
|
||||||
moe_aux_loss_coef: Optional[float] = field(
|
moe_aux_loss_coef: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||||
|
@ -10,6 +10,7 @@ from ..extras.logging import get_logger
|
|||||||
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
from .utils.quantization import QuantizationMethod
|
from .utils.quantization import QuantizationMethod
|
||||||
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||||
|
from .utils.visual import filter_vision_tower_linear
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -58,6 +59,9 @@ def init_adapter(
|
|||||||
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||||
model.vision_tower.requires_grad_(False)
|
model.vision_tower.requires_grad_(False)
|
||||||
|
|
||||||
|
if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj
|
||||||
|
model.language_model.requires_grad_(False)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
num_layers = (
|
num_layers = (
|
||||||
@ -180,6 +184,9 @@ def init_adapter(
|
|||||||
if finetuning_args.use_llama_pro:
|
if finetuning_args.use_llama_pro:
|
||||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||||
|
|
||||||
|
if model_args.visual_inputs:
|
||||||
|
target_modules = filter_vision_tower_linear(target_modules)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.use_dora
|
finetuning_args.use_dora
|
||||||
and getattr(model, "quantization_method", None) is not None
|
and getattr(model, "quantization_method", None) is not None
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Tuple, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.models
|
import transformers.models
|
||||||
@ -82,3 +82,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
|||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||||
|
|
||||||
|
|
||||||
|
def filter_vision_tower_linear(target_modules: List[str]) -> str:
|
||||||
|
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
|
||||||
|
return target_modules
|
||||||
|
Loading…
x
Reference in New Issue
Block a user