diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 607fe325..b0069212 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -134,7 +134,7 @@ def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> lis def _check_video_is_nested_images(video: "VideoInput") -> bool: r"""Check if the video is nested images.""" - return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict)) for frame in video) + return isinstance(video, list) and all(isinstance(frame, (str, BinaryIO, dict, ImageObject)) for frame in video) @dataclass @@ -1531,6 +1531,119 @@ class Qwen2VLPlugin(BasePlugin): return messages +@dataclass +class Qwen3VLPlugin(Qwen2VLPlugin): + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "MMProcessor", + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) + video_processor: BaseImageProcessor = getattr(processor, "video_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + video_metadata = [ + {"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)} + for video in videos["videos"] + ] + mm_inputs.update( + video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True) + ) + temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) + if "second_per_grid_ts" in processor.model_input_names: + mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in videos["fps_per_video"]] + + return mm_inputs + + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + video_processor: BaseImageProcessor = getattr(processor, "video_processor") + + image_merge_length: int = getattr(image_processor, "merge_size") ** 2 + video_merge_length: int = getattr(video_processor, "merge_size") ** 2 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now + video_metadata = mm_inputs.get("video_metadata", {}) + + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + num_frames = 0 + timestamps = [0] + + for idx, message in enumerate(messages): + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = ( + image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1 + ) + content = content.replace( + IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1 + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + metadata = video_metadata[idx] + timestamps = processor._calculate_timestamps( + metadata.frames_indices, + metadata.fps, + video_processor.merge_size, + ) + video_structure = "" + for frame_index in range(num_frames): + video_seqlen = ( + video_grid_thw[num_video_tokens][1:].prod() // video_merge_length + if self.expand_mm_tokens + else 1 + ) + timestamp_sec = timestamps[frame_index] + frame_structure = f"<{timestamp_sec:.1f} seconds>{self.start_token}{self.video_token * video_seqlen}{self.end_token}" + video_structure += frame_structure + + if not self.expand_mm_tokens: + video_structure = f"{self.start_token}{self.video_token}{self.end_token}" + + content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1) + num_video_tokens += 1 + + message["content"] = content + + return messages + + @dataclass class GLM4VPlugin(Qwen2VLPlugin): @override @@ -1893,6 +2006,7 @@ PLUGINS = { "qwen2_audio": Qwen2AudioPlugin, "qwen2_omni": Qwen2OmniPlugin, "qwen2_vl": Qwen2VLPlugin, + "qwen3_vl": Qwen3VLPlugin, "video_llava": VideoLlavaPlugin, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index d47b2b83..330ff50c 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1866,6 +1866,44 @@ register_template( ), ) + +register_template( + name="qwen3_omni", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), + stop_words=["<|im_end|>"], + replace_eos=True, + mm_plugin=get_mm_plugin( + name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" + ), + template_class=ReasoningTemplate, +) + + +register_template( + name="qwen3_omni_nothink", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), + stop_words=["<|im_end|>"], + replace_eos=True, + mm_plugin=get_mm_plugin( + name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" + ), +) + + # copied from qwen template register_template( name="qwen2_vl", @@ -1884,6 +1922,41 @@ register_template( ) +# copied from qwen template +register_template( + name="qwen3_vl", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), + stop_words=["<|im_end|>"], + replace_eos=True, + mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), + template_class=ReasoningTemplate, +) + + +# copied from qwen template +register_template( + name="qwen3_vl_nothink", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), + stop_words=["<|im_end|>"], + replace_eos=True, + mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), +) + + register_template( name="sailor", format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 134e3ce2..e3f5d708 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -3060,6 +3060,31 @@ register_model_group( multimodal=True, ) +register_model_group( + models={ + "Qwen/Qwen3-Omni-30B-A3B-Captioner": { + DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Captioner", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Captioner", + }, + "Qwen/Qwen3-Omni-30B-A3B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Instruct", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Instruct", + }, + }, + template="qwen3_omni_nothink", + multimodal=True, +) + +register_model_group( + models={ + "Qwen/Qwen3-Omni-30B-A3B-Thinking": { + DownloadSource.DEFAULT: "Qwen/Qwen3-Omni-30B-A3B-Thinking", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-Omni-30B-A3B-Thinking", + }, + }, + template="qwen3_omni", + multimodal=True, +) register_model_group( models={ @@ -3163,6 +3188,30 @@ register_model_group( ) +register_model_group( + models={ + "Qwen/Qwen3-VL-235B-A22B-Thinking": { + DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Thinking", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Thinking", + }, + }, + template="qwen3_vl", + multimodal=True, +) + + +register_model_group( + models={ + "Qwen/Qwen3-VL-235B-A22B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen3-VL-235B-A22B-Instruct", + DownloadSource.MODELSCOPE: "Qwen/Qwen3-VL-235B-A22B-Instruct", + }, + }, + template="qwen3_vl_nothink", + multimodal=True, +) + + register_model_group( models={ "Seed-Coder-8B-Base": { diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index 2cd47d1e..a99c0b93 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -105,7 +105,7 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) - if model_type == "qwen3_moe" or text_architectures == "Qwen3MoeForCausalLM": # for internvl_3_5 + if model_type == "qwen3_moe" or text_architectures == "Qwen3MoeForCausalLM": from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock _set_z3_leaf_modules(model, [Qwen3MoeSparseMoeBlock]) diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 7f05e563..406307d0 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -56,10 +56,17 @@ TEXT_MESSAGES = [ {"role": "assistant", "content": "I am fine!"}, ] +VIDEO_MESSAGES = [ + {"role": "user", "content": "