From 129e918106f96836f5b1795a322647aa644f9408 Mon Sep 17 00:00:00 2001 From: Xiaosu Zhu Date: Sun, 26 Oct 2025 16:07:04 +0800 Subject: [PATCH] [data] Fix Qwen3VL plugin (#9297) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Yaowei Zheng Co-authored-by: kingsley --- scripts/vllm_infer.py | 36 ++++++++++++- src/llamafactory/data/mm_plugin.py | 87 ++++++++++++++++++------------ 2 files changed, 89 insertions(+), 34 deletions(-) diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 1a080ad5..4d6f0586 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -16,6 +16,7 @@ import gc import json from typing import Optional +import av import fire from tqdm import tqdm from transformers import Seq2SeqTrainingArguments @@ -33,6 +34,14 @@ if is_vllm_available(): from vllm.lora.request import LoRARequest +def _need_video_kwargs(template): + NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"] + if any(t in template for t in NEEDED_TEMPLATE): + return True + + return False + + def vllm_infer( model_name_or_path: str, adapter_name_or_path: str = None, @@ -132,6 +141,7 @@ def vllm_infer( # Store all results in these lists all_prompts, all_preds, all_labels = [], [], [] + need_video_kwargs = _need_video_kwargs(template) # Add batch process to avoid the issue of too many files opened for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"): @@ -147,6 +157,7 @@ def vllm_infer( )["images"] } elif batch["videos"][j] is not None: + video_metadata, video_metadata_kwargs = None, None video = batch["videos"][j] multi_modal_data = { "video": template_obj.mm_plugin._regularize_videos( @@ -157,6 +168,25 @@ def vllm_infer( video_maxlen=video_maxlen, )["videos"] } + if need_video_kwargs: + container = av.open(video[0], "r") + video_stream = next(stream for stream in container.streams if stream.type == "video") + sampling_indices = template_obj.mm_plugin._get_video_sample_indices( + video_stream, video_fps, video_maxlen + ) + total_frames = video_stream.frames + video_metadata_kwargs = { + "fps": getattr(tokenizer_module["processor"], "video_fps", 24.0), + "do_sample_frames": False, + "total_num_frames": total_frames, + } + video_metadata = dict( + fps=video_fps, + frames_indices=sampling_indices, + total_num_frames=total_frames, + video_backend="opencv", + ) + multi_modal_data["video"] = (multi_modal_data["video"], video_metadata) elif batch["audios"][j] is not None: audio = batch["audios"][j] audio_data = template_obj.mm_plugin._regularize_audios( @@ -167,7 +197,11 @@ def vllm_infer( else: multi_modal_data = None - vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}) + vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data} + if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None: + vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs + + vllm_inputs.append(vllm_input_data) prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens)) labels.append( tokenizer.decode( diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 6916d962..242811d3 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -31,7 +31,7 @@ from transformers.models.mllama.processing_mllama import ( convert_sparse_cross_attention_mask_to_dense, get_cross_attention_token_mask, ) -from typing_extensions import override +from typing_extensions import NotRequired, override from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.packages import ( @@ -77,6 +77,18 @@ if TYPE_CHECKING: VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] AudioInput = Union[str, BinaryIO, NDArray] + class RegularizedImageOutput(TypedDict): + images: list[ImageObject] + + class RegularizedVideoOutput(TypedDict): + videos: list[list[ImageObject]] + durations: list[float] + fps_per_video: NotRequired[list[float]] + + class RegularizedAudioOutput(TypedDict): + audios: list[NDArray] + sampling_rates: list[float] + class MMProcessor(ProcessorMixin): patch_size: int image_seq_length: int @@ -244,7 +256,7 @@ class MMPluginMixin: sample_frames = min(total_frames, video_maxlen, sample_frames) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) - def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]: + def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput": r"""Regularize images to avoid error. Including reading and pre-processing.""" results = [] for image in images: @@ -265,9 +277,10 @@ class MMPluginMixin: return {"images": results} - def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]: + def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput": r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" results = [] + durations = [] for video in videos: frames: list[ImageObject] = [] if _check_video_is_nested_images(video): @@ -275,6 +288,7 @@ class MMPluginMixin: if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): raise ValueError("Invalid image found in video frames.") frames = video + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) else: container = av.open(video, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") @@ -284,14 +298,19 @@ class MMPluginMixin: if frame_idx in sample_indices: frames.append(frame.to_image()) + if video_stream.duration is None: + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) + else: + durations.append(float(video_stream.duration * video_stream.time_base)) + frames = self._regularize_images(frames, **kwargs)["images"] results.append(frames) - return {"videos": results} + return {"videos": results, "durations": durations} def _regularize_audios( self, audios: list["AudioInput"], sampling_rate: float, **kwargs - ) -> dict[str, Union[list["NDArray"], list[float]]]: + ) -> "RegularizedAudioOutput": r"""Regularizes audios to avoid error. Including reading and resampling.""" results, sampling_rates = [], [] for audio in audios: @@ -1418,10 +1437,8 @@ class Qwen2VLPlugin(BasePlugin): return image @override - def _regularize_videos( - self, videos: list["VideoInput"], **kwargs - ) -> dict[str, Union[list[list["ImageObject"]], list[float]]]: - results, fps_per_video = [], [] + def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput": + results, fps_per_video, durations = [], [], [] for video in videos: frames: list[ImageObject] = [] if _check_video_is_nested_images(video): @@ -1431,6 +1448,7 @@ class Qwen2VLPlugin(BasePlugin): frames = video fps_per_video.append(kwargs.get("video_fps", 2.0)) + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) else: container = av.open(video, "r") video_stream = next(stream for stream in container.streams if stream.type == "video") @@ -1442,8 +1460,10 @@ class Qwen2VLPlugin(BasePlugin): if video_stream.duration is None: fps_per_video.append(kwargs.get("video_fps", 2.0)) + durations.append(len(frames) / kwargs.get("video_fps", 2.0)) else: fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) + durations.append(float(video_stream.duration * video_stream.time_base)) if len(frames) % 2 != 0: frames.append(frames[-1]) @@ -1451,7 +1471,7 @@ class Qwen2VLPlugin(BasePlugin): frames = self._regularize_images(frames, **kwargs)["images"] results.append(frames) - return {"videos": results, "fps_per_video": fps_per_video} + return {"videos": results, "fps_per_video": fps_per_video, "durations": durations} @override def _get_mm_inputs( @@ -1565,8 +1585,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin): video_maxlen=getattr(processor, "video_maxlen", 128), ) video_metadata = [ - {"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)} - for video in videos["videos"] + {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)} + for video, duration in zip(videos["videos"], videos["durations"]) ] mm_inputs.update( video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True) @@ -1622,27 +1642,27 @@ class Qwen3VLPlugin(Qwen2VLPlugin): num_image_tokens += 1 while VIDEO_PLACEHOLDER in content: - metadata = video_metadata[idx] - timestamps = processor._calculate_timestamps( - metadata.frames_indices, - metadata.fps, - video_processor.merge_size, - ) - video_structure = "" - for frame_index in range(num_frames): - video_seqlen = ( - video_grid_thw[num_video_tokens][1:].prod() // video_merge_length - if self.expand_mm_tokens - else 1 + if self.expand_mm_tokens: + metadata = video_metadata[idx] + timestamps = processor._calculate_timestamps( + metadata.frames_indices, + metadata.fps, + video_processor.merge_size, ) - timestamp_sec = timestamps[frame_index] - frame_structure = ( - f"<{timestamp_sec:.1f} seconds>" - f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}" - ) - video_structure += frame_structure - - if not self.expand_mm_tokens: + video_structure = "" + for frame_index in range(num_frames): + video_seqlen = ( + video_grid_thw[num_video_tokens][1:].prod() // video_merge_length + if self.expand_mm_tokens + else 1 + ) + timestamp_sec = timestamps[frame_index] + frame_structure = ( + f"<{timestamp_sec:.1f} seconds>" + f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}" + ) + video_structure += frame_structure + else: video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}" content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1) @@ -1684,7 +1704,8 @@ class GLM4VPlugin(Qwen2VLPlugin): ) # prepare video metadata video_metadata = [ - {"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"] + {"fps": 2, "duration": duration, "total_frames": len(video)} + for video, duration in zip(video_data["videos"], video_data["durations"]) ] mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))