mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-31 10:42:50 +08:00
[data] fix qwen2vl pos ids (#8387)
This commit is contained in:
parent
9f2f12b0fe
commit
af2f75e688
@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from peft import PeftModel
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
@ -94,6 +95,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
if self.template is None:
|
||||
raise ValueError("Template is required for MultiModalDataCollator.")
|
||||
|
||||
if isinstance(self.model, PeftModel):
|
||||
self.model = self.model.base_model.model
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni
|
||||
elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
|
||||
self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0
|
||||
else:
|
||||
self.get_rope_func = None
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
|
||||
batch_images, batch_videos, batch_audios = [], [], []
|
||||
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
|
||||
@ -171,7 +182,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||
|
||||
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
|
||||
if self.get_rope_func is not None:
|
||||
rope_index_kwargs = {
|
||||
"input_ids": features["input_ids"],
|
||||
"image_grid_thw": mm_inputs.get("image_grid_thw"),
|
||||
@ -180,27 +191,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
}
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
if "video_second_per_grid" in mm_inputs: # for qwen2omni
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
|
||||
|
||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni
|
||||
if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni
|
||||
rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
|
||||
feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
|
||||
if feature_attention_mask is not None:
|
||||
audio_feature_lengths = torch.sum(
|
||||
feature_attention_mask, dim=1
|
||||
) # FIXME need to get video image lengths
|
||||
if feature_attention_mask is not None: # FIXME: need to get video image lengths
|
||||
audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
|
||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||
|
||||
delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
|
||||
# avoid conflict
|
||||
new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
|
||||
features["position_ids"], features["rope_deltas"] = (
|
||||
new_position_ids.clone(),
|
||||
rope_deltas - delta0,
|
||||
) # avoid inplace operation FIXME
|
||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
||||
dim=-1
|
||||
).unsqueeze(-1)
|
||||
else: # for qwen2vl
|
||||
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
|
||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||
|
||||
if (
|
||||
self.model is not None
|
||||
and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
|
||||
and ("position_ids" not in features or features["position_ids"].dim() != 3)
|
||||
):
|
||||
raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")
|
||||
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
|
@ -1609,13 +1609,13 @@ register_model_group(
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Mistral-Small-24B-Base-2503": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2503",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2503",
|
||||
"Mistral-Small-3.1-24B-Base": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Base-2503",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Base-2503",
|
||||
},
|
||||
"Mistral-Small-24B-Instruct-2503": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2503",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2503",
|
||||
"Mistral-Small-3.1-24B-Instruct": {
|
||||
DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||
DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
|
||||
},
|
||||
},
|
||||
template="mistral_small",
|
||||
|
@ -169,10 +169,15 @@ def _check_extra_dependencies(
|
||||
if finetuning_args.plot_loss:
|
||||
check_version("matplotlib", mandatory=True)
|
||||
|
||||
if training_args is not None and training_args.predict_with_generate:
|
||||
check_version("jieba", mandatory=True)
|
||||
check_version("nltk", mandatory=True)
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
if training_args is not None:
|
||||
if training_args.deepspeed:
|
||||
# pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
|
||||
check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True)
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
check_version("jieba", mandatory=True)
|
||||
check_version("nltk", mandatory=True)
|
||||
check_version("rouge_chinese", mandatory=True)
|
||||
|
||||
|
||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||
|
@ -86,10 +86,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
except ValueError: # try the fast one
|
||||
except ValueError: # try another one
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=True,
|
||||
use_fast=not model_args.use_fast_tokenizer,
|
||||
padding_side="right",
|
||||
**init_kwargs,
|
||||
)
|
||||
@ -97,12 +97,23 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
||||
raise OSError("Failed to load tokenizer.") from e
|
||||
|
||||
patch_tokenizer(tokenizer, model_args)
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||
patch_processor(processor, tokenizer, model_args)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
**init_kwargs,
|
||||
)
|
||||
except ValueError: # try another one
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=not model_args.use_fast_tokenizer,
|
||||
**init_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info_rank0(f"Failed to load processor: {e}.")
|
||||
processor = None
|
||||
raise OSError("Failed to load processor.") from e
|
||||
|
||||
patch_processor(processor, tokenizer, model_args)
|
||||
|
||||
# Avoid load tokenizer, see:
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
|
||||
|
@ -287,7 +287,9 @@ _register_composite_model(
|
||||
model_type="qwen2_vl",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
|
||||
language_model_keys=["language_model", "lm_head"]
|
||||
if is_transformers_version_greater_than("4.52.0")
|
||||
else ["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
@ -296,6 +298,8 @@ _register_composite_model(
|
||||
model_type="qwen2_5_vl",
|
||||
projector_key="visual.merger",
|
||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
|
||||
language_model_keys=["language_model", "lm_head"]
|
||||
if is_transformers_version_greater_than("4.52.0")
|
||||
else ["model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
@ -16,6 +16,7 @@ import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoConfig, AutoModelForVision2Seq
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||
@ -72,12 +73,17 @@ def test_base_collator():
|
||||
|
||||
def test_multimodal_collator():
|
||||
model_args, data_args, *_ = get_infer_args(
|
||||
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
|
||||
{"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
model = AutoModelForVision2Seq.from_config(config)
|
||||
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
model=model,
|
||||
pad_to_multiple_of=4,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
@ -107,8 +113,15 @@ def test_multimodal_collator():
|
||||
"labels": [
|
||||
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
||||
],
|
||||
"position_ids": [
|
||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
],
|
||||
"rope_deltas": [[-8]],
|
||||
**tokenizer_module["processor"].image_processor(fake_image),
|
||||
}
|
||||
assert batch_input.keys() == expected_input.keys()
|
||||
for k in batch_input.keys():
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
@ -150,3 +163,7 @@ def test_4d_attention_mask():
|
||||
)
|
||||
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
|
||||
assert torch.all(attention_mask_computed == attention_mask_expected)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_multimodal_collator()
|
||||
|
@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.3.107
|
||||
0.9.3.108
|
||||
|
Loading…
x
Reference in New Issue
Block a user