mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-05 07:38:55 +08:00
[data] Optimize QwenVL video dataset preprocessing (#10404)
Co-authored-by: Kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.extras.packages import is_pyav_available, is_transformers_version_greater_than
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
@@ -439,6 +439,40 @@ def test_qwen3_vl_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
|
||||
@pytest.mark.skipif(not is_pyav_available(), reason="Requires pyav")
|
||||
def test_qwen3_vl_plugin_video_path():
|
||||
video_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "..", "data", "mllm_demo_data", "1.mp4")
|
||||
video_path = os.path.abspath(video_path)
|
||||
if not os.path.exists(video_path):
|
||||
pytest.skip(f"Video file not found: {video_path}")
|
||||
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-30B-A3B-Instruct")
|
||||
processor = tokenizer_module["processor"]
|
||||
qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
|
||||
|
||||
videos = [video_path]
|
||||
|
||||
# fast path: metadata-only, no frame decoding
|
||||
fast_mm_inputs = qwen3_vl_plugin._get_mm_token_metadata([], videos, [], processor)
|
||||
assert fast_mm_inputs is not None, "_get_mm_token_metadata should succeed for a real video file"
|
||||
|
||||
full_mm_inputs = qwen3_vl_plugin._get_mm_inputs([], videos, [], processor)
|
||||
|
||||
# video_grid_thw must be identical between the two paths
|
||||
assert torch.equal(fast_mm_inputs["video_grid_thw"], full_mm_inputs["video_grid_thw"]), (
|
||||
f"video_grid_thw mismatch between fast path and full path: "
|
||||
f"fast={fast_mm_inputs['video_grid_thw']}, full={full_mm_inputs['video_grid_thw']}"
|
||||
)
|
||||
result = qwen3_vl_plugin.process_messages(VIDEO_MESSAGES, [], videos, [], processor)
|
||||
# This demo video duration is 9.72s, with video_fps=2, we extract 19 frames
|
||||
# 19 + 1 => temperoal compress => 10 video_sequence
|
||||
assert result[0]["content"].count("<|vision_start|>") == 10, (
|
||||
f"Expected 10 video tokens, got {result[0]['content'].count('<|vision_start|>')}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
|
||||
def test_video_llava_plugin():
|
||||
|
||||
Reference in New Issue
Block a user