mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-05 07:38:55 +08:00
[data] Optimize QwenVL video dataset preprocessing (#10404)
Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
@@ -22,7 +22,8 @@ import re
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
from types import SimpleNamespace
|
||||||
|
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -245,6 +246,14 @@ class MMPluginMixin:
|
|||||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||||
|
|
||||||
|
def _get_video_token_metadata(
|
||||||
|
self,
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
r"""Build metadata used to expand video tokens without decoding frames."""
|
||||||
|
return None
|
||||||
|
|
||||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
|
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
|
||||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||||
results = []
|
results = []
|
||||||
@@ -1747,6 +1756,199 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
"frames_indices": frames_indices,
|
"frames_indices": frames_indices,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_qwen_video_size_after_regularization(
|
||||||
|
self, width: int, height: int, image_max_pixels: int, image_min_pixels: int
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
r"""Compute the frame size produced by Qwen-VL image regularization."""
|
||||||
|
if (width * height) > image_max_pixels:
|
||||||
|
resize_factor = math.sqrt(image_max_pixels / (width * height))
|
||||||
|
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||||
|
|
||||||
|
if (width * height) < image_min_pixels:
|
||||||
|
resize_factor = math.sqrt(image_min_pixels / (width * height))
|
||||||
|
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||||
|
|
||||||
|
if min(width, height) < 28:
|
||||||
|
width, height = max(width, 28), max(height, 28)
|
||||||
|
|
||||||
|
if width / height > 200:
|
||||||
|
width, height = height * 180, height
|
||||||
|
|
||||||
|
if height / width > 200:
|
||||||
|
width, height = width, width * 180
|
||||||
|
|
||||||
|
return width, height
|
||||||
|
|
||||||
|
def _get_qwen_video_stream_metadata(
|
||||||
|
self,
|
||||||
|
video: "VideoInput",
|
||||||
|
video_fps: float,
|
||||||
|
video_maxlen: int,
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
if not is_pyav_available() or not isinstance(video, (str, os.PathLike)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
container = av.open(video, "r")
|
||||||
|
except (av.FFmpegError, OSError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
video_stream = next((stream for stream in container.streams if stream.type == "video"), None)
|
||||||
|
if video_stream is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if video_stream.duration is None or video_stream.average_rate is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
average_fps = float(video_stream.average_rate)
|
||||||
|
if average_fps <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
sample_indices = self._get_video_sample_indices(
|
||||||
|
video_stream, video_fps=video_fps, video_maxlen=video_maxlen
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"width": video_stream.width,
|
||||||
|
"height": video_stream.height,
|
||||||
|
"average_fps": average_fps,
|
||||||
|
"sample_indices": sample_indices,
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
def _get_qwen_video_resize(
|
||||||
|
self,
|
||||||
|
num_frames: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
patch_size: int,
|
||||||
|
temporal_patch_size: int,
|
||||||
|
merge_size: int,
|
||||||
|
min_pixels: int,
|
||||||
|
max_pixels: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||||
|
|
||||||
|
return smart_resize(
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
factor=patch_size * merge_size,
|
||||||
|
min_pixels=min_pixels,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_qwen_video_grid_metadata(
|
||||||
|
self,
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
if len(videos) == 0:
|
||||||
|
return {"video_grid_thw": torch.empty((0, 3), dtype=torch.long), "frames_indices": [], "fps": 2.0}
|
||||||
|
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||||
|
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None) or image_processor
|
||||||
|
if image_processor is None or video_processor is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
patch_size = getattr(video_processor, "patch_size", None)
|
||||||
|
temporal_patch_size = getattr(video_processor, "temporal_patch_size", None)
|
||||||
|
merge_size = getattr(video_processor, "merge_size", None)
|
||||||
|
size = getattr(video_processor, "size", None)
|
||||||
|
if patch_size is None or temporal_patch_size is None or merge_size is None or size is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(size, dict):
|
||||||
|
min_pixels = size.get("shortest_edge")
|
||||||
|
max_pixels = size.get("longest_edge")
|
||||||
|
else:
|
||||||
|
min_pixels = getattr(size, "shortest_edge", None)
|
||||||
|
max_pixels = getattr(size, "longest_edge", None)
|
||||||
|
|
||||||
|
if min_pixels is None or max_pixels is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
video_fps = getattr(processor, "video_fps", 2.0)
|
||||||
|
video_maxlen = getattr(processor, "video_maxlen", 128)
|
||||||
|
image_max_pixels = getattr(processor, "video_max_pixels", 256 * 256)
|
||||||
|
image_min_pixels = getattr(processor, "video_min_pixels", 16 * 16)
|
||||||
|
|
||||||
|
video_grid_thw = []
|
||||||
|
frames_indices = []
|
||||||
|
for video in videos:
|
||||||
|
metadata = self._get_qwen_video_stream_metadata(video, video_fps, video_maxlen)
|
||||||
|
if metadata is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
width, height = self._get_qwen_video_size_after_regularization(
|
||||||
|
metadata["width"], metadata["height"], image_max_pixels, image_min_pixels
|
||||||
|
)
|
||||||
|
num_frames = len(metadata["sample_indices"])
|
||||||
|
if num_frames % 2 != 0:
|
||||||
|
num_frames += 1
|
||||||
|
|
||||||
|
resized_size = self._get_qwen_video_resize(
|
||||||
|
num_frames,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
patch_size,
|
||||||
|
temporal_patch_size,
|
||||||
|
merge_size,
|
||||||
|
min_pixels,
|
||||||
|
max_pixels,
|
||||||
|
)
|
||||||
|
|
||||||
|
resized_height, resized_width = resized_size
|
||||||
|
video_grid_thw.append(
|
||||||
|
[
|
||||||
|
math.ceil(num_frames / temporal_patch_size),
|
||||||
|
resized_height // patch_size,
|
||||||
|
resized_width // patch_size,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
frames_indices.append([idx / metadata["average_fps"] * video_fps for idx in metadata["sample_indices"]])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"video_grid_thw": torch.tensor(video_grid_thw, dtype=torch.long),
|
||||||
|
"frames_indices": frames_indices,
|
||||||
|
"fps": video_fps,
|
||||||
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_video_token_metadata(
|
||||||
|
self,
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
|
||||||
|
if video_metadata is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {"video_grid_thw": video_metadata["video_grid_thw"]}
|
||||||
|
|
||||||
|
def _get_mm_token_metadata(
|
||||||
|
self,
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
if len(audios) != 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mm_inputs = {}
|
||||||
|
if len(images) != 0:
|
||||||
|
mm_inputs.update(self._get_mm_inputs(images, [], [], processor))
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
video_inputs = self._get_video_token_metadata(videos, processor)
|
||||||
|
if video_inputs is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mm_inputs.update(video_inputs)
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
@@ -1798,7 +2000,10 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
|
||||||
|
if mm_inputs is None:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
else:
|
else:
|
||||||
@@ -1832,6 +2037,51 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Qwen3VLPlugin(Qwen2VLPlugin):
|
class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||||
|
@override
|
||||||
|
def _get_qwen_video_resize(
|
||||||
|
self,
|
||||||
|
num_frames: int,
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
patch_size: int,
|
||||||
|
temporal_patch_size: int,
|
||||||
|
merge_size: int,
|
||||||
|
min_pixels: int,
|
||||||
|
max_pixels: int,
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
from transformers.models.qwen3_vl.video_processing_qwen3_vl import smart_resize
|
||||||
|
|
||||||
|
return smart_resize(
|
||||||
|
num_frames=num_frames,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
temporal_factor=temporal_patch_size,
|
||||||
|
factor=patch_size * merge_size,
|
||||||
|
min_pixels=min_pixels,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_video_token_metadata(
|
||||||
|
self,
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
video_metadata = self._get_qwen_video_grid_metadata(videos, processor)
|
||||||
|
if video_metadata is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"video_grid_thw": video_metadata["video_grid_thw"],
|
||||||
|
"video_metadata": [
|
||||||
|
SimpleNamespace(
|
||||||
|
frames_indices=frames_indices,
|
||||||
|
fps=video_metadata["fps"],
|
||||||
|
)
|
||||||
|
for frames_indices in video_metadata["frames_indices"]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
@@ -1904,7 +2154,10 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
|
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||||
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
|
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_token_metadata(images, videos, audios, processor)
|
||||||
|
if mm_inputs is None:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
||||||
|
|||||||
@@ -186,7 +186,6 @@ def _verify_model_args(
|
|||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _check_extra_dependencies(
|
def _check_extra_dependencies(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from peft import LoraConfig, LoraModel, OFTConfig, PeftModel, TaskType, get_peft
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import EngineName
|
|
||||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
from .model_utils.quantization import QuantizationMethod
|
from .model_utils.quantization import QuantizationMethod
|
||||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
from llamafactory.extras.packages import is_pyav_available, is_transformers_version_greater_than
|
||||||
from llamafactory.hparams import get_infer_args
|
from llamafactory.hparams import get_infer_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
|
|
||||||
@@ -439,6 +439,40 @@ def test_qwen3_vl_plugin():
|
|||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
|
||||||
|
@pytest.mark.skipif(not is_pyav_available(), reason="Requires pyav")
|
||||||
|
def test_qwen3_vl_plugin_video_path():
|
||||||
|
video_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "data", "mllm_demo_data", "1.mp4")
|
||||||
|
video_path = os.path.abspath(video_path)
|
||||||
|
if not os.path.exists(video_path):
|
||||||
|
pytest.skip(f"Video file not found: {video_path}")
|
||||||
|
|
||||||
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-30B-A3B-Instruct")
|
||||||
|
processor = tokenizer_module["processor"]
|
||||||
|
qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
|
||||||
|
|
||||||
|
videos = [video_path]
|
||||||
|
|
||||||
|
# fast path: metadata-only, no frame decoding
|
||||||
|
fast_mm_inputs = qwen3_vl_plugin._get_mm_token_metadata([], videos, [], processor)
|
||||||
|
assert fast_mm_inputs is not None, "_get_mm_token_metadata should succeed for a real video file"
|
||||||
|
|
||||||
|
full_mm_inputs = qwen3_vl_plugin._get_mm_inputs([], videos, [], processor)
|
||||||
|
|
||||||
|
# video_grid_thw must be identical between the two paths
|
||||||
|
assert torch.equal(fast_mm_inputs["video_grid_thw"], full_mm_inputs["video_grid_thw"]), (
|
||||||
|
f"video_grid_thw mismatch between fast path and full path: "
|
||||||
|
f"fast={fast_mm_inputs['video_grid_thw']}, full={full_mm_inputs['video_grid_thw']}"
|
||||||
|
)
|
||||||
|
result = qwen3_vl_plugin.process_messages(VIDEO_MESSAGES, [], videos, [], processor)
|
||||||
|
# This demo video duration is 9.72s, with video_fps=2, we extract 19 frames
|
||||||
|
# 19 + 1 => temperoal compress => 10 video_sequence
|
||||||
|
assert result[0]["content"].count("<|vision_start|>") == 10, (
|
||||||
|
f"Expected 10 video tokens, got {result[0]['content'].count('<|vision_start|>')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
|
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
|
||||||
def test_video_llava_plugin():
|
def test_video_llava_plugin():
|
||||||
|
|||||||
Reference in New Issue
Block a user