mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-05 07:38:55 +08:00
[data] Optimize QwenVL video dataset preprocessing (#10404)
Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
@@ -22,7 +22,8 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -245,6 +246,14 @@ class MMPluginMixin:
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
|
||||
def _get_video_token_metadata(
|
||||
self,
|
||||
videos: list["VideoInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
r"""Build metadata used to expand video tokens without decoding frames."""
|
||||
return None
|
||||
|
||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
|
||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||
results = []
|
||||
@@ -1747,6 +1756,199 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
"frames_indices": frames_indices,
|
||||
}
|
||||
|
||||
def _get_qwen_video_size_after_regularization(
|
||||
self, width: int, height: int, image_max_pixels: int, image_min_pixels: int
|
||||
) -> tuple[int, int]:
|
||||
r"""Compute the frame size produced by Qwen-VL image regularization."""
|
||||
if (width * height) > image_max_pixels:
|
||||
resize_factor = math.sqrt(image_max_pixels / (width * height))
|
||||
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||
|
||||
if (width * height) < image_min_pixels:
|
||||
resize_factor = math.sqrt(image_min_pixels / (width * height))
|
||||
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||
|
||||
if min(width, height) < 28:
|
||||
width, height = max(width, 28), max(height, 28)
|
||||
|
||||
if width / height > 200:
|
||||
width, height = height * 180, height
|
||||
|
||||
if height / width > 200:
|
||||
width, height = width, width * 180
|
||||
|
||||
return width, height
|
||||
|
||||
def _get_qwen_video_stream_metadata(
|
||||
self,
|
||||
video: "VideoInput",
|
||||
video_fps: float,
|
||||
video_maxlen: int,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if not is_pyav_available() or not isinstance(video, (str, os.PathLike)):
|
||||
return None
|
||||
|
||||
try:
|
||||
container = av.open(video, "r")
|
||||
except (av.FFmpegError, OSError):
|
||||
return None
|
||||
|
||||
try:
|
||||
video_stream = next((stream for stream in container.streams if stream.type == "video"), None)
|
||||
if video_stream is None:
|
||||
return None
|
||||
|
||||
if video_stream.duration is None or video_stream.average_rate is None:
|
||||
return None
|
||||
|
||||
average_fps = float(video_stream.average_rate)
|
||||
if average_fps <= 0:
|
||||
return None
|
||||
|
||||
sample_indices = self._get_video_sample_indices(
|
||||
video_stream, video_fps=video_fps, video_maxlen=video_maxlen
|
||||
)
|
||||
return {
|
||||
"width": video_stream.width,
|
||||
"height": video_stream.height,
|
||||
"average_fps": average_fps,
|
||||
"sample_indices": sample_indices,
|
||||
}
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
def _get_qwen_video_resize(
|
||||
self,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
) -> tuple[int, int]:
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
|
||||
return smart_resize(
|
||||
height=height,
|
||||
width=width,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
|
||||
def _get_qwen_video_grid_metadata(
|
||||
self,
|
||||
videos: list["VideoInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if len(videos) == 0:
|
||||
return {"video_grid_thw": torch.empty((0, 3), dtype=torch.long), "frames_indices": [], "fps": 2.0}
|
||||
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) or image_processor
|
||||
if image_processor is None or video_processor is None:
|
||||
return None
|
||||
|
||||
patch_size = getattr(video_processor, "patch_size", None)
|
||||
temporal_patch_size = getattr(video_processor, "temporal_patch_size", None)
|
||||
merge_size = getattr(video_processor, "merge_size", None)
|
||||
size = getattr(video_processor, "size", None)
|
||||
if patch_size is None or temporal_patch_size is None or merge_size is None or size is None:
|
||||
return None
|
||||
|
||||
if isinstance(size, dict):
|
||||
min_pixels = size.get("shortest_edge")
|
||||
max_pixels = size.get("longest_edge")
|
||||
else:
|
||||
min_pixels = getattr(size, "shortest_edge", None)
|
||||
max_pixels = getattr(size, "longest_edge", None)
|
||||
|
||||
if min_pixels is None or max_pixels is None:
|
||||
return None
|
||||
|
||||
video_fps = getattr(processor, "video_fps", 2.0)
|
||||
video_maxlen = getattr(processor, "video_maxlen", 128)
|
||||
image_max_pixels = getattr(processor, "video_max_pixels", 256 * 256)
|
||||
image_min_pixels = getattr(processor, "video_min_pixels", 16 * 16)
|
||||
|
||||
video_grid_thw = []
|
||||
frames_indices = []
|
||||
for video in videos:
|
||||
metadata = self._get_qwen_video_stream_metadata(video, video_fps, video_maxlen)
|
||||
if metadata is None:
|
||||
return None
|
||||
|
||||
width, height = self._get_qwen_video_size_after_regularization(
|
||||
metadata["width"], metadata["height"], image_max_pixels, image_min_pixels
|
||||
)
|
||||
num_frames = len(metadata["sample_indices"])
|
||||
if num_frames % 2 != 0:
|
||||
num_frames += 1
|
||||
|
||||
resized_size = self._get_qwen_video_resize(
|
||||
num_frames,
|
||||
height,
|
||||
width,
|
||||
patch_size,
|
||||
temporal_patch_size,
|
||||
merge_size,
|
||||
min_pixels,
|
||||
max_pixels,
|
||||
)
|
||||
|
||||
resized_height, resized_width = resized_size
|
||||
video_grid_thw.append(
|
||||
[
|
||||
math.ceil(num_frames / temporal_patch_size),
|
||||
resized_height // patch_size,
|
||||
resized_width // patch_size,
|
||||
]
|
||||
)
|
||||
frames_indices.append([idx / metadata["average_fps"] * video_fps for idx in metadata["sample_indices"]])
|
||||
|
||||
return {
|
||||
"video_grid_thw": torch.tensor(video_grid_thw, dtype=torch.long),
|
||||
"frames_indices": frames_indices,
|
||||
"fps": video_fps,
|
||||
}
|
||||
|
||||
@override
|
||||
def _get_video_token_metadata(
|
||||
self,
|
||||
videos: list["VideoInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
|
||||
if video_metadata is None:
|
||||
return None
|
||||
|
||||
return {"video_grid_thw": video_metadata["video_grid_thw"]}
|
||||
|
||||
def _get_mm_token_metadata(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if len(audios) != 0:
|
||||
return None
|
||||
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
mm_inputs.update(self._get_mm_inputs(images, [], [], processor))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_inputs = self._get_video_token_metadata(videos, processor)
|
||||
if video_inputs is None:
|
||||
return None
|
||||
|
||||
mm_inputs.update(video_inputs)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
@@ -1798,7 +2000,10 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
|
||||
if mm_inputs is None:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
else:
|
||||
@@ -1832,6 +2037,51 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
@dataclass
|
||||
class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
@override
|
||||
def _get_qwen_video_resize(
|
||||
self,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
patch_size: int,
|
||||
temporal_patch_size: int,
|
||||
merge_size: int,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
) -> tuple[int, int]:
|
||||
from transformers.models.qwen3_vl.video_processing_qwen3_vl import smart_resize
|
||||
|
||||
return smart_resize(
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
temporal_factor=temporal_patch_size,
|
||||
factor=patch_size * merge_size,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
|
||||
@override
|
||||
def _get_video_token_metadata(
|
||||
self,
|
||||
videos: list["VideoInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
|
||||
if video_metadata is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"video_grid_thw": video_metadata["video_grid_thw"],
|
||||
"video_metadata": [
|
||||
SimpleNamespace(
|
||||
frames_indices=frames_indices,
|
||||
fps=video_metadata["fps"],
|
||||
)
|
||||
for frames_indices in video_metadata["frames_indices"]
|
||||
],
|
||||
}
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
@@ -1904,7 +2154,10 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
|
||||
if mm_inputs is None:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
||||
|
||||
@@ -186,7 +186,6 @@ def _verify_model_args(
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
|
||||
|
||||
def _check_extra_dependencies(
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
|
||||
@@ -20,7 +20,6 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.constants import EngineName
|
||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||
from .model_utils.quantization import QuantizationMethod
|
||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.extras.packages import is_pyav_available, is_transformers_version_greater_than
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
@@ -439,6 +439,40 @@ def test_qwen3_vl_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
|
||||
@pytest.mark.skipif(not is_pyav_available(), reason="Requires pyav")
|
||||
def test_qwen3_vl_plugin_video_path():
|
||||
video_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "data", "mllm_demo_data", "1.mp4")
|
||||
video_path = os.path.abspath(video_path)
|
||||
if not os.path.exists(video_path):
|
||||
pytest.skip(f"Video file not found: {video_path}")
|
||||
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-30B-A3B-Instruct")
|
||||
processor = tokenizer_module["processor"]
|
||||
qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
|
||||
|
||||
videos = [video_path]
|
||||
|
||||
# fast path: metadata-only, no frame decoding
|
||||
fast_mm_inputs = qwen3_vl_plugin._get_mm_token_metadata([], videos, [], processor)
|
||||
assert fast_mm_inputs is not None, "_get_mm_token_metadata should succeed for a real video file"
|
||||
|
||||
full_mm_inputs = qwen3_vl_plugin._get_mm_inputs([], videos, [], processor)
|
||||
|
||||
# video_grid_thw must be identical between the two paths
|
||||
assert torch.equal(fast_mm_inputs["video_grid_thw"], full_mm_inputs["video_grid_thw"]), (
|
||||
f"video_grid_thw mismatch between fast path and full path: "
|
||||
f"fast={fast_mm_inputs['video_grid_thw']}, full={full_mm_inputs['video_grid_thw']}"
|
||||
)
|
||||
result = qwen3_vl_plugin.process_messages(VIDEO_MESSAGES, [], videos, [], processor)
|
||||
# This demo video duration is 9.72s, with video_fps=2, we extract 19 frames
|
||||
# 19 + 1 => temperoal compress => 10 video_sequence
|
||||
assert result[0]["content"].count("<|vision_start|>") == 10, (
|
||||
f"Expected 10 video tokens, got {result[0]['content'].count('<|vision_start|>')}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
|
||||
def test_video_llava_plugin():
|
||||
|
||||
Reference in New Issue
Block a user