From 8752280dd7d58e24ef1eb1a776210bc835e232c1 Mon Sep 17 00:00:00 2001 From: luca-888 <55954511+luca-888@users.noreply.github.com> Date: Sun, 3 May 2026 18:36:56 +0800 Subject: [PATCH] [data] Optimize QwenVL video dataset preprocessing (#10404) Co-authored-by: Kingsley --- src/llamafactory/data/mm_plugin.py | 259 ++++++++++++++++++++++++++++- src/llamafactory/hparams/parser.py | 1 - src/llamafactory/model/adapter.py | 1 - tests/data/test_mm_plugin.py | 36 +++- 4 files changed, 291 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 1e4b6db8e..6827ea400 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -22,7 +22,8 @@ import re from copy import deepcopy from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union import numpy as np import torch @@ -245,6 +246,14 @@ class MMPluginMixin: sample_frames = min(total_frames, video_maxlen, sample_frames) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) + def _get_video_token_metadata( + self, + videos: list["VideoInput"], + processor: "MMProcessor", + ) -> Optional[dict[str, Any]]: + r"""Build metadata used to expand video tokens without decoding frames.""" + return None + def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput": r"""Regularize images to avoid error. Including reading and pre-processing.""" results = [] @@ -1747,6 +1756,199 @@ class Qwen2VLPlugin(BasePlugin): "frames_indices": frames_indices, } + def _get_qwen_video_size_after_regularization( + self, width: int, height: int, image_max_pixels: int, image_min_pixels: int + ) -> tuple[int, int]: + r"""Compute the frame size produced by Qwen-VL image regularization.""" + if (width * height) > image_max_pixels: + resize_factor = math.sqrt(image_max_pixels / (width * height)) + width, height = int(width * resize_factor), int(height * resize_factor) + + if (width * height) < image_min_pixels: + resize_factor = math.sqrt(image_min_pixels / (width * height)) + width, height = int(width * resize_factor), int(height * resize_factor) + + if min(width, height) < 28: + width, height = max(width, 28), max(height, 28) + + if width / height > 200: + width, height = height * 180, height + + if height / width > 200: + width, height = width, width * 180 + + return width, height + + def _get_qwen_video_stream_metadata( + self, + video: "VideoInput", + video_fps: float, + video_maxlen: int, + ) -> Optional[dict[str, Any]]: + if not is_pyav_available() or not isinstance(video, (str, os.PathLike)): + return None + + try: + container = av.open(video, "r") + except (av.FFmpegError, OSError): + return None + + try: + video_stream = next((stream for stream in container.streams if stream.type == "video"), None) + if video_stream is None: + return None + + if video_stream.duration is None or video_stream.average_rate is None: + return None + + average_fps = float(video_stream.average_rate) + if average_fps <= 0: + return None + + sample_indices = self._get_video_sample_indices( + video_stream, video_fps=video_fps, video_maxlen=video_maxlen + ) + return { + "width": video_stream.width, + "height": video_stream.height, + "average_fps": average_fps, + "sample_indices": sample_indices, + } + finally: + container.close() + + def _get_qwen_video_resize( + self, + num_frames: int, + height: int, + width: int, + patch_size: int, + temporal_patch_size: int, + merge_size: int, + min_pixels: int, + max_pixels: int, + ) -> tuple[int, int]: + from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize + + return smart_resize( + height=height, + width=width, + factor=patch_size * merge_size, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + + def _get_qwen_video_grid_metadata( + self, + videos: list["VideoInput"], + processor: "MMProcessor", + ) -> Optional[dict[str, Any]]: + if len(videos) == 0: + return {"video_grid_thw": torch.empty((0, 3), dtype=torch.long), "frames_indices": [], "fps": 2.0} + + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) or image_processor + if image_processor is None or video_processor is None: + return None + + patch_size = getattr(video_processor, "patch_size", None) + temporal_patch_size = getattr(video_processor, "temporal_patch_size", None) + merge_size = getattr(video_processor, "merge_size", None) + size = getattr(video_processor, "size", None) + if patch_size is None or temporal_patch_size is None or merge_size is None or size is None: + return None + + if isinstance(size, dict): + min_pixels = size.get("shortest_edge") + max_pixels = size.get("longest_edge") + else: + min_pixels = getattr(size, "shortest_edge", None) + max_pixels = getattr(size, "longest_edge", None) + + if min_pixels is None or max_pixels is None: + return None + + video_fps = getattr(processor, "video_fps", 2.0) + video_maxlen = getattr(processor, "video_maxlen", 128) + image_max_pixels = getattr(processor, "video_max_pixels", 256 * 256) + image_min_pixels = getattr(processor, "video_min_pixels", 16 * 16) + + video_grid_thw = [] + frames_indices = [] + for video in videos: + metadata = self._get_qwen_video_stream_metadata(video, video_fps, video_maxlen) + if metadata is None: + return None + + width, height = self._get_qwen_video_size_after_regularization( + metadata["width"], metadata["height"], image_max_pixels, image_min_pixels + ) + num_frames = len(metadata["sample_indices"]) + if num_frames % 2 != 0: + num_frames += 1 + + resized_size = self._get_qwen_video_resize( + num_frames, + height, + width, + patch_size, + temporal_patch_size, + merge_size, + min_pixels, + max_pixels, + ) + + resized_height, resized_width = resized_size + video_grid_thw.append( + [ + math.ceil(num_frames / temporal_patch_size), + resized_height // patch_size, + resized_width // patch_size, + ] + ) + frames_indices.append([idx / metadata["average_fps"] * video_fps for idx in metadata["sample_indices"]]) + + return { + "video_grid_thw": torch.tensor(video_grid_thw, dtype=torch.long), + "frames_indices": frames_indices, + "fps": video_fps, + } + + @override + def _get_video_token_metadata( + self, + videos: list["VideoInput"], + processor: "MMProcessor", + ) -> Optional[dict[str, Any]]: + video_metadata = self._get_qwen_video_grid_metadata(videos, processor) + if video_metadata is None: + return None + + return {"video_grid_thw": video_metadata["video_grid_thw"]} + + def _get_mm_token_metadata( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> Optional[dict[str, Any]]: + if len(audios) != 0: + return None + + mm_inputs = {} + if len(images) != 0: + mm_inputs.update(self._get_mm_inputs(images, [], [], processor)) + + if len(videos) != 0: + video_inputs = self._get_video_token_metadata(videos, processor) + if video_inputs is None: + return None + + mm_inputs.update(video_inputs) + + return mm_inputs + @override def _get_mm_inputs( self, @@ -1798,7 +2000,10 @@ class Qwen2VLPlugin(BasePlugin): merge_length: int = getattr(image_processor, "merge_size") ** 2 if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor) + if mm_inputs is None: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) video_grid_thw = mm_inputs.get("video_grid_thw", []) else: @@ -1832,6 +2037,51 @@ class Qwen2VLPlugin(BasePlugin): @dataclass class Qwen3VLPlugin(Qwen2VLPlugin): + @override + def _get_qwen_video_resize( + self, + num_frames: int, + height: int, + width: int, + patch_size: int, + temporal_patch_size: int, + merge_size: int, + min_pixels: int, + max_pixels: int, + ) -> tuple[int, int]: + from transformers.models.qwen3_vl.video_processing_qwen3_vl import smart_resize + + return smart_resize( + num_frames=num_frames, + height=height, + width=width, + temporal_factor=temporal_patch_size, + factor=patch_size * merge_size, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + + @override + def _get_video_token_metadata( + self, + videos: list["VideoInput"], + processor: "MMProcessor", + ) -> Optional[dict[str, Any]]: + video_metadata = self._get_qwen_video_grid_metadata(videos, processor) + if video_metadata is None: + return None + + return { + "video_grid_thw": video_metadata["video_grid_thw"], + "video_metadata": [ + SimpleNamespace( + frames_indices=frames_indices, + fps=video_metadata["fps"], + ) + for frames_indices in video_metadata["frames_indices"] + ], + } + @override def _get_mm_inputs( self, @@ -1904,7 +2154,10 @@ class Qwen3VLPlugin(Qwen2VLPlugin): image_merge_length: int = getattr(image_processor, "merge_size") ** 2 video_merge_length: int = getattr(video_processor, "merge_size") ** 2 if self.expand_mm_tokens: - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor) + if mm_inputs is None: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) video_grid_thw = mm_inputs.get("video_grid_thw", []) num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index f8bfe4868..c946cceb9 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -186,7 +186,6 @@ def _verify_model_args( raise ValueError("Quantized model only accepts a single adapter. Merge them first.") - def _check_extra_dependencies( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 70efb6acb..b8c3900d7 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -20,7 +20,6 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft from transformers.integrations import is_deepspeed_zero3_enabled from ..extras import logging -from ..extras.constants import EngineName from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.quantization import QuantizationMethod from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 7dd792e06..e63c866db 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -21,7 +21,7 @@ import torch from PIL import Image from llamafactory.data.mm_plugin import get_mm_plugin -from llamafactory.extras.packages import is_transformers_version_greater_than +from llamafactory.extras.packages import is_pyav_available, is_transformers_version_greater_than from llamafactory.hparams import get_infer_args from llamafactory.model import load_tokenizer @@ -439,6 +439,40 @@ def test_qwen3_vl_plugin(): _check_plugin(**check_inputs) +@pytest.mark.runs_on(["cpu", "mps"]) +@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0") +@pytest.mark.skipif(not is_pyav_available(), reason="Requires pyav") +def test_qwen3_vl_plugin_video_path(): + video_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "data", "mllm_demo_data", "1.mp4") + video_path = os.path.abspath(video_path) + if not os.path.exists(video_path): + pytest.skip(f"Video file not found: {video_path}") + + tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-30B-A3B-Instruct") + processor = tokenizer_module["processor"] + qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>") + + videos = [video_path] + + # fast path: metadata-only, no frame decoding + fast_mm_inputs = qwen3_vl_plugin._get_mm_token_metadata([], videos, [], processor) + assert fast_mm_inputs is not None, "_get_mm_token_metadata should succeed for a real video file" + + full_mm_inputs = qwen3_vl_plugin._get_mm_inputs([], videos, [], processor) + + # video_grid_thw must be identical between the two paths + assert torch.equal(fast_mm_inputs["video_grid_thw"], full_mm_inputs["video_grid_thw"]), ( + f"video_grid_thw mismatch between fast path and full path: " + f"fast={fast_mm_inputs['video_grid_thw']}, full={full_mm_inputs['video_grid_thw']}" + ) + result = qwen3_vl_plugin.process_messages(VIDEO_MESSAGES, [], videos, [], processor) + # This demo video duration is 9.72s, with video_fps=2, we extract 19 frames + # 19 + 1 => temperoal compress => 10 video_sequence + assert result[0]["content"].count("<|vision_start|>") == 10, ( + f"Expected 10 video tokens, got {result[0]['content'].count('<|vision_start|>')}" + ) + + @pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0") def test_video_llava_plugin():