mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[data] fix qwen2vl pos ids (#8387)
This commit is contained in:
parent
9f2f12b0fe
commit
af2f75e688
@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from peft import PeftModel
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||||
@ -94,6 +95,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if self.template is None:
|
if self.template is None:
|
||||||
raise ValueError("Template is required for MultiModalDataCollator.")
|
raise ValueError("Template is required for MultiModalDataCollator.")
|
||||||
|
|
||||||
|
if isinstance(self.model, PeftModel):
|
||||||
|
self.model = self.model.base_model.model
|
||||||
|
|
||||||
|
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||||
|
self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni
|
||||||
|
elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
|
||||||
|
self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0
|
||||||
|
else:
|
||||||
|
self.get_rope_func = None
|
||||||
|
|
||||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||||
batch_images, batch_videos, batch_audios = [], [], []
|
batch_images, batch_videos, batch_audios = [], [], []
|
||||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||||
@ -171,7 +182,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||||
|
|
||||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
if self.get_rope_func is not None:
|
||||||
rope_index_kwargs = {
|
rope_index_kwargs = {
|
||||||
"input_ids": features["input_ids"],
|
"input_ids": features["input_ids"],
|
||||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||||
@ -180,27 +191,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
}
|
}
|
||||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||||
if "video_second_per_grid" in mm_inputs: # for qwen2omni
|
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||||
|
|
||||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
|
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni
|
||||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||||
if feature_attention_mask is not None:
|
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||||
audio_feature_lengths = torch.sum(
|
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||||
feature_attention_mask, dim=1
|
|
||||||
) # FIXME need to get video image lengths
|
|
||||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||||
|
|
||||||
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
|
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||||
# avoid conflict
|
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||||
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
|
dim=-1
|
||||||
features["position_ids"], features["rope_deltas"] = (
|
).unsqueeze(-1)
|
||||||
new_position_ids.clone(),
|
|
||||||
rope_deltas - delta0,
|
|
||||||
) # avoid inplace operation FIXME
|
|
||||||
else: # for qwen2vl
|
else: # for qwen2vl
|
||||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
|
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.model is not None
|
||||||
|
and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
|
||||||
|
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||||
|
):
|
||||||
|
raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")
|
||||||
|
|
||||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||||
|
@ -1609,13 +1609,13 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Mistral-Small-24B-Base-2503": {
|
"Mistral-Small-3.1-24B-Base": {
|
||||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2503",
|
DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Base-2503",
|
||||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2503",
|
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Base-2503",
|
||||||
},
|
},
|
||||||
"Mistral-Small-24B-Instruct-2503": {
|
"Mistral-Small-3.1-24B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2503",
|
DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2503",
|
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="mistral_small",
|
template="mistral_small",
|
||||||
|
@ -169,10 +169,15 @@ def _check_extra_dependencies(
|
|||||||
if finetuning_args.plot_loss:
|
if finetuning_args.plot_loss:
|
||||||
check_version("matplotlib", mandatory=True)
|
check_version("matplotlib", mandatory=True)
|
||||||
|
|
||||||
if training_args is not None and training_args.predict_with_generate:
|
if training_args is not None:
|
||||||
check_version("jieba", mandatory=True)
|
if training_args.deepspeed:
|
||||||
check_version("nltk", mandatory=True)
|
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
|
||||||
check_version("rouge_chinese", mandatory=True)
|
check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True)
|
||||||
|
|
||||||
|
if training_args.predict_with_generate:
|
||||||
|
check_version("jieba", mandatory=True)
|
||||||
|
check_version("nltk", mandatory=True)
|
||||||
|
check_version("rouge_chinese", mandatory=True)
|
||||||
|
|
||||||
|
|
||||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||||
|
@ -86,10 +86,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
padding_side="right",
|
padding_side="right",
|
||||||
**init_kwargs,
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
except ValueError: # try the fast one
|
except ValueError: # try another one
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
use_fast=True,
|
use_fast=not model_args.use_fast_tokenizer,
|
||||||
padding_side="right",
|
padding_side="right",
|
||||||
**init_kwargs,
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
@ -97,12 +97,23 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
raise OSError("Failed to load tokenizer.") from e
|
raise OSError("Failed to load tokenizer.") from e
|
||||||
|
|
||||||
patch_tokenizer(tokenizer, model_args)
|
patch_tokenizer(tokenizer, model_args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
processor = AutoProcessor.from_pretrained(
|
||||||
patch_processor(processor, tokenizer, model_args)
|
model_args.model_name_or_path,
|
||||||
|
use_fast=model_args.use_fast_tokenizer,
|
||||||
|
**init_kwargs,
|
||||||
|
)
|
||||||
|
except ValueError: # try another one
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
use_fast=not model_args.use_fast_tokenizer,
|
||||||
|
**init_kwargs,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info_rank0(f"Failed to load processor: {e}.")
|
raise OSError("Failed to load processor.") from e
|
||||||
processor = None
|
|
||||||
|
patch_processor(processor, tokenizer, model_args)
|
||||||
|
|
||||||
# Avoid load tokenizer, see:
|
# Avoid load tokenizer, see:
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
||||||
|
@ -287,7 +287,9 @@ _register_composite_model(
|
|||||||
model_type="qwen2_vl",
|
model_type="qwen2_vl",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"]
|
||||||
|
if is_transformers_version_greater_than("4.52.0")
|
||||||
|
else ["model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -296,6 +298,8 @@ _register_composite_model(
|
|||||||
model_type="qwen2_5_vl",
|
model_type="qwen2_5_vl",
|
||||||
projector_key="visual.merger",
|
projector_key="visual.merger",
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"]
|
||||||
|
if is_transformers_version_greater_than("4.52.0")
|
||||||
|
else ["model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
)
|
)
|
||||||
|
@ -16,6 +16,7 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from transformers import AutoConfig, AutoModelForVision2Seq
|
||||||
|
|
||||||
from llamafactory.data import get_template_and_fix_tokenizer
|
from llamafactory.data import get_template_and_fix_tokenizer
|
||||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||||
@ -72,12 +73,17 @@ def test_base_collator():
|
|||||||
|
|
||||||
def test_multimodal_collator():
|
def test_multimodal_collator():
|
||||||
model_args, data_args, *_ = get_infer_args(
|
model_args, data_args, *_ = get_infer_args(
|
||||||
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
|
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
|
||||||
)
|
)
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||||
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||||
|
with torch.device("meta"):
|
||||||
|
model = AutoModelForVision2Seq.from_config(config)
|
||||||
|
|
||||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||||
template=template,
|
template=template,
|
||||||
|
model=model,
|
||||||
pad_to_multiple_of=4,
|
pad_to_multiple_of=4,
|
||||||
label_pad_token_id=IGNORE_INDEX,
|
label_pad_token_id=IGNORE_INDEX,
|
||||||
**tokenizer_module,
|
**tokenizer_module,
|
||||||
@ -107,8 +113,15 @@ def test_multimodal_collator():
|
|||||||
"labels": [
|
"labels": [
|
||||||
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
||||||
],
|
],
|
||||||
|
"position_ids": [
|
||||||
|
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||||
|
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||||
|
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||||
|
],
|
||||||
|
"rope_deltas": [[-8]],
|
||||||
**tokenizer_module["processor"].image_processor(fake_image),
|
**tokenizer_module["processor"].image_processor(fake_image),
|
||||||
}
|
}
|
||||||
|
assert batch_input.keys() == expected_input.keys()
|
||||||
for k in batch_input.keys():
|
for k in batch_input.keys():
|
||||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||||
|
|
||||||
@ -150,3 +163,7 @@ def test_4d_attention_mask():
|
|||||||
)
|
)
|
||||||
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
|
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
|
||||||
assert torch.all(attention_mask_computed == attention_mask_expected)
|
assert torch.all(attention_mask_computed == attention_mask_expected)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_multimodal_collator()
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
# change if test fails or cache is outdated
|
# change if test fails or cache is outdated
|
||||||
0.9.3.107
|
0.9.3.108
|
||||||
|
Loading…
x
Reference in New Issue
Block a user