mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
initial-commit
Former-commit-id: aeb85f200bd824748008dae6047c2607dfcdf174
This commit is contained in:
parent
efd60f0306
commit
8a09b1e732
@ -38,6 +38,20 @@
|
|||||||
"assistant_tag": "assistant"
|
"assistant_tag": "assistant"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"qwen2vl_demo": {
|
||||||
|
"file_name": "qwen2vl_demo.json",
|
||||||
|
"formatting": "sharegpt",
|
||||||
|
"columns": {
|
||||||
|
"messages": "messages",
|
||||||
|
"images": "images"
|
||||||
|
},
|
||||||
|
"tags": {
|
||||||
|
"role_tag": "role",
|
||||||
|
"content_tag": "content",
|
||||||
|
"user_tag": "user",
|
||||||
|
"assistant_tag": "assistant"
|
||||||
|
}
|
||||||
|
},
|
||||||
"alpaca_en": {
|
"alpaca_en": {
|
||||||
"hf_hub_url": "llamafactory/alpaca_en",
|
"hf_hub_url": "llamafactory/alpaca_en",
|
||||||
"ms_hub_url": "llamafactory/alpaca_en"
|
"ms_hub_url": "llamafactory/alpaca_en"
|
||||||
|
140
data/qwen2vl_demo.json
Normal file
140
data/qwen2vl_demo.json
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"content": "<|image_pad|>Who are they?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "They're Kane and Gretzka from Bayern Munich.",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "What are they doing?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "They are celebrating on the soccer field.",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
"mllm_demo_data/1.jpg"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"content": "<|image_pad|>Who is he?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "He's Thomas Muller from Bayern Munich.",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "<|image_pad|>Why is he on the ground?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "Because he's sliding on his knees to celebrate.",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
"mllm_demo_data/2.jpg","mllm_demo_data/2.jpg"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"content": "<|image_pad|>Please describe this image",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "Chinese astronaut Gui Haichao is giving a speech.",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "What has he accomplished?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "He was appointed to be a payload specialist on Shenzhou 16 mission in June 2022, thus becoming the first Chinese civilian of Group 3 in space on 30 May 2023. He is responsible for the on-orbit operation of space science experimental payloads.",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
"mllm_demo_data/3.jpg"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"content": "<|image_pad|>他们是谁?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "他们是拜仁慕尼黑的凯恩和格雷茨卡。",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "<|image_pad|>他们在做什么?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "他们在足球场上庆祝。",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
"mllm_demo_data/1.jpg","mllm_demo_data/1.jpg"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"content": "<|image_pad|>他是谁?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "他是来自拜仁慕尼黑的托马斯·穆勒。",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "他为什么在地上?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "因为他正在双膝跪地滑行庆祝。",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
"mllm_demo_data/2.jpg"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"content": "<|image_pad|>请描述这张图片",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "中国宇航员桂海潮正在讲话。",
|
||||||
|
"role": "assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "他取得过哪些成就?",
|
||||||
|
"role": "user"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "他于2022年6月被任命为神舟十六号任务的有效载荷专家,从而成为2023年5月30日进入太空的首位平民宇航员。他负责在轨操作空间科学实验有效载荷。",
|
||||||
|
"role": "assistant"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
"mllm_demo_data/3.jpg"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
40
examples/train_full/qwen2vl_full_sft.yaml
Normal file
40
examples/train_full/qwen2vl_full_sft.yaml
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf
|
||||||
|
visual_inputs: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: full
|
||||||
|
deepspeed: examples/deepspeed/ds_z3_config.json
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: qwen2vl_demo
|
||||||
|
template: qwen2vl
|
||||||
|
cutoff_len: 1024
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen2-vl-7b/full/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
eval_strategy: steps
|
||||||
|
eval_steps: 500
|
40
examples/train_lora/qwen2vl_lora_sft.yaml
Normal file
40
examples/train_lora/qwen2vl_lora_sft.yaml
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
### model
|
||||||
|
model_name_or_path: qwen2-vl-hf/qwen2-vl-7b-hf
|
||||||
|
visual_inputs: true
|
||||||
|
|
||||||
|
### method
|
||||||
|
stage: sft
|
||||||
|
do_train: true
|
||||||
|
finetuning_type: lora
|
||||||
|
lora_target: all
|
||||||
|
|
||||||
|
### dataset
|
||||||
|
dataset: qwen2vl_demo
|
||||||
|
template: qwen2vl
|
||||||
|
cutoff_len: 1024
|
||||||
|
max_samples: 1000
|
||||||
|
overwrite_cache: true
|
||||||
|
preprocessing_num_workers: 16
|
||||||
|
|
||||||
|
### output
|
||||||
|
output_dir: saves/qwen2-vl-7b/lora/sft
|
||||||
|
logging_steps: 10
|
||||||
|
save_steps: 500
|
||||||
|
plot_loss: true
|
||||||
|
overwrite_output_dir: true
|
||||||
|
|
||||||
|
### train
|
||||||
|
per_device_train_batch_size: 2
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
num_train_epochs: 3.0
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
warmup_ratio: 0.1
|
||||||
|
bf16: true
|
||||||
|
ddp_timeout: 180000000
|
||||||
|
|
||||||
|
### eval
|
||||||
|
val_size: 0.1
|
||||||
|
per_device_eval_batch_size: 1
|
||||||
|
eval_strategy: steps
|
||||||
|
eval_steps: 500
|
@ -72,9 +72,37 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
|
|||||||
compute_dtype: "torch.dtype" = torch.float32
|
compute_dtype: "torch.dtype" = torch.float32
|
||||||
|
|
||||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
|
||||||
|
image_grid_thw = None
|
||||||
|
if "image_grid_thw" in features[0]:
|
||||||
|
image_grid_thw_list = [
|
||||||
|
torch.Tensor(feature["image_grid_thw"]).long()
|
||||||
|
for feature in features
|
||||||
|
if feature["image_grid_thw"][0][0] > 0
|
||||||
|
]
|
||||||
|
pixel_values_list = [
|
||||||
|
torch.Tensor(feature["pixel_values"]) for feature in features if feature["image_grid_thw"][0][0] > 0
|
||||||
|
]
|
||||||
|
if image_grid_thw_list:
|
||||||
|
image_grid_thw = torch.cat(image_grid_thw_list, 0)
|
||||||
|
else:
|
||||||
|
# Handle the case where the list is empty, for example:
|
||||||
|
image_grid_thw = None
|
||||||
|
if pixel_values_list:
|
||||||
|
pixel_values = torch.cat(pixel_values_list, 0)
|
||||||
|
else:
|
||||||
|
# Handle the case where the list is empty, for example:
|
||||||
|
pixel_values = None
|
||||||
|
features = [
|
||||||
|
{key: feature[key] for key in feature if key not in ["image_grid_thw", "pixel_values"]}
|
||||||
|
for feature in features
|
||||||
|
]
|
||||||
|
|
||||||
features = super().__call__(features)
|
features = super().__call__(features)
|
||||||
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
|
||||||
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
|
||||||
|
if image_grid_thw is not None:
|
||||||
|
features["image_grid_thw"] = image_grid_thw
|
||||||
|
features["pixel_values"] = pixel_values
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
@ -78,6 +78,20 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
|
|||||||
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
return [0] * image_seq_length + [1] * (input_len - image_seq_length)
|
||||||
|
|
||||||
|
|
||||||
|
def get_qwen2vl_image_inputs(images: Sequence["ImageObject"], processor: "ProcessorMixin") -> "NDArray":
|
||||||
|
r"""
|
||||||
|
Processes visual inputs. support multi images
|
||||||
|
"""
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
if len(images) != 0:
|
||||||
|
image_inputs = image_processor(images=images, return_tensors="pt")
|
||||||
|
else:
|
||||||
|
image = Image.new("RGB", (56, 56), (255, 255, 255))
|
||||||
|
image_inputs = image_processor(images=[image], return_tensors="pt")
|
||||||
|
image_inputs["image_grid_thw"][0][0] = 0
|
||||||
|
return {"pixel_values": image_inputs["pixel_values"], "image_grid_thw": image_inputs["image_grid_thw"]}
|
||||||
|
|
||||||
|
|
||||||
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
||||||
r"""
|
r"""
|
||||||
Computes the real sequence length after truncation by the cutoff_len.
|
Computes the real sequence length after truncation by the cutoff_len.
|
||||||
|
@ -17,10 +17,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
|||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, greedy_knapsack, infer_seqlen
|
from .processor_utils import (
|
||||||
|
get_paligemma_token_type_ids,
|
||||||
|
get_pixel_values,
|
||||||
|
get_qwen2vl_image_inputs,
|
||||||
|
greedy_knapsack,
|
||||||
|
infer_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from PIL.Image import Image as ImageObject
|
||||||
from transformers import PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import DataArguments
|
from ...hparams import DataArguments
|
||||||
@ -36,12 +43,31 @@ def _encode_supervised_example(
|
|||||||
system: Optional[str],
|
system: Optional[str],
|
||||||
tools: Optional[str],
|
tools: Optional[str],
|
||||||
template: "Template",
|
template: "Template",
|
||||||
|
images: Sequence["ImageObject"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
cutoff_len: int,
|
cutoff_len: int,
|
||||||
train_on_prompt: bool,
|
train_on_prompt: bool,
|
||||||
mask_history: bool,
|
mask_history: bool,
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
|
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
||||||
|
image_processor = getattr(processor, "image_processor")
|
||||||
|
merge_length = image_processor.merge_size**2
|
||||||
|
if len(images) > 0:
|
||||||
|
image_grid_thw = get_qwen2vl_image_inputs(images, processor)["image_grid_thw"]
|
||||||
|
index = 0
|
||||||
|
for message in prompt:
|
||||||
|
content = message["content"]
|
||||||
|
while "<|image_pad|>" in content:
|
||||||
|
content = content.replace(
|
||||||
|
"<|image_pad|>",
|
||||||
|
template.vision_start_token
|
||||||
|
+ "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length)
|
||||||
|
+ template.vision_end_token,
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
index += 1
|
||||||
|
message["content"] = content.replace("<|placeholder|>", "<|image_pad|>")
|
||||||
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
|
||||||
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
prompt[0]["content"] = template.image_token + prompt[0]["content"]
|
||||||
|
|
||||||
@ -107,6 +133,8 @@ def preprocess_supervised_dataset(
|
|||||||
model_inputs["pixel_values"] = []
|
model_inputs["pixel_values"] = []
|
||||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||||
model_inputs["token_type_ids"] = []
|
model_inputs["token_type_ids"] = []
|
||||||
|
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
||||||
|
model_inputs["image_grid_thw"] = []
|
||||||
|
|
||||||
for i in range(len(examples["prompt"])):
|
for i in range(len(examples["prompt"])):
|
||||||
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
||||||
@ -129,9 +157,14 @@ def preprocess_supervised_dataset(
|
|||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
model_inputs["labels"].append(labels)
|
model_inputs["labels"].append(labels)
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
if "image_grid_thw" in processor.model_input_names: # qwen2_vl models
|
||||||
if hasattr(processor, "image_seq_length"): # paligemma models
|
image_inputs = get_qwen2vl_image_inputs(examples["images"][i], processor)
|
||||||
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
model_inputs["pixel_values"].append(image_inputs["pixel_values"])
|
||||||
|
model_inputs["image_grid_thw"].append(image_inputs["image_grid_thw"])
|
||||||
|
else:
|
||||||
|
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
|
||||||
|
if hasattr(processor, "image_seq_length"): # paligemma models
|
||||||
|
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
|
@ -42,6 +42,8 @@ class Template:
|
|||||||
default_system: str
|
default_system: str
|
||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
image_token: str
|
image_token: str
|
||||||
|
vision_start_token: str
|
||||||
|
vision_end_token: str
|
||||||
efficient_eos: bool
|
efficient_eos: bool
|
||||||
replace_eos: bool
|
replace_eos: bool
|
||||||
|
|
||||||
@ -206,6 +208,8 @@ def _register_template(
|
|||||||
default_system: str = "",
|
default_system: str = "",
|
||||||
stop_words: Sequence[str] = [],
|
stop_words: Sequence[str] = [],
|
||||||
image_token: str = "<image>",
|
image_token: str = "<image>",
|
||||||
|
vision_start_token: str = "<|vision_start|>",
|
||||||
|
vision_end_token: str = "<|vision_end|>",
|
||||||
efficient_eos: bool = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -255,6 +259,8 @@ def _register_template(
|
|||||||
default_system=default_system,
|
default_system=default_system,
|
||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
image_token=image_token,
|
image_token=image_token,
|
||||||
|
vision_start_token=vision_start_token,
|
||||||
|
vision_end_token=vision_end_token,
|
||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos,
|
replace_eos=replace_eos,
|
||||||
)
|
)
|
||||||
@ -783,6 +789,21 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="qwen2vl",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
default_system="You are a helpful assistant.",
|
||||||
|
image_token="<|image_pad|>",
|
||||||
|
vision_start_token="<|vision_start|>",
|
||||||
|
vision_end_token="<|vision_end|>",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
replace_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="sailor",
|
name="sailor",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
||||||
|
@ -212,7 +212,7 @@ def _setup_lora_tuning(
|
|||||||
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers)
|
||||||
|
|
||||||
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
if model_args.visual_inputs and finetuning_args.freeze_vision_tower:
|
||||||
target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
target_modules = "^(?!.*(?:vision_tower|visual)).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.use_dora
|
finetuning_args.use_dora
|
||||||
|
@ -36,6 +36,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
|||||||
forbidden_modules.add("output")
|
forbidden_modules.add("output")
|
||||||
elif model.config.model_type in ["llava", "paligemma"]:
|
elif model.config.model_type in ["llava", "paligemma"]:
|
||||||
forbidden_modules.add("multi_modal_projector")
|
forbidden_modules.add("multi_modal_projector")
|
||||||
|
elif model.config.model_type in ["qwen2_vl"]:
|
||||||
|
forbidden_modules.add("merger")
|
||||||
|
|
||||||
if freeze_vision_tower:
|
if freeze_vision_tower:
|
||||||
forbidden_modules.add("vision_tower")
|
forbidden_modules.add("vision_tower")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user