From af2f75e6883b2aef1855e56cb5e0fb1e0d48d721 Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Tue, 17 Jun 2025 00:48:54 +0800 Subject: [PATCH] [data] fix qwen2vl pos ids (#8387) --- src/llamafactory/data/collator.py | 43 +++++++++++++------- src/llamafactory/extras/constants.py | 12 +++--- src/llamafactory/hparams/parser.py | 13 ++++-- src/llamafactory/model/loader.py | 23 ++++++++--- src/llamafactory/model/model_utils/visual.py | 8 +++- tests/data/test_collator.py | 19 ++++++++- tests/version.txt | 2 +- 7 files changed, 85 insertions(+), 35 deletions(-) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 3fb08f4b..b749aaef 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional import numpy as np import torch import torch.nn.functional as F +from peft import PeftModel from transformers import DataCollatorForSeq2Seq from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER @@ -94,6 +95,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if self.template is None: raise ValueError("Template is required for MultiModalDataCollator.") + if isinstance(self.model, PeftModel): + self.model = self.model.base_model.model + + if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope + self.get_rope_func = self.model.get_rope_index # transformers < 4.52.0 or qwen2.5 omni + elif self.model is not None and hasattr(self.model, "model") and hasattr(self.model.model, "get_rope_index"): + self.get_rope_func = self.model.model.get_rope_index # transformers >= 4.52.0 + else: + self.get_rope_func = None + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: batch_images, batch_videos, batch_audios = [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] @@ -171,7 +182,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features: dict[str, torch.Tensor] = super().__call__(features) - if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope + if self.get_rope_func is not None: rope_index_kwargs = { "input_ids": features["input_ids"], "image_grid_thw": mm_inputs.get("image_grid_thw"), @@ -180,27 +191,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): } if "second_per_grid_ts" in mm_inputs: # for qwen2vl rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") - if "video_second_per_grid" in mm_inputs: # for qwen2omni + elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid") - if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2omni + if getattr(self.model.config, "model_type", None) == "qwen2_5_omni_thinker": # for qwen2.5 omni rope_index_kwargs["use_audio_in_video"] = getattr(self.processor, "use_audio_in_video", False) feature_attention_mask = mm_inputs.get("feature_attention_mask", None) - if feature_attention_mask is not None: - audio_feature_lengths = torch.sum( - feature_attention_mask, dim=1 - ) # FIXME need to get video image lengths + if feature_attention_mask is not None: # FIXME: need to get video image lengths + audio_feature_lengths = torch.sum(feature_attention_mask, dim=1) rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input - delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1) - # avoid conflict - new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs) - features["position_ids"], features["rope_deltas"] = ( - new_position_ids.clone(), - rope_deltas - delta0, - ) # avoid inplace operation FIXME + features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs) + features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum( + dim=-1 + ).unsqueeze(-1) else: # for qwen2vl - features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs) + features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs) + + if ( + self.model is not None + and getattr(self.model.config, "model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen2_5_omni_thinker"] + and ("position_ids" not in features or features["position_ids"].dim() != 3) + ): + raise ValueError("Qwen2-VL/Qwen2.5-Omni model requires 3D position ids for mrope.") if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled cross_attention_mask = mm_inputs.pop("cross_attention_mask") diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index da19ebb7..f582d1f0 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1609,13 +1609,13 @@ register_model_group( register_model_group( models={ - "Mistral-Small-24B-Base-2503": { - DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Base-2503", - DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Base-2503", + "Mistral-Small-3.1-24B-Base": { + DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Base-2503", + DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Base-2503", }, - "Mistral-Small-24B-Instruct-2503": { - DownloadSource.DEFAULT: "mistralai/Mistral-Small-24B-Instruct-2503", - DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-24B-Instruct-2503", + "Mistral-Small-3.1-24B-Instruct": { + DownloadSource.DEFAULT: "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + DownloadSource.MODELSCOPE: "mistralai/Mistral-Small-3.1-24B-Instruct-2503", }, }, template="mistral_small", diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 7b0c0476..8f210b7c 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -169,10 +169,15 @@ def _check_extra_dependencies( if finetuning_args.plot_loss: check_version("matplotlib", mandatory=True) - if training_args is not None and training_args.predict_with_generate: - check_version("jieba", mandatory=True) - check_version("nltk", mandatory=True) - check_version("rouge_chinese", mandatory=True) + if training_args is not None: + if training_args.deepspeed: + # pin deepspeed version < 0.17 because of https://github.com/deepspeedai/DeepSpeed/issues/7347 + check_version("deepspeed>=0.10.0,<=0.16.9", mandatory=True) + + if training_args.predict_with_generate: + check_version("jieba", mandatory=True) + check_version("nltk", mandatory=True) + check_version("rouge_chinese", mandatory=True) def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index cbcc6b28..7ed4230a 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -86,10 +86,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": padding_side="right", **init_kwargs, ) - except ValueError: # try the fast one + except ValueError: # try another one tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, - use_fast=True, + use_fast=not model_args.use_fast_tokenizer, padding_side="right", **init_kwargs, ) @@ -97,12 +97,23 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": raise OSError("Failed to load tokenizer.") from e patch_tokenizer(tokenizer, model_args) + try: - processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) - patch_processor(processor, tokenizer, model_args) + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + **init_kwargs, + ) + except ValueError: # try another one + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + use_fast=not model_args.use_fast_tokenizer, + **init_kwargs, + ) except Exception as e: - logger.info_rank0(f"Failed to load processor: {e}.") - processor = None + raise OSError("Failed to load processor.") from e + + patch_processor(processor, tokenizer, model_args) # Avoid load tokenizer, see: # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324 diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 9d4e535a..a8010579 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -287,7 +287,9 @@ _register_composite_model( model_type="qwen2_vl", projector_key="visual.merger", vision_model_keys=["visual.patch_embed", "visual.blocks"], - language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"], + language_model_keys=["language_model", "lm_head"] + if is_transformers_version_greater_than("4.52.0") + else ["model", "lm_head"], lora_conflict_keys=["patch_embed"], ) @@ -296,6 +298,8 @@ _register_composite_model( model_type="qwen2_5_vl", projector_key="visual.merger", vision_model_keys=["visual.patch_embed", "visual.blocks"], - language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"], + language_model_keys=["language_model", "lm_head"] + if is_transformers_version_greater_than("4.52.0") + else ["model", "lm_head"], lora_conflict_keys=["patch_embed"], ) diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py index a263c0f8..657f280d 100644 --- a/tests/data/test_collator.py +++ b/tests/data/test_collator.py @@ -16,6 +16,7 @@ import os import torch from PIL import Image +from transformers import AutoConfig, AutoModelForVision2Seq from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask @@ -72,12 +73,17 @@ def test_base_collator(): def test_multimodal_collator(): model_args, data_args, *_ = get_infer_args( - {"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"} + {"model_name_or_path": "Qwen/Qwen2-VL-2B-Instruct", "template": "qwen2_vl"} ) tokenizer_module = load_tokenizer(model_args) template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + with torch.device("meta"): + model = AutoModelForVision2Seq.from_config(config) + data_collator = MultiModalDataCollatorForSeq2Seq( template=template, + model=model, pad_to_multiple_of=4, label_pad_token_id=IGNORE_INDEX, **tokenizer_module, @@ -107,8 +113,15 @@ def test_multimodal_collator(): "labels": [ [0, 1, 2, 3, q, q, q, q, q, q, q, q], ], + "position_ids": [ + [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]], + [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]], + [[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]], + ], + "rope_deltas": [[-8]], **tokenizer_module["processor"].image_processor(fake_image), } + assert batch_input.keys() == expected_input.keys() for k in batch_input.keys(): assert batch_input[k].eq(torch.tensor(expected_input[k])).all() @@ -150,3 +163,7 @@ def test_4d_attention_mask(): ) assert list(attention_mask_computed.size()) == [2, 1, 6, 6] assert torch.all(attention_mask_computed == attention_mask_expected) + + +if __name__ == "__main__": + test_multimodal_collator() diff --git a/tests/version.txt b/tests/version.txt index dae5ebba..0f1383aa 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.3.107 +0.9.3.108