mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[model] add qwen3-vl/qwen3-omni (#9196)
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
		
							parent
							
								
									abc3b1e1c4
								
							
						
					
					
						commit
						0761a4448f
					
				@ -134,7 +134,7 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis
 | 
			
		||||
 | 
			
		||||
def _check_video_is_nested_images(video: "VideoInput") -> bool:
 | 
			
		||||
    r"""Check if the video is nested images."""
 | 
			
		||||
    return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video)
 | 
			
		||||
    return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict, ImageObject)) for frame in video)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -1531,6 +1531,119 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Qwen3VLPlugin(Qwen2VLPlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
        video_processor: BaseImageProcessor = getattr(processor, "video_processor", None)
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )["images"]
 | 
			
		||||
            mm_inputs.update(image_processor(images, return_tensors="pt"))
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            videos = self._regularize_videos(
 | 
			
		||||
                videos,
 | 
			
		||||
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
 | 
			
		||||
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            video_metadata = [
 | 
			
		||||
                {"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
 | 
			
		||||
                for video in videos["videos"]
 | 
			
		||||
            ]
 | 
			
		||||
            mm_inputs.update(
 | 
			
		||||
                video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
 | 
			
		||||
            )
 | 
			
		||||
            temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
 | 
			
		||||
            if "second_per_grid_ts" in processor.model_input_names:
 | 
			
		||||
                mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in videos["fps_per_video"]]
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        self._validate_messages(messages, images, videos, audios)
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor")
 | 
			
		||||
        video_processor: BaseImageProcessor = getattr(processor, "video_processor")
 | 
			
		||||
 | 
			
		||||
        image_merge_length: int = getattr(image_processor, "merge_size") ** 2
 | 
			
		||||
        video_merge_length: int = getattr(video_processor, "merge_size") ** 2
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
            image_grid_thw = mm_inputs.get("image_grid_thw", [])
 | 
			
		||||
            video_grid_thw = mm_inputs.get("video_grid_thw", [])
 | 
			
		||||
            num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0  # hard code for now
 | 
			
		||||
            video_metadata = mm_inputs.get("video_metadata", {})
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            image_grid_thw = [None] * len(images)
 | 
			
		||||
            video_grid_thw = [None] * len(videos)
 | 
			
		||||
            num_frames = 0
 | 
			
		||||
            timestamps = [0]
 | 
			
		||||
 | 
			
		||||
        for idx, message in enumerate(messages):
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                image_seqlen = (
 | 
			
		||||
                    image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
 | 
			
		||||
                )
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
 | 
			
		||||
                )
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                metadata = video_metadata[idx]
 | 
			
		||||
                timestamps = processor._calculate_timestamps(
 | 
			
		||||
                    metadata.frames_indices,
 | 
			
		||||
                    metadata.fps,
 | 
			
		||||
                    video_processor.merge_size,
 | 
			
		||||
                )
 | 
			
		||||
                video_structure = ""
 | 
			
		||||
                for frame_index in range(num_frames):
 | 
			
		||||
                    video_seqlen = (
 | 
			
		||||
                        video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
 | 
			
		||||
                        if self.expand_mm_tokens
 | 
			
		||||
                        else 1
 | 
			
		||||
                    )
 | 
			
		||||
                    timestamp_sec = timestamps[frame_index]
 | 
			
		||||
                    frame_structure = f"<{timestamp_sec:.1f} seconds>{self.start_token}{self.video_token * video_seqlen}{self.end_token}"
 | 
			
		||||
                    video_structure += frame_structure
 | 
			
		||||
 | 
			
		||||
                if not self.expand_mm_tokens:
 | 
			
		||||
                    video_structure = f"{self.start_token}{self.video_token}{self.end_token}"
 | 
			
		||||
 | 
			
		||||
                content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
 | 
			
		||||
                num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class GLM4VPlugin(Qwen2VLPlugin):
 | 
			
		||||
    @override
 | 
			
		||||
