[data] fix qwen2vl pos ids (#8387)

This commit is contained in:
Yaowei Zheng 2025-06-17 00:48:54 +08:00 committed by GitHub
parent 9f2f12b0fe
commit af2f75e688
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 85 additions and 35 deletions

View File

@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
import numpy as np
import torch
import torch.nn.functional as F
from peft import PeftModel
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
@ -94,6 +95,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.")
if isinstance(self.model, PeftModel):
self.model = self.model.base_model.model
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni
elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0
else:
self.get_rope_func = None
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
@ -171,7 +182,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: dict[str, torch.Tensor] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
if self.get_rope_func is not None:
rope_index_kwargs = {
"input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"),
@ -180,27 +191,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
}
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
if "video_second_per_grid" in mm_inputs: # for qwen2omni
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
if feature_attention_mask is not None:
audio_feature_lengths = torch.sum(
feature_attention_mask, dim=1
) # FIXME need to get video image lengths
if feature_attention_mask is not None: # FIXME: need to get video image lengths
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
# avoid conflict
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
features["position_ids"], features["rope_deltas"] = (
new_position_ids.clone(),
rope_deltas - delta0,
) # avoid inplace operation FIXME
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
dim=-1
).unsqueeze(-1)
else: # for qwen2vl
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
if (
self.model is not None
and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
and ("position_ids" not in features or features["position_ids"].dim() != 3)
):
raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")

View File

@ -1609,13 +1609,13 @@ register_model_group(
register_model_group(
models={
"Mistral-Small-24B-Base-2503": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2503",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2503",
"Mistral-Small-3.1-24B-Base": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Base-2503",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Base-2503",
},
"Mistral-Small-24B-Instruct-2503": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2503",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2503",
"Mistral-Small-3.1-24B-Instruct": {
DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
},
},
template="mistral_small",

View File

@ -169,10 +169,15 @@ def _check_extra_dependencies(
if finetuning_args.plot_loss:
check_version("matplotlib", mandatory=True)
if training_args is not None and training_args.predict_with_generate:
check_version("jieba", mandatory=True)
check_version("nltk", mandatory=True)
check_version("rouge_chinese", mandatory=True)
if training_args is not None:
if training_args.deepspeed:
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True)
if training_args.predict_with_generate:
check_version("jieba", mandatory=True)
check_version("nltk", mandatory=True)
check_version("rouge_chinese", mandatory=True)
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:

View File

@ -86,10 +86,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
padding_side="right",
**init_kwargs,
)
except ValueError: # try the fast one
except ValueError: # try another one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
use_fast=not model_args.use_fast_tokenizer,
padding_side="right",
**init_kwargs,
)
@ -97,12 +97,23 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
raise OSError("Failed to load tokenizer.") from e
patch_tokenizer(tokenizer, model_args)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, tokenizer, model_args)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
**init_kwargs,
)
except ValueError: # try another one
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
use_fast=not model_args.use_fast_tokenizer,
**init_kwargs,
)
except Exception as e:
logger.info_rank0(f"Failed to load processor: {e}.")
processor = None
raise OSError("Failed to load processor.") from e
patch_processor(processor, tokenizer, model_args)
# Avoid load tokenizer, see:
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324

View File

@ -287,7 +287,9 @@ _register_composite_model(
model_type="qwen2_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
language_model_keys=["language_model", "lm_head"]
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
@ -296,6 +298,8 @@ _register_composite_model(
model_type="qwen2_5_vl",
projector_key="visual.merger",
vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
language_model_keys=["language_model", "lm_head"]
if is_transformers_version_greater_than("4.52.0")
else ["model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)

View File

@ -16,6 +16,7 @@ import os
import torch
from PIL import Image
from transformers import AutoConfig, AutoModelForVision2Seq
from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
@ -72,12 +73,17 @@ def test_base_collator():
def test_multimodal_collator():
model_args, data_args, *_ = get_infer_args(
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
)
tokenizer_module = load_tokenizer(model_args)
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"):
model = AutoModelForVision2Seq.from_config(config)
data_collator = MultiModalDataCollatorForSeq2Seq(
template=template,
model=model,
pad_to_multiple_of=4,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
@ -107,8 +113,15 @@ def test_multimodal_collator():
"labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
],
"position_ids": [
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
],
"rope_deltas": [[-8]],
**tokenizer_module["processor"].image_processor(fake_image),
}
assert batch_input.keys() == expected_input.keys()
for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
@ -150,3 +163,7 @@ def test_4d_attention_mask():
)
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
assert torch.all(attention_mask_computed == attention_mask_expected)
if __name__ == "__main__":
test_multimodal_collator()

View File

@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.3.107
0.9.3.108