[data] Fix Qwen3VL plugin (#9297)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
Xiaosu Zhu 2025-10-26 16:07:04 +08:00 committed by GitHub
parent 9c0d033a15
commit 129e918106
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 89 additions and 34 deletions

View File

@ -16,6 +16,7 @@ import gc
import json
from typing import Optional
import av
import fire
from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments
@ -33,6 +34,14 @@ if is_vllm_available():
from vllm.lora.request import LoRARequest
def _need_video_kwargs(template):
NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
if any(t in template for t in NEEDED_TEMPLATE):
return True
return False
def vllm_infer(
model_name_or_path: str,
adapter_name_or_path: str = None,
@ -132,6 +141,7 @@ def vllm_infer(
# Store all results in these lists
all_prompts, all_preds, all_labels = [], [], []
need_video_kwargs = _need_video_kwargs(template)
# Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
@ -147,6 +157,7 @@ def vllm_infer(
)["images"]
}
elif batch["videos"][j] is not None:
video_metadata, video_metadata_kwargs = None, None
video = batch["videos"][j]
multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos(
@ -157,6 +168,25 @@ def vllm_infer(
video_maxlen=video_maxlen,
)["videos"]
}
if need_video_kwargs:
container = av.open(video[0], "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
video_stream, video_fps, video_maxlen
)
total_frames = video_stream.frames
video_metadata_kwargs = {
"fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
"do_sample_frames": False,
"total_num_frames": total_frames,
}
video_metadata = dict(
fps=video_fps,
frames_indices=sampling_indices,
total_num_frames=total_frames,
video_backend="opencv",
)
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
elif batch["audios"][j] is not None:
audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios(
@ -167,7 +197,11 @@ def vllm_infer(
else:
multi_modal_data = None
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
vllm_inputs.append(vllm_input_data)
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
labels.append(
tokenizer.decode(

View File

@ -31,7 +31,7 @@ from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask,
)
from typing_extensions import override
from typing_extensions import NotRequired, override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import (
@ -77,6 +77,18 @@ if TYPE_CHECKING:
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
AudioInput = Union[str, BinaryIO, NDArray]
class RegularizedImageOutput(TypedDict):
images: list[ImageObject]
class RegularizedVideoOutput(TypedDict):
videos: list[list[ImageObject]]
durations: list[float]
fps_per_video: NotRequired[list[float]]
class RegularizedAudioOutput(TypedDict):
audios: list[NDArray]
sampling_rates: list[float]
class MMProcessor(ProcessorMixin):
patch_size: int
image_seq_length: int
@ -244,7 +256,7 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
r"""Regularize images to avoid error. Including reading and pre-processing."""
results = []
for image in images:
@ -265,9 +277,10 @@ class MMPluginMixin:
return {"images": results}
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = []
durations = []
for video in videos:
frames: list[ImageObject] = []
if _check_video_is_nested_images(video):
@ -275,6 +288,7 @@ class MMPluginMixin:
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
raise ValueError("Invalid image found in video frames.")
frames = video
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
@ -284,14 +298,19 @@ class MMPluginMixin:
if frame_idx in sample_indices:
frames.append(frame.to_image())
if video_stream.duration is None:
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
durations.append(float(video_stream.duration * video_stream.time_base))
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return {"videos": results}
return {"videos": results, "durations": durations}
def _regularize_audios(
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
) -> dict[str, Union[list["NDArray"], list[float]]]:
) -> "RegularizedAudioOutput":
r"""Regularizes audios to avoid error. Including reading and resampling."""
results, sampling_rates = [], []
for audio in audios:
@ -1418,10 +1437,8 @@ class Qwen2VLPlugin(BasePlugin):
return image
@override
def _regularize_videos(
self, videos: list["VideoInput"], **kwargs
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
results, fps_per_video = [], []
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
results, fps_per_video, durations = [], [], []
for video in videos:
frames: list[ImageObject] = []
if _check_video_is_nested_images(video):
@ -1431,6 +1448,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
@ -1442,8 +1460,10 @@ class Qwen2VLPlugin(BasePlugin):
if video_stream.duration is None:
fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
durations.append(float(video_stream.duration * video_stream.time_base))
if len(frames) % 2 != 0:
frames.append(frames[-1])
@ -1451,7 +1471,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video}
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
@override
def _get_mm_inputs(
@ -1565,8 +1585,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
for video in videos["videos"]
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
for video, duration in zip(videos["videos"], videos["durations"])
]
mm_inputs.update(
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
@ -1622,27 +1642,27 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
metadata = video_metadata[idx]
timestamps = processor._calculate_timestamps(
metadata.frames_indices,
metadata.fps,
video_processor.merge_size,
)
video_structure = ""
for frame_index in range(num_frames):
video_seqlen = (
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
if self.expand_mm_tokens
else 1
if self.expand_mm_tokens:
metadata = video_metadata[idx]
timestamps = processor._calculate_timestamps(
metadata.frames_indices,
metadata.fps,
video_processor.merge_size,
)
timestamp_sec = timestamps[frame_index]
frame_structure = (
f"<{timestamp_sec:.1f} seconds>"
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
)
video_structure += frame_structure
if not self.expand_mm_tokens:
video_structure = ""
for frame_index in range(num_frames):
video_seqlen = (
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
if self.expand_mm_tokens
else 1
)
timestamp_sec = timestamps[frame_index]
frame_structure = (
f"<{timestamp_sec:.1f} seconds>"
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
)
video_structure += frame_structure
else:
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
@ -1684,7 +1704,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
)
# prepare video metadata
video_metadata = [
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
{"fps": 2, "duration": duration, "total_frames": len(video)}
for video, duration in zip(video_data["videos"], video_data["durations"])
]
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))