[data] Optimize QwenVL video dataset preprocessing (#10404)

Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
luca-888
2026-05-03 18:36:56 +08:00
committed by GitHub
parent 468723c5d9
commit 8752280dd7
4 changed files with 291 additions and 6 deletions

View File

@@ -22,7 +22,8 @@ import re
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
import numpy as np
import torch
@@ -245,6 +246,14 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _get_video_token_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
r"""Build metadata used to expand video tokens without decoding frames."""
return None
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
r"""Regularize images to avoid error. Including reading and pre-processing."""
results = []
@@ -1747,6 +1756,199 @@ class Qwen2VLPlugin(BasePlugin):
"frames_indices": frames_indices,
}
def _get_qwen_video_size_after_regularization(
self, width: int, height: int, image_max_pixels: int, image_min_pixels: int
) -> tuple[int, int]:
r"""Compute the frame size produced by Qwen-VL image regularization."""
if (width * height) > image_max_pixels:
resize_factor = math.sqrt(image_max_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if (width * height) < image_min_pixels:
resize_factor = math.sqrt(image_min_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if min(width, height) < 28:
width, height = max(width, 28), max(height, 28)
if width / height > 200:
width, height = height * 180, height
if height / width > 200:
width, height = width, width * 180
return width, height
def _get_qwen_video_stream_metadata(
self,
video: "VideoInput",
video_fps: float,
video_maxlen: int,
) -> Optional[dict[str, Any]]:
if not is_pyav_available() or not isinstance(video, (str, os.PathLike)):
return None
try:
container = av.open(video, "r")
except (av.FFmpegError, OSError):
return None
try:
video_stream = next((stream for stream in container.streams if stream.type == "video"), None)
if video_stream is None:
return None
if video_stream.duration is None or video_stream.average_rate is None:
return None
average_fps = float(video_stream.average_rate)
if average_fps <= 0:
return None
sample_indices = self._get_video_sample_indices(
video_stream, video_fps=video_fps, video_maxlen=video_maxlen
)
return {
"width": video_stream.width,
"height": video_stream.height,
"average_fps": average_fps,
"sample_indices": sample_indices,
}
finally:
container.close()
def _get_qwen_video_resize(
self,
num_frames: int,
height: int,
width: int,
patch_size: int,
temporal_patch_size: int,
merge_size: int,
min_pixels: int,
max_pixels: int,
) -> tuple[int, int]:
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
return smart_resize(
height=height,
width=width,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
def _get_qwen_video_grid_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
if len(videos) == 0:
return {"video_grid_thw": torch.empty((0, 3), dtype=torch.long), "frames_indices": [], "fps": 2.0}
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) or image_processor
if image_processor is None or video_processor is None:
return None
patch_size = getattr(video_processor, "patch_size", None)
temporal_patch_size = getattr(video_processor, "temporal_patch_size", None)
merge_size = getattr(video_processor, "merge_size", None)
size = getattr(video_processor, "size", None)
if patch_size is None or temporal_patch_size is None or merge_size is None or size is None:
return None
if isinstance(size, dict):
min_pixels = size.get("shortest_edge")
max_pixels = size.get("longest_edge")
else:
min_pixels = getattr(size, "shortest_edge", None)
max_pixels = getattr(size, "longest_edge", None)
if min_pixels is None or max_pixels is None:
return None
video_fps = getattr(processor, "video_fps", 2.0)
video_maxlen = getattr(processor, "video_maxlen", 128)
image_max_pixels = getattr(processor, "video_max_pixels", 256 * 256)
image_min_pixels = getattr(processor, "video_min_pixels", 16 * 16)
video_grid_thw = []
frames_indices = []
for video in videos:
metadata = self._get_qwen_video_stream_metadata(video, video_fps, video_maxlen)
if metadata is None:
return None
width, height = self._get_qwen_video_size_after_regularization(
metadata["width"], metadata["height"], image_max_pixels, image_min_pixels
)
num_frames = len(metadata["sample_indices"])
if num_frames % 2 != 0:
num_frames += 1
resized_size = self._get_qwen_video_resize(
num_frames,
height,
width,
patch_size,
temporal_patch_size,
merge_size,
min_pixels,
max_pixels,
)
resized_height, resized_width = resized_size
video_grid_thw.append(
[
math.ceil(num_frames / temporal_patch_size),
resized_height // patch_size,
resized_width // patch_size,
]
)
frames_indices.append([idx / metadata["average_fps"] * video_fps for idx in metadata["sample_indices"]])
return {
"video_grid_thw": torch.tensor(video_grid_thw, dtype=torch.long),
"frames_indices": frames_indices,
"fps": video_fps,
}
@override
def _get_video_token_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
if video_metadata is None:
return None
return {"video_grid_thw": video_metadata["video_grid_thw"]}
def _get_mm_token_metadata(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
if len(audios) != 0:
return None
mm_inputs = {}
if len(images) != 0:
mm_inputs.update(self._get_mm_inputs(images, [], [], processor))
if len(videos) != 0:
video_inputs = self._get_video_token_metadata(videos, processor)
if video_inputs is None:
return None
mm_inputs.update(video_inputs)
return mm_inputs
@override
def _get_mm_inputs(
self,
@@ -1798,7 +2000,10 @@ class Qwen2VLPlugin(BasePlugin):
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
if mm_inputs is None:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
else:
@@ -1832,6 +2037,51 @@ class Qwen2VLPlugin(BasePlugin):
@dataclass
class Qwen3VLPlugin(Qwen2VLPlugin):
@override
def _get_qwen_video_resize(
self,
num_frames: int,
height: int,
width: int,
patch_size: int,
temporal_patch_size: int,
merge_size: int,
min_pixels: int,
max_pixels: int,
) -> tuple[int, int]:
from transformers.models.qwen3_vl.video_processing_qwen3_vl import smart_resize
return smart_resize(
num_frames=num_frames,
height=height,
width=width,
temporal_factor=temporal_patch_size,
factor=patch_size * merge_size,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
@override
def _get_video_token_metadata(
self,
videos: list["VideoInput"],
processor: "MMProcessor",
) -> Optional[dict[str, Any]]:
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
if video_metadata is None:
return None
return {
"video_grid_thw": video_metadata["video_grid_thw"],
"video_metadata": [
SimpleNamespace(
frames_indices=frames_indices,
fps=video_metadata["fps"],
)
for frames_indices in video_metadata["frames_indices"]
],
}
@override
def _get_mm_inputs(
self,
@@ -1904,7 +2154,10 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
if mm_inputs is None:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now

View File

@@ -186,7 +186,6 @@ def _verify_model_args(
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
def _check_extra_dependencies(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",

View File

@@ -20,7 +20,6 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft
from transformers.integrations import is_deepspeed_zero3_enabled
from ..extras import logging
from ..extras.constants import EngineName
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
from .model_utils.quantization import QuantizationMethod
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model

View File

@@ -21,7 +21,7 @@ import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.extras.packages import is_pyav_available, is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
@@ -439,6 +439,40 @@ def test_qwen3_vl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
@pytest.mark.skipif(not is_pyav_available(), reason="Requires pyav")
def test_qwen3_vl_plugin_video_path():
video_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "data", "mllm_demo_data", "1.mp4")
video_path = os.path.abspath(video_path)
if not os.path.exists(video_path):
pytest.skip(f"Video file not found: {video_path}")
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-30B-A3B-Instruct")
processor = tokenizer_module["processor"]
qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
videos = [video_path]
# fast path: metadata-only, no frame decoding
fast_mm_inputs = qwen3_vl_plugin._get_mm_token_metadata([], videos, [], processor)
assert fast_mm_inputs is not None, "_get_mm_token_metadata should succeed for a real video file"
full_mm_inputs = qwen3_vl_plugin._get_mm_inputs([], videos, [], processor)
# video_grid_thw must be identical between the two paths
assert torch.equal(fast_mm_inputs["video_grid_thw"], full_mm_inputs["video_grid_thw"]), (
f"video_grid_thw mismatch between fast path and full path: "
f"fast={fast_mm_inputs['video_grid_thw']}, full={full_mm_inputs['video_grid_thw']}"
)
result = qwen3_vl_plugin.process_messages(VIDEO_MESSAGES, [], videos, [], processor)
# This demo video duration is 9.72s, with video_fps=2, we extract 19 frames
# 19 + 1 => temperoal compress => 10 video_sequence
assert result[0]["content"].count("<|vision_start|>") == 10, (
f"Expected 10 video tokens, got {result[0]['content'].count('<|vision_start|>')}"
)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin():