[model] add qwen3-vl/qwen3-omni (#9196)

Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
xvxuopop 2025-09-27 01:21:47 +08:00 committed by GitHub
parent abc3b1e1c4
commit 0761a4448f
5 changed files with 268 additions and 2 deletions

View File

@ -134,7 +134,7 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
def _check_video_is_nested_images(video: "VideoInput") -> bool: def _check_video_is_nested_images(video: "VideoInput") -> bool:
r"""Check if the video is nested images.""" r"""Check if the video is nested images."""
return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video) return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict, ImageObject)) for frame in video)
@dataclass @dataclass
@ -1531,6 +1531,119 @@ class Qwen2VLPlugin(BasePlugin):
return messages return messages
@dataclass
class Qwen3VLPlugin(Qwen2VLPlugin):
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
for video in videos["videos"]
]
mm_inputs.update(
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
)
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names:
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in videos["fps_per_video"]]
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
video_processor: BaseImageProcessor = getattr(processor, "video_processor")
image_merge_length: int = getattr(image_processor, "merge_size") ** 2
video_merge_length: int = getattr(video_processor, "merge_size") ** 2
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
video_metadata = mm_inputs.get("video_metadata", {})
else:
image_grid_thw = [None] * len(images)
video_grid_thw = [None] * len(videos)
num_frames = 0
timestamps = [0]
for idx, message in enumerate(messages):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
image_seqlen = (
image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
)
content = content.replace(
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
metadata = video_metadata[idx]
timestamps = processor._calculate_timestamps(
metadata.frames_indices,
metadata.fps,
video_processor.merge_size,
)
video_structure = ""
for frame_index in range(num_frames):
video_seqlen = (
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
if self.expand_mm_tokens
else 1
)
timestamp_sec = timestamps[frame_index]
frame_structure = f"<{timestamp_sec:.1f} seconds>{self.start_token}{self.video_token * video_seqlen}{self.end_token}"
video_structure += frame_structure
if not self.expand_mm_tokens:
video_structure = f"{self.start_token}{self.video_token}{self.end_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
num_video_tokens += 1
message["content"] = content
return messages
@dataclass @dataclass
class GLM4VPlugin(Qwen2VLPlugin): class GLM4VPlugin(Qwen2VLPlugin):
@override @override
@ -1893,6 +2006,7 @@ PLUGINS = {
"qwen2_audio": Qwen2AudioPlugin, "qwen2_audio": Qwen2AudioPlugin,
"qwen2_omni": Qwen2OmniPlugin, "qwen2_omni": Qwen2OmniPlugin,
"qwen2_vl": Qwen2VLPlugin, "qwen2_vl": Qwen2VLPlugin,
"qwen3_vl": Qwen3VLPlugin,
"video_llava": VideoLlavaPlugin, "video_llava": VideoLlavaPlugin,
} }

View File

@ -1866,6 +1866,44 @@ register_template(
), ),
) )
register_template(
name="qwen3_omni",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
),
template_class=ReasoningTemplate,
)
register_template(
name="qwen3_omni_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(
name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
),
)
# copied from qwen template # copied from qwen template
register_template( register_template(
name="qwen2_vl", name="qwen2_vl",
@ -1884,6 +1922,41 @@ register_template(
) )
# copied from qwen template
register_template(
name="qwen3_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
# copied from qwen template
register_template(
name="qwen3_vl_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
register_template( register_template(
name="sailor", name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),

View File

@ -3060,6 +3060,31 @@ register_model_group(
multimodal=True, multimodal=True,
) )
register_model_group(
models={
"Qwen/Qwen3-Omni-30B-A3B-Captioner": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
},
"Qwen/Qwen3-Omni-30B-A3B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
},
},
template="qwen3_omni_nothink",
multimodal=True,
)
register_model_group(
models={
"Qwen/Qwen3-Omni-30B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
},
},
template="qwen3_omni",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
@ -3163,6 +3188,30 @@ register_model_group(
) )
register_model_group(
models={
"Qwen/Qwen3-VL-235B-A22B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Thinking",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Thinking",
},
},
template="qwen3_vl",
multimodal=True,
)
register_model_group(
models={
"Qwen/Qwen3-VL-235B-A22B-Instruct": {
DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Instruct",
DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Instruct",
},
},
template="qwen3_vl_nothink",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"Seed-Coder-8B-Base": { "Seed-Coder-8B-Base": {

View File

@ -105,7 +105,7 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
if model_type == "qwen3_moe" or text_architectures == "Qwen3MoeForCausalLM": # for internvl_3_5 if model_type == "qwen3_moe" or text_architectures == "Qwen3MoeForCausalLM":
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])

View File

@ -56,10 +56,17 @@ TEXT_MESSAGES = [
{"role": "assistant", "content": "I am fine!"}, {"role": "assistant", "content": "I am fine!"},
] ]
VIDEO_MESSAGES = [
{"role": "user", "content": "<video>What is in this viode?"},
{"role": "assistant", "content": "A cat."},
]
AUDIOS = [np.zeros(1600)] AUDIOS = [np.zeros(1600)]
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))] IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
VIDEOS = [[Image.new("RGB", (32, 32), (255, 255, 255))] * 4]
NO_IMAGES = [] NO_IMAGES = []
NO_VIDEOS = [] NO_VIDEOS = []
@ -145,6 +152,8 @@ def _check_plugin(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor), plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
expected_mm_inputs, expected_mm_inputs,
) )
elif plugin.__class__.__name__ == "Qwen3VLPlugin": # only check replacement
assert plugin.process_messages(VIDEO_MESSAGES, NO_IMAGES, VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == ( assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
@ -357,6 +366,27 @@ def test_qwen2_vl_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
def test_qwen3_vl_plugin():
frame_seqlen = 1
tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-235B-A22B-Instruct")
qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
check_inputs = {"plugin": qwen3_vl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace(
"<video>", # little different with original processor for default `fps=2` in our repo
"<0.2 seconds><|vision_start|>{}<|vision_end|><1.2 seconds><|vision_start|>{}<|vision_end|>".format(
"<|video_pad|>" * frame_seqlen, "<|video_pad|>" * frame_seqlen
),
)
for key, value in message.items()
}
for message in VIDEO_MESSAGES
]
_check_plugin(**check_inputs)
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
def test_video_llava_plugin(): def test_video_llava_plugin():
image_seqlen = 256 image_seqlen = 256