@ -1893,6 +2006,7 @@ PLUGINS = {
 | 
			
		||||
    "qwen2_audio": Qwen2AudioPlugin,
 | 
			
		||||
    "qwen2_omni": Qwen2OmniPlugin,
 | 
			
		||||
    "qwen2_vl": Qwen2VLPlugin,
 | 
			
		||||
    "qwen3_vl": Qwen3VLPlugin,
 | 
			
		||||
    "video_llava": VideoLlavaPlugin,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1866,6 +1866,44 @@ register_template(
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="qwen3_omni",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="qwen"),
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    replace_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin(
 | 
			
		||||
        name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
 | 
			
		||||
    ),
 | 
			
		||||
    template_class=ReasoningTemplate,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="qwen3_omni_nothink",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="qwen"),
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    replace_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin(
 | 
			
		||||
        name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from qwen template
 | 
			
		||||
register_template(
 | 
			
		||||
    name="qwen2_vl",
 | 
			
		||||
@ -1884,6 +1922,41 @@ register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from qwen template
 | 
			
		||||
register_template(
 | 
			
		||||
    name="qwen3_vl",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="qwen"),
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    replace_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
 | 
			
		||||
    template_class=ReasoningTemplate,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from qwen template
 | 
			
		||||
register_template(
 | 
			
		||||
    name="qwen3_vl_nothink",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="qwen"),
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    replace_eos=True,
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="sailor",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
 | 
			
		||||
 | 
			
		||||
@ -3060,6 +3060,31 @@ register_model_group(
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Qwen/Qwen3-Omni-30B-A3B-Captioner": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Captioner",
 | 
			
		||||
        },
 | 
			
		||||
        "Qwen/Qwen3-Omni-30B-A3B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Instruct",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen3_omni_nothink",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Qwen/Qwen3-Omni-30B-A3B-Thinking": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Thinking",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen3_omni",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
@ -3163,6 +3188,30 @@ register_model_group(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Qwen/Qwen3-VL-235B-A22B-Thinking": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Thinking",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Thinking",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen3_vl",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Qwen/Qwen3-VL-235B-A22B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Instruct",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Instruct",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen3_vl_nothink",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Seed-Coder-8B-Base": {
 | 
			
		||||
 | 
			
		||||
@ -105,7 +105,7 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])
 | 
			
		||||
 | 
			
		||||
    if model_type == "qwen3_moe" or text_architectures == "Qwen3MoeForCausalLM":  # for internvl_3_5
 | 
			
		||||
    if model_type == "qwen3_moe" or text_architectures == "Qwen3MoeForCausalLM":
 | 
			
		||||
        from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
 | 
			
		||||
 | 
			
		||||
        _set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock])
 | 
			
		||||
 | 
			
		||||
@ -56,10 +56,17 @@ TEXT_MESSAGES = [
 | 
			
		||||
    {"role": "assistant", "content": "I am fine!"},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
VIDEO_MESSAGES = [
 | 
			
		||||
    {"role": "user", "content": "<video>What is in this viode?"},
 | 
			
		||||
    {"role": "assistant", "content": "A cat."},
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
AUDIOS = [np.zeros(1600)]
 | 
			
		||||
 | 
			
		||||
IMAGES = [Image.new("RGB", (32, 32), (255, 255, 255))]
 | 
			
		||||
 | 
			
		||||
VIDEOS = [[Image.new("RGB", (32, 32), (255, 255, 255))] * 4]
 | 
			
		||||
 | 
			
		||||
NO_IMAGES = []
 | 
			
		||||
 | 
			
		||||
NO_VIDEOS = []
 | 
			
		||||
@ -145,6 +152,8 @@ def _check_plugin(
 | 
			
		||||
            plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
 | 
			
		||||
            expected_mm_inputs,
 | 
			
		||||
        )
 | 
			
		||||
    elif plugin.__class__.__name__ == "Qwen3VLPlugin":  # only check replacement
 | 
			
		||||
        assert plugin.process_messages(VIDEO_MESSAGES, NO_IMAGES, VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
 | 
			
		||||
    elif plugin.__class__.__name__ != "BasePlugin":  # test mm_messages
 | 
			
		||||
        assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
 | 
			
		||||
        assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
 | 
			
		||||
@ -357,6 +366,27 @@ def test_qwen2_vl_plugin():
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.57.0"), reason="Requires transformers>=4.57.0")
 | 
			
		||||
def test_qwen3_vl_plugin():
 | 
			
		||||
    frame_seqlen = 1
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen3-VL-235B-A22B-Instruct")
 | 
			
		||||
    qwen3_vl_plugin = get_mm_plugin(name="qwen3_vl", video_token="<|video_pad|>")
 | 
			
		||||
    check_inputs = {"plugin": qwen3_vl_plugin, **tokenizer_module}
 | 
			
		||||
    check_inputs["expected_mm_messages"] = [
 | 
			
		||||
        {
 | 
			
		||||
            key: value.replace(
 | 
			
		||||
                "<video>",  # little different with original processor for default `fps=2` in our repo
 | 
			
		||||
                "<0.2 seconds><|vision_start|>{}<|vision_end|><1.2 seconds><|vision_start|>{}<|vision_end|>".format(
 | 
			
		||||
                    "<|video_pad|>" * frame_seqlen, "<|video_pad|>" * frame_seqlen
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
            for key, value in message.items()
 | 
			
		||||
        }
 | 
			
		||||
        for message in VIDEO_MESSAGES
 | 
			
		||||
    ]
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
 | 
			
		||||
def test_video_llava_plugin():
 | 
			
		||||
    image_seqlen = 256
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user