mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] fix qwen2vl pos ids (#8387)
This commit is contained in:
		
							parent
							
								
									31874e4f62
								
							
						
					
					
						commit
						3a3bae1cfe
					
				@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from peft import PeftModel
 | 
			
		||||
from transformers import DataCollatorForSeq2Seq
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER
 | 
			
		||||
@ -94,6 +95,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
        if self.template is None:
 | 
			
		||||
            raise ValueError("Template is required for MultiModalDataCollator.")
 | 
			
		||||
 | 
			
		||||
        if isinstance(self.model, PeftModel):
 | 
			
		||||
            self.model = self.model.base_model.model
 | 
			
		||||
 | 
			
		||||
        if self.model is not None and hasattr(self.model, "get_rope_index"):  # for qwen2vl mrope
 | 
			
		||||
            self.get_rope_func = self.model.get_rope_index  # transformers < 4.52.0 or qwen2.5 omni
 | 
			
		||||
        elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"):
 | 
			
		||||
            self.get_rope_func = self.model.model.get_rope_index  # transformers >= 4.52.0
 | 
			
		||||
        else:
 | 
			
		||||
            self.get_rope_func = None
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        batch_images, batch_videos, batch_audios = [], [], []
 | 
			
		||||
        batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
 | 
			
		||||
@ -171,7 +182,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
 | 
			
		||||
        features: dict[str, torch.Tensor] = super().__call__(features)
 | 
			
		||||
 | 
			
		||||
        if self.model is not None and hasattr(self.model, "get_rope_index"):  # for qwen2vl mrope
 | 
			
		||||
        if self.get_rope_func is not None:
 | 
			
		||||
            rope_index_kwargs = {
 | 
			
		||||
                "input_ids": features["input_ids"],
 | 
			
		||||
                "image_grid_thw": mm_inputs.get("image_grid_thw"),
 | 
			
		||||
@ -180,27 +191,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
            }
 | 
			
		||||
            if "second_per_grid_ts" in mm_inputs:  # for qwen2vl
 | 
			
		||||
                rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
 | 
			
		||||
            if "video_second_per_grid" in mm_inputs:  # for qwen2omni
 | 
			
		||||
            elif "video_second_per_grid" in mm_inputs:  # for qwen2.5 omni
 | 
			
		||||
                rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid")
 | 
			
		||||
 | 
			
		||||
            if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker":  # for qwen2omni
 | 
			
		||||
            if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker":  # for qwen2.5 omni
 | 
			
		||||
                rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False)
 | 
			
		||||
                feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
 | 
			
		||||
                if feature_attention_mask is not None:
 | 
			
		||||
                    audio_feature_lengths = torch.sum(
 | 
			
		||||
                        feature_attention_mask, dim=1
 | 
			
		||||
                    )  # FIXME need to get video image lengths
 | 
			
		||||
                if feature_attention_mask is not None:  # FIXME: need to get video image lengths
 | 
			
		||||
                    audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
 | 
			
		||||
                    rope_index_kwargs["audio_seqlens"] = audio_feature_lengths  # prepare for input
 | 
			
		||||
 | 
			
		||||
                delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
 | 
			
		||||
                # avoid conflict
 | 
			
		||||
                new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
 | 
			
		||||
                features["position_ids"], features["rope_deltas"] = (
 | 
			
		||||
                    new_position_ids.clone(),
 | 
			
		||||
                    rope_deltas - delta0,
 | 
			
		||||
                )  # avoid inplace operation FIXME
 | 
			
		||||
                features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
 | 
			
		||||
                features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
 | 
			
		||||
                    dim=-1
 | 
			
		||||
                ).unsqueeze(-1)
 | 
			
		||||
            else:  # for qwen2vl
 | 
			
		||||
                features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
 | 
			
		||||
                features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            self.model is not None
 | 
			
		||||
            and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"]
 | 
			
		||||
            and ("position_ids" not in features or features["position_ids"].dim() != 3)
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.")
 | 
			
		||||
 | 
			
		||||
        if "cross_attention_mask" in mm_inputs:  # for mllama inputs when pad_to_multiple_of is enabled
 | 
			
		||||
            cross_attention_mask = mm_inputs.pop("cross_attention_mask")
 | 
			
		||||
 | 
			
		||||
