mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
add regex of only tune lm and mm_proj
Former-commit-id: 57eb13b75d8597d748e84d3549a0b08876b669db
This commit is contained in:
parent
a6c2a2071d
commit
606240aec0
49
sites/paligemma-pt.yaml
Normal file
49
sites/paligemma-pt.yaml
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# model
|
||||||
|
model_name_or_path: google/paligemma-3b-mix-448
|
||||||
|
visual_inputs: true
|
||||||
|
tune_mm_proj: true
|
||||||
|
#print_param_status: true
|
||||||
|
|
||||||
|
# method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
|
||||||
|
# ddp
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
deepspeed: examples/deepspeed/ds_z2_offload_config.json
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
dataset: mllm_pt_demo
|
||||||
|
dataset_dir: data
|
||||||
|
template: gemma
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 3
|
||||||
|
#val_size: 0.0001
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
# output
|
||||||
|
output_dir: saves/paligemma/full/sft_llava_pt_test
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 50
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
#save_strategy: epoch
|
||||||
|
#save_total_limit: 2
|
||||||
|
|
||||||
|
# train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 16
|
||||||
|
learning_rate: 0.00001
|
||||||
|
num_train_epochs: 100
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_steps: 0.1
|
||||||
|
#bf16: true
|
||||||
|
pure_bf16: true
|
||||||
|
|
||||||
|
# eval
|
||||||
|
do_eval: false
|
||||||
|
#per_device_eval_batch_size: 1
|
||||||
|
#evaluation_strategy: steps
|
||||||
|
#eval_steps: 500
|
49
sites/paligemma.yaml
Normal file
49
sites/paligemma.yaml
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# model
|
||||||
|
model_name_or_path: google/paligemma-3b-mix-448
|
||||||
|
visual_inputs: true
|
||||||
|
#print_param_status: true
|
||||||
|
use_fast_tokenizer: false
|
||||||
|
|
||||||
|
# method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
|
||||||
|
# ddp
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
deepspeed: examples/deepspeed/ds_z2_offload_config.json
|
||||||
|
|
||||||
|
# dataset
|
||||||
|
dataset: mllm_demo
|
||||||
|
dataset_dir: data
|
||||||
|
template: gemma
|
||||||
|
cutoff_len: 2048
|
||||||
|
max_samples: 3
|
||||||
|
#val_size: 0.0001
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
# output
|
||||||
|
output_dir: saves/paligemma/full/sft_llava_1k
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 50
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
#save_strategy: epoch
|
||||||
|
#save_total_limit: 2
|
||||||
|
|
||||||
|
# train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 16
|
||||||
|
learning_rate: 0.00001
|
||||||
|
num_train_epochs: 100
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_steps: 0.1
|
||||||
|
#bf16: true
|
||||||
|
pure_bf16: true
|
||||||
|
|
||||||
|
# eval
|
||||||
|
do_eval: false
|
||||||
|
#per_device_eval_batch_size: 1
|
||||||
|
#evaluation_strategy: steps
|
||||||
|
#eval_steps: 500
|
40
sites/paligemma_lora.yaml
Normal file
40
sites/paligemma_lora.yaml
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: google/paligemma-3b-mix-448
|
||||||
|
visual_inputs: true
|
||||||
|
use_fast_tokenizer: false
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_target: q_proj,v_proj
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: mllm_demo
|
||||||
|
template: gemma
|
||||||
|
cutoff_len: 1024
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/paligemma/lora/sft_mllm
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
learning_rate: 0.0001
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_steps: 0.1
|
||||||
|
fp16: true
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
evaluation_strategy: steps
|
||||||
|
eval_steps: 500
|
@ -10,6 +10,7 @@ from ..extras.logging import get_logger
|
|||||||
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
from .utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
from .utils.quantization import QuantizationMethod
|
from .utils.quantization import QuantizationMethod
|
||||||
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||||
|
from .utils.visual import filter_vision_tower_linear
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -58,6 +59,9 @@ def init_adapter(
|
|||||||
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
|
||||||
model.vision_tower.requires_grad_(False)
|
model.vision_tower.requires_grad_(False)
|
||||||
|
|
||||||
|
if model_args.visual_inputs and hasattr(model, "language_model") and model_args.tune_mm_proj: # freeze language model if only tune mm_proj
|
||||||
|
model.language_model.requires_grad_(False)
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
if finetuning_args.finetuning_type == "freeze" and is_trainable:
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info("Fine-tuning method: Freeze")
|
||||||
num_layers = (
|
num_layers = (
|
||||||
@ -180,6 +184,9 @@ def init_adapter(
|
|||||||
if finetuning_args.use_llama_pro:
|
if finetuning_args.use_llama_pro:
|
||||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable)
|
||||||
|
|
||||||
|
if model_args.visual_inputs:
|
||||||
|
target_modules = filter_vision_tower_linear(target_modules)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.use_dora
|
finetuning_args.use_dora
|
||||||
and getattr(model, "quantization_method", None) is not None
|
and getattr(model, "quantization_method", None) is not None
|
||||||
|
@ -163,11 +163,6 @@ def load_model(
|
|||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
if model_args.visual_inputs and model_args.tune_mm_proj:
|
|
||||||
lm_params = [param for name, param in model.named_parameters() if "language_model" in name]
|
|
||||||
for param in lm_params:
|
|
||||||
param.requires_grad_(False)
|
|
||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import TYPE_CHECKING, Tuple
|
from typing import TYPE_CHECKING, Tuple, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers.models
|
import transformers.models
|
||||||
@ -82,3 +82,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
|||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||||
|
|
||||||
|
|
||||||
|
def filter_vision_tower_linear(target_modules: List[str]) -> str:
|
||||||
|
target_modules = f"^(?!.*vision_tower).*(?:{'|'.join(target_modules)}).*"
|
||||||
|
return target_modules
|
||||||
|
Loading…
x
Reference in New Issue
Block a user