[model] add qwen3-vl/qwen3-omni (#9196)

Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
xvxuopop
2025-09-27 01:21:47 +08:00
committed by GitHub
parent abc3b1e1c4
commit 0761a4448f
5 changed files with 268 additions and 2 deletions

View File

@@ -134,7 +134,7 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
def _check_video_is_nested_images(video: "VideoInput") -> bool:
r"""Check if the video is nested images."""
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video)
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict, ImageObject)) for frame in video)
@dataclass
@@ -1531,6 +1531,119 @@ class Qwen2VLPlugin(BasePlugin):
return messages
@dataclass
class Qwen3VLPlugin(Qwen2VLPlugin):
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
for video in videos["videos"]
]
mm_inputs.update(
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
)
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in videos["fps_per_video"]]
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
video_processor: BaseImageProcessor = getattr(processor, "video_processor")
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
video_metadata = mm_inputs.get("video_metadata", {})
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
num_frames = 0
timestamps = [0]
for idx, message in enumerate(messages):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_seqlen = (
image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
)
content = content.replace(
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
metadata = video_metadata[idx]
timestamps = processor._calculate_timestamps(
metadata.frames_indices,
metadata.fps,
video_processor.merge_size,
)
video_structure = ""
for frame_index in range(num_frames):
video_seqlen = (
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
if self.expand_mm_tokens
else 1
)
timestamp_sec = timestamps[frame_index]
frame_structure = f"<{timestamp_sec:.1f} seconds>{self.start_token}{self.video_token * video_seqlen}{self.end_token}"
video_structure += frame_structure
if not self.expand_mm_tokens:
video_structure = f"{self.start_token}{self.video_token}{self.end_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
num_video_tokens += 1
message["content"] = content
return messages
@dataclass
class GLM4VPlugin(Qwen2VLPlugin):
@override
@@ -1893,6 +2006,7 @@ PLUGINS = {
"qwen2_audio": Qwen2AudioPlugin,
"qwen2_omni": Qwen2OmniPlugin,
"qwen2_vl": Qwen2VLPlugin,
"qwen3_vl": Qwen3VLPlugin,
"video_llava": VideoLlavaPlugin,
}