@ -1609,13 +1609,13 @@ register_model_group(
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Mistral-Small-24B-Base-2503": {
 | 
			
		||||
            DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2503",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2503",
 | 
			
		||||
        "Mistral-Small-3.1-24B-Base": {
 | 
			
		||||
            DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Base-2503",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Base-2503",
 | 
			
		||||
        },
 | 
			
		||||
        "Mistral-Small-24B-Instruct-2503": {
 | 
			
		||||
            DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2503",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2503",
 | 
			
		||||
        "Mistral-Small-3.1-24B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="mistral_small",
 | 
			
		||||
 | 
			
		||||
@ -169,10 +169,15 @@ def _check_extra_dependencies(
 | 
			
		||||
    if finetuning_args.plot_loss:
 | 
			
		||||
        check_version("matplotlib", mandatory=True)
 | 
			
		||||
 | 
			
		||||
    if training_args is not None and training_args.predict_with_generate:
 | 
			
		||||
        check_version("jieba", mandatory=True)
 | 
			
		||||
        check_version("nltk", mandatory=True)
 | 
			
		||||
        check_version("rouge_chinese", mandatory=True)
 | 
			
		||||
    if training_args is not None:
 | 
			
		||||
        if training_args.deepspeed:
 | 
			
		||||
            # pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347
 | 
			
		||||
            check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True)
 | 
			
		||||
 | 
			
		||||
        if training_args.predict_with_generate:
 | 
			
		||||
            check_version("jieba", mandatory=True)
 | 
			
		||||
            check_version("nltk", mandatory=True)
 | 
			
		||||
            check_version("rouge_chinese", mandatory=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
 | 
			
		||||
 | 
			
		||||
@ -86,10 +86,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
 | 
			
		||||
            padding_side="right",
 | 
			
		||||
            **init_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
    except ValueError:  # try the fast one
 | 
			
		||||
    except ValueError:  # try another one
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
            model_args.model_name_or_path,
 | 
			
		||||
            use_fast=True,
 | 
			
		||||
            use_fast=not model_args.use_fast_tokenizer,
 | 
			
		||||
            padding_side="right",
 | 
			
		||||
            **init_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
@ -97,12 +97,23 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
 | 
			
		||||
        raise OSError("Failed to load tokenizer.") from e
 | 
			
		||||
 | 
			
		||||
    patch_tokenizer(tokenizer, model_args)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
 | 
			
		||||
        patch_processor(processor, tokenizer, model_args)
 | 
			
		||||
        processor = AutoProcessor.from_pretrained(
 | 
			
		||||
            model_args.model_name_or_path,
 | 
			
		||||
            use_fast=model_args.use_fast_tokenizer,
 | 
			
		||||
            **init_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
    except ValueError:  # try another one
 | 
			
		||||
        processor = AutoProcessor.from_pretrained(
 | 
			
		||||
            model_args.model_name_or_path,
 | 
			
		||||
            use_fast=not model_args.use_fast_tokenizer,
 | 
			
		||||
            **init_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.info_rank0(f"Failed to load processor: {e}.")
 | 
			
		||||
        processor = None
 | 
			
		||||
        raise OSError("Failed to load processor.") from e
 | 
			
		||||
 | 
			
		||||
    patch_processor(processor, tokenizer, model_args)
 | 
			
		||||
 | 
			
		||||
    # Avoid load tokenizer, see:
 | 
			
		||||
    # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324
 | 
			
		||||
 | 
			
		||||
@ -287,7 +287,9 @@ _register_composite_model(
 | 
			
		||||
    model_type="qwen2_vl",
 | 
			
		||||
    projector_key="visual.merger",
 | 
			
		||||
    vision_model_keys=["visual.patch_embed", "visual.blocks"],
 | 
			
		||||
    language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
 | 
			
		||||
    language_model_keys=["language_model", "lm_head"]
 | 
			
		||||
    if is_transformers_version_greater_than("4.52.0")
 | 
			
		||||
    else ["model", "lm_head"],
 | 
			
		||||
    lora_conflict_keys=["patch_embed"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -296,6 +298,8 @@ _register_composite_model(
 | 
			
		||||
    model_type="qwen2_5_vl",
 | 
			
		||||
    projector_key="visual.merger",
 | 
			
		||||
    vision_model_keys=["visual.patch_embed", "visual.blocks"],
 | 
			
		||||
    language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
 | 
			
		||||
    language_model_keys=["language_model", "lm_head"]
 | 
			
		||||
    if is_transformers_version_greater_than("4.52.0")
 | 
			
		||||
    else ["model", "lm_head"],
 | 
			
		||||
    lora_conflict_keys=["patch_embed"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ import os
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from transformers import AutoConfig, AutoModelForVision2Seq
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_template_and_fix_tokenizer
 | 
			
		||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
 | 
			
		||||
@ -72,12 +73,17 @@ def test_base_collator():
 | 
			
		||||
 | 
			
		||||
def test_multimodal_collator():
 | 
			
		||||
    model_args, data_args, *_ = get_infer_args(
 | 
			
		||||
        {"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
 | 
			
		||||
        {"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"}
 | 
			
		||||
    )
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
 | 
			
		||||
    with torch.device("meta"):
 | 
			
		||||
        model = AutoModelForVision2Seq.from_config(config)
 | 
			
		||||
 | 
			
		||||
    data_collator = MultiModalDataCollatorForSeq2Seq(
 | 
			
		||||
        template=template,
 | 
			
		||||
        model=model,
 | 
			
		||||
        pad_to_multiple_of=4,
 | 
			
		||||
        label_pad_token_id=IGNORE_INDEX,
 | 
			
		||||
        **tokenizer_module,
 | 
			
		||||
@ -107,8 +113,15 @@ def test_multimodal_collator():
 | 
			
		||||
        "labels": [
 | 
			
		||||
            [0, 1, 2, 3, q, q, q, q, q, q, q, q],
 | 
			
		||||
        ],
 | 
			
		||||
        "position_ids": [
 | 
			
		||||
            [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
 | 
			
		||||
            [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
 | 
			
		||||
            [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
 | 
			
		||||
        ],
 | 
			
		||||
        "rope_deltas": [[-8]],
 | 
			
		||||
        **tokenizer_module["processor"].image_processor(fake_image),
 | 
			
		||||
    }
 | 
			
		||||
    assert batch_input.keys() == expected_input.keys()
 | 
			
		||||
    for k in batch_input.keys():
 | 
			
		||||
        assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
 | 
			
		||||
 | 
			
		||||
@ -150,3 +163,7 @@ def test_4d_attention_mask():
 | 
			
		||||
    )
 | 
			
		||||
    assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
 | 
			
		||||
    assert torch.all(attention_mask_computed == attention_mask_expected)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    test_multimodal_collator()
 | 
			
		||||
 | 
			
		||||
@ -1,2 +1,2 @@
 | 
			
		||||
# change if test fails or cache is outdated
 | 
			
		||||
0.9.3.107
 | 
			
		||||
0.9.3.108
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user