[data] Optimize QwenVL video dataset preprocessing (#10404)

Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
luca-888
2026-05-03 18:36:56 +08:00
committed by GitHub
parent 468723c5d9
commit 8752280dd7
4 changed files with 291 additions and 6 deletions

View File

@@ -21,7 +21,7 @@ import torch
from PIL import Image
from llamafactory.data.mm_plugin import get_mm_plugin
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.extras.packages import is_pyav_available, is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer
@@ -439,6 +439,40 @@ def test_qwen3_vl_plugin():
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
@pytest.mark.skipif(not is_pyav_available(), reason="Requires pyav")
def test_qwen3_vl_plugin_video_path():
video_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "data", "mllm_demo_data", "1.mp4")
video_path = os.path.abspath(video_path)
if not os.path.exists(video_path):
pytest.skip(f"Video file not found: {video_path}")
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-30B-A3B-Instruct")
processor = tokenizer_module["processor"]
qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
videos = [video_path]
# fast path: metadata-only, no frame decoding
fast_mm_inputs = qwen3_vl_plugin._get_mm_token_metadata([], videos, [], processor)
assert fast_mm_inputs is not None, "_get_mm_token_metadata should succeed for a real video file"
full_mm_inputs = qwen3_vl_plugin._get_mm_inputs([], videos, [], processor)
# video_grid_thw must be identical between the two paths
assert torch.equal(fast_mm_inputs["video_grid_thw"], full_mm_inputs["video_grid_thw"]), (
f"video_grid_thw mismatch between fast path and full path: "
f"fast={fast_mm_inputs['video_grid_thw']}, full={full_mm_inputs['video_grid_thw']}"
)
result = qwen3_vl_plugin.process_messages(VIDEO_MESSAGES, [], videos, [], processor)
# This demo video duration is 9.72s, with video_fps=2, we extract 19 frames
# 19 + 1 => temperoal compress => 10 video_sequence
assert result[0]["content"].count("<|vision_start|>") == 10, (
f"Expected 10 video tokens, got {result[0]['content'].count('<|vision_start|>')}"
)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin():