[data] Fix Qwen3VL plugin (#9297)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
Xiaosu Zhu 2025-10-26 16:07:04 +08:00 committed by GitHub
parent 9c0d033a15
commit 129e918106
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 89 additions and 34 deletions

View File

@ -16,6 +16,7 @@ import gc
import json import json
from typing import Optional from typing import Optional
import av
import fire import fire
from tqdm import tqdm from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
@ -33,6 +34,14 @@ if is_vllm_available():
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
def _need_video_kwargs(template):
NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
if any(t in template for t in NEEDED_TEMPLATE):
return True
return False
def vllm_infer( def vllm_infer(
model_name_or_path: str, model_name_or_path: str,
adapter_name_or_path: str = None, adapter_name_or_path: str = None,
@ -132,6 +141,7 @@ def vllm_infer(
# Store all results in these lists # Store all results in these lists
all_prompts, all_preds, all_labels = [], [], [] all_prompts, all_preds, all_labels = [], [], []
need_video_kwargs = _need_video_kwargs(template)
# Add batch process to avoid the issue of too many files opened # Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"): for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
@ -147,6 +157,7 @@ def vllm_infer(
)["images"] )["images"]
} }
elif batch["videos"][j] is not None: elif batch["videos"][j] is not None:
video_metadata, video_metadata_kwargs = None, None
video = batch["videos"][j] video = batch["videos"][j]
multi_modal_data = { multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos( "video": template_obj.mm_plugin._regularize_videos(
@ -157,6 +168,25 @@ def vllm_infer(
video_maxlen=video_maxlen, video_maxlen=video_maxlen,
)["videos"] )["videos"]
} }
if need_video_kwargs:
container = av.open(video[0], "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
video_stream, video_fps, video_maxlen
)
total_frames = video_stream.frames
video_metadata_kwargs = {
"fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
"do_sample_frames": False,
"total_num_frames": total_frames,
}
video_metadata = dict(
fps=video_fps,
frames_indices=sampling_indices,
total_num_frames=total_frames,
video_backend="opencv",
)
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
elif batch["audios"][j] is not None: elif batch["audios"][j] is not None:
audio = batch["audios"][j] audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios( audio_data = template_obj.mm_plugin._regularize_audios(
@ -167,7 +197,11 @@ def vllm_infer(
else: else:
multi_modal_data = None multi_modal_data = None
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}) vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
vllm_inputs.append(vllm_input_data)
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens)) prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
labels.append( labels.append(
tokenizer.decode( tokenizer.decode(

View File

@ -31,7 +31,7 @@ from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense, convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask, get_cross_attention_token_mask,
) )
from typing_extensions import override from typing_extensions import NotRequired, override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import ( from ..extras.packages import (
@ -77,6 +77,18 @@ if TYPE_CHECKING:
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
AudioInput = Union[str, BinaryIO, NDArray] AudioInput = Union[str, BinaryIO, NDArray]
class RegularizedImageOutput(TypedDict):
images: list[ImageObject]
class RegularizedVideoOutput(TypedDict):
videos: list[list[ImageObject]]
durations: list[float]
fps_per_video: NotRequired[list[float]]
class RegularizedAudioOutput(TypedDict):
audios: list[NDArray]
sampling_rates: list[float]
class MMProcessor(ProcessorMixin): class MMProcessor(ProcessorMixin):
patch_size: int patch_size: int
image_seq_length: int image_seq_length: int
@ -244,7 +256,7 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames) sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]: def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
r"""Regularize images to avoid error. Including reading and pre-processing.""" r"""Regularize images to avoid error. Including reading and pre-processing."""
results = [] results = []
for image in images: for image in images:
@ -265,9 +277,10 @@ class MMPluginMixin:
return {"images": results} return {"images": results}
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]: def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = [] results = []
durations = []
for video in videos: for video in videos:
frames: list[ImageObject] = [] frames: list[ImageObject] = []
if _check_video_is_nested_images(video): if _check_video_is_nested_images(video):
@ -275,6 +288,7 @@ class MMPluginMixin:
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
raise ValueError("Invalid image found in video frames.") raise ValueError("Invalid image found in video frames.")
frames = video frames = video
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else: else:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@ -284,14 +298,19 @@ class MMPluginMixin:
if frame_idx in sample_indices: if frame_idx in sample_indices:
frames.append(frame.to_image()) frames.append(frame.to_image())
if video_stream.duration is None:
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
durations.append(float(video_stream.duration * video_stream.time_base))
frames = self._regularize_images(frames, **kwargs)["images"] frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
return {"videos": results} return {"videos": results, "durations": durations}
def _regularize_audios( def _regularize_audios(
self, audios: list["AudioInput"], sampling_rate: float, **kwargs self, audios: list["AudioInput"], sampling_rate: float, **kwargs
) -> dict[str, Union[list["NDArray"], list[float]]]: ) -> "RegularizedAudioOutput":
r"""Regularizes audios to avoid error. Including reading and resampling.""" r"""Regularizes audios to avoid error. Including reading and resampling."""
results, sampling_rates = [], [] results, sampling_rates = [], []
for audio in audios: for audio in audios:
@ -1418,10 +1437,8 @@ class Qwen2VLPlugin(BasePlugin):
return image return image
@override @override
def _regularize_videos( def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
self, videos: list["VideoInput"], **kwargs results, fps_per_video, durations = [], [], []
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
results, fps_per_video = [], []
for video in videos: for video in videos:
frames: list[ImageObject] = [] frames: list[ImageObject] = []
if _check_video_is_nested_images(video): if _check_video_is_nested_images(video):
@ -1431,6 +1448,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = video frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0)) fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else: else:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@ -1442,8 +1460,10 @@ class Qwen2VLPlugin(BasePlugin):
if video_stream.duration is None: if video_stream.duration is None:
fps_per_video.append(kwargs.get("video_fps", 2.0)) fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else: else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
durations.append(float(video_stream.duration * video_stream.time_base))
if len(frames) % 2 != 0: if len(frames) % 2 != 0:
frames.append(frames[-1]) frames.append(frames[-1])
@ -1451,7 +1471,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)["images"] frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video} return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
@override @override
def _get_mm_inputs( def _get_mm_inputs(
@ -1565,8 +1585,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_maxlen=getattr(processor, "video_maxlen", 128), video_maxlen=getattr(processor, "video_maxlen", 128),
) )
video_metadata = [ video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)} {"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
for video in videos["videos"] for video, duration in zip(videos["videos"], videos["durations"])
] ]
mm_inputs.update( mm_inputs.update(
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True) video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
@ -1622,27 +1642,27 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
metadata = video_metadata[idx] if self.expand_mm_tokens:
timestamps = processor._calculate_timestamps( metadata = video_metadata[idx]
metadata.frames_indices, timestamps = processor._calculate_timestamps(
metadata.fps, metadata.frames_indices,
video_processor.merge_size, metadata.fps,
) video_processor.merge_size,
video_structure = ""
for frame_index in range(num_frames):
video_seqlen = (
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
if self.expand_mm_tokens
else 1
) )
timestamp_sec = timestamps[frame_index] video_structure = ""
frame_structure = ( for frame_index in range(num_frames):
f"<{timestamp_sec:.1f} seconds>" video_seqlen = (
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}" video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
) if self.expand_mm_tokens
video_structure += frame_structure else 1
)
if not self.expand_mm_tokens: timestamp_sec = timestamps[frame_index]
frame_structure = (
f"<{timestamp_sec:.1f} seconds>"
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
)
video_structure += frame_structure
else:
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}" video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1) content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
@ -1684,7 +1704,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
) )
# prepare video metadata # prepare video metadata
video_metadata = [ video_metadata = [
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"] {"fps": 2, "duration": duration, "total_frames": len(video)}
for video, duration in zip(video_data["videos"], video_data["durations"])
] ]
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata)) mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))