mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[model] add Qwen2.5-Omni model (#7537)
* preserve image_sizes * preserve image_sizes * init plugin * support audio-text2text lora * nit * support image/video-text2text, audio-text2text * remove args * remove lines * add docs && nit * remove some comments * fix && add merge part script * add license
This commit is contained in:
		
							parent
							
								
									0f8296626a
								
							
						
					
					
						commit
						7eed496336
					
				@ -261,6 +261,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
| [Pixtral](https://huggingface.co/mistralai)                       | 12B                              | pixtral             |
 | 
			
		||||
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen)   | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen                |
 | 
			
		||||
| [Qwen2-Audio](https://huggingface.co/Qwen)                        | 7B                               | qwen2_audio         |
 | 
			
		||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)                       | 7B                               | qwen2_omni          |
 | 
			
		||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen)            | 2B/3B/7B/32B/72B                 | qwen2_vl            |
 | 
			
		||||
| [Skywork o1](https://huggingface.co/Skywork)                      | 8B                               | skywork_o1          |
 | 
			
		||||
| [StarCoder 2](https://huggingface.co/bigcode)                     | 3B/7B/15B                        | -                   |
 | 
			
		||||
 | 
			
		||||
@ -263,6 +263,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
| [Pixtral](https://huggingface.co/mistralai)                       | 12B                              | pixtral             |
 | 
			
		||||
| [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen)   | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen                |
 | 
			
		||||
| [Qwen2-Audio](https://huggingface.co/Qwen)                        | 7B                               | qwen2_audio         |
 | 
			
		||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)                       | 7B                               | qwen2_omni          |
 | 
			
		||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen)            | 2B/3B/7B/32B/72B                 | qwen2_vl            |
 | 
			
		||||
| [Skywork o1](https://huggingface.co/Skywork)                      | 8B                               | skywork_o1          |
 | 
			
		||||
| [StarCoder 2](https://huggingface.co/bigcode)                     | 3B/7B/15B                        | -                   |
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										90
									
								
								scripts/lora_part_merge.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										90
									
								
								scripts/lora_part_merge.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,90 @@
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is based on the HuggingFace's PEFT library.
 | 
			
		||||
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
import os
 | 
			
		||||
import shutil
 | 
			
		||||
 | 
			
		||||
import fire
 | 
			
		||||
from peft import PeftModel
 | 
			
		||||
from transformers import AutoModel, AutoProcessor, AutoTokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def merge_lora(
 | 
			
		||||
    base_model_path: str,
 | 
			
		||||
    lora_checkpoint_path: str,
 | 
			
		||||
    extra_file: str = "spk_dict.pt",
 | 
			
		||||
    submodule_name: str = "thinker",
 | 
			
		||||
    save_path: str = "./merged_model_checkpoint",
 | 
			
		||||
):
 | 
			
		||||
    """Load the original model, tokenizer, and processor configuration, merge the LoRA weights.
 | 
			
		||||
 | 
			
		||||
    for a specified submodule, and save the final merged model along with its configurations.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        base_model_path (str): Path to the original model directory.
 | 
			
		||||
        lora_checkpoint_path (str): Path to the directory containing LoRA weights.
 | 
			
		||||
        extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
 | 
			
		||||
        submodule_name (str): Name of the submodule to merge (default: "thinker").
 | 
			
		||||
        save_path (str): Directory where the merged model and configurations will be saved.
 | 
			
		||||
    """
 | 
			
		||||
    # 1. Load the original model, tokenizer, and processor
 | 
			
		||||
    model = AutoModel.from_pretrained(base_model_path)
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        processor = AutoProcessor.from_pretrained(base_model_path)
 | 
			
		||||
    except Exception:
 | 
			
		||||
        print("Processor configuration not found, skipping processor load.")
 | 
			
		||||
        processor = None
 | 
			
		||||
 | 
			
		||||
    print("Successfully loaded the original model, tokenizer, and processor (if available).")
 | 
			
		||||
 | 
			
		||||
    # 2. Extract the submodule to be merged (e.g., model.thinker)
 | 
			
		||||
    if not hasattr(model, submodule_name):
 | 
			
		||||
        raise AttributeError(f"The model does not have a submodule named '{submodule_name}'.")
 | 
			
		||||
    base_submodule = getattr(model, submodule_name)
 | 
			
		||||
    print(f"Successfully extracted submodule: {submodule_name}.")
 | 
			
		||||
 | 
			
		||||
    # 3. Load the LoRA weights onto the extracted submodule
 | 
			
		||||
    lora_model = PeftModel.from_pretrained(base_submodule, lora_checkpoint_path)
 | 
			
		||||
    print("LoRA weights loaded successfully.")
 | 
			
		||||
 | 
			
		||||
    # 4. Merge the LoRA weights into the submodule and unload the LoRA modules
 | 
			
		||||
    merged_submodule = lora_model.merge_and_unload()
 | 
			
		||||
    print("LoRA weights merged successfully.")
 | 
			
		||||
 | 
			
		||||
    # 5. Replace the original submodule with the merged submodule in the model
 | 
			
		||||
    setattr(model, submodule_name, merged_submodule)
 | 
			
		||||
 | 
			
		||||
    # 6. Save the final merged model along with the tokenizer and processor configuration
 | 
			
		||||
    model.save_pretrained(save_path)
 | 
			
		||||
    tokenizer.save_pretrained(save_path)
 | 
			
		||||
    if processor is not None:
 | 
			
		||||
        processor.save_pretrained(save_path)
 | 
			
		||||
 | 
			
		||||
    print(f"Merged model and configuration saved to {save_path}.")
 | 
			
		||||
 | 
			
		||||
    source_file = os.path.join(base_model_path, extra_file)
 | 
			
		||||
    target_file = os.path.join(save_path, extra_file)
 | 
			
		||||
    if os.path.exists(source_file):
 | 
			
		||||
        shutil.copy(source_file, target_file)
 | 
			
		||||
        print(f"File '{extra_file}' copied from {base_model_path} to {save_path}.")
 | 
			
		||||
    else:
 | 
			
		||||
        print(f"File '{extra_file}' not found in {base_model_path}, skipping copy.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    fire.Fire(merge_lora)
 | 
			
		||||
@ -190,10 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
                "video_grid_thw": mm_inputs.get("video_grid_thw"),
 | 
			
		||||
                "attention_mask": features["attention_mask"],
 | 
			
		||||
            }
 | 
			
		||||
            if "second_per_grid_ts" in mm_inputs:
 | 
			
		||||
            if "second_per_grid_ts" in mm_inputs:  # for qwen2vl
 | 
			
		||||
                rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
 | 
			
		||||
 | 
			
		||||
            features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
 | 
			
		||||
            if getattr(self.model.config, "model_type", None) == "qwen2_5_omni":  # for qwen2omni
 | 
			
		||||
                feature_attention_mask = mm_inputs.get("feature_attention_mask", None)
 | 
			
		||||
                if feature_attention_mask is not None:
 | 
			
		||||
                    audio_feature_lengths = torch.sum(
 | 
			
		||||
                        feature_attention_mask, dim=1
 | 
			
		||||
                    )  # FIXME need to get video image lengths
 | 
			
		||||
                    rope_index_kwargs["audio_seqlens"] = audio_feature_lengths  # prepare for input
 | 
			
		||||
 | 
			
		||||
                delta0 = (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(1)
 | 
			
		||||
                # avoid conflict
 | 
			
		||||
                rope_index_kwargs["second_per_grids"] = mm_inputs.get("video_second_per_grid", None)
 | 
			
		||||
                new_position_ids, rope_deltas = self.model.get_rope_index(**rope_index_kwargs)
 | 
			
		||||
                features["position_ids"], features["rope_deltas"] = (
 | 
			
		||||
                    new_position_ids.clone(),
 | 
			
		||||
                    rope_deltas - delta0,
 | 
			
		||||
                )  # avoid inplace operation FIXME
 | 
			
		||||
            else:  # for qwen2vl
 | 
			
		||||
                features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
 | 
			
		||||
 | 
			
		||||
        if "cross_attention_mask" in mm_inputs:  # for mllama inputs when pad_to_multiple_of is enabled
 | 
			
		||||
            cross_attention_mask = mm_inputs.pop("cross_attention_mask")
 | 
			
		||||
 | 
			
		||||
@ -146,6 +146,12 @@ class MMPluginMixin:
 | 
			
		||||
        video_processor: BaseImageProcessor = getattr(
 | 
			
		||||
            processor, "video_processor", getattr(processor, "image_processor", None)
 | 
			
		||||
        )
 | 
			
		||||
        if image_processor is None and video_processor is None:  # hack for qwen2_5_omni
 | 
			
		||||
            image_processor, video_processor = (
 | 
			
		||||
                getattr(processor, "omni_processor", None),
 | 
			
		||||
                getattr(processor, "omni_processor", None),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
        if len(images) != 0 and self.image_token is None:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
@ -1104,6 +1110,186 @@ class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Qwen2OmniPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
        imglens: Optional[list[int]] = None,
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None)  # FIXME
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )
 | 
			
		||||
            if imglens is not None:
 | 
			
		||||
                images = _make_batched_images(images, imglens)
 | 
			
		||||
 | 
			
		||||
            image_processor_kwargs = {}
 | 
			
		||||
            mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            video_processor: BaseImageProcessor = getattr(
 | 
			
		||||
                processor, "video_processor", getattr(processor, "omni_processor", None)
 | 
			
		||||
            )
 | 
			
		||||
            videos = self._regularize_videos(
 | 
			
		||||
                videos,
 | 
			
		||||
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
 | 
			
		||||
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )
 | 
			
		||||
            if "videos" in inspect.signature(video_processor.preprocess).parameters:  # for qwen2_vl and video_llava
 | 
			
		||||
                mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
 | 
			
		||||
                fps = [2.0] * len(videos)  # FIXME hardcode
 | 
			
		||||
                video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))]
 | 
			
		||||
                mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid)
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
        if len(audios) != 0:
 | 
			
		||||
            feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
            audios = self._regularize_audios(
 | 
			
		||||
                audios,
 | 
			
		||||
                sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs.update(
 | 
			
		||||
                feature_extractor(
 | 
			
		||||
                    audios,
 | 
			
		||||
                    sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
                    return_attention_mask=True,
 | 
			
		||||
                    padding="max_length",
 | 
			
		||||
                    return_tensors="pt",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask")  # prevent conflicts
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
 | 
			
		||||
        use_audio_in_video = getattr(processor, "use_audio_in_video", False)
 | 
			
		||||
 | 
			
		||||
        # get length or size from mm_inputs
 | 
			
		||||
        if "feature_attention_mask" in mm_inputs:
 | 
			
		||||
            input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
 | 
			
		||||
            audio_lengths = (input_lengths - 2) // 2 + 1
 | 
			
		||||
        if mm_inputs.get("image_grid_thw", None) is not None:
 | 
			
		||||
            image_grid_thw = mm_inputs["image_grid_thw"]
 | 
			
		||||
            merge_length = processor.omni_processor.merge_size**2
 | 
			
		||||
        if mm_inputs.get("video_grid_thw", None) is not None:
 | 
			
		||||
            video_grid_thw = mm_inputs["video_grid_thw"]
 | 
			
		||||
            merge_length = processor.omni_processor.merge_size**2
 | 
			
		||||
 | 
			
		||||
        if use_audio_in_video:
 | 
			
		||||
            assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
 | 
			
		||||
            assert mm_inputs.get("video_grid_thw", None) is not None, (
 | 
			
		||||
                "video_grid_thw should be exist when use_audio_in_video is `True`"
 | 
			
		||||
            )
 | 
			
		||||
            positions_list = []
 | 
			
		||||
            for i, message in enumerate(messages):  # get multimodal index when use_audio
 | 
			
		||||
                positions = []
 | 
			
		||||
                for special_token in [self.audio_token, self.image_token, self.video_token]:
 | 
			
		||||
                    start = 0
 | 
			
		||||
                    while True:
 | 
			
		||||
                        pos = message[i].find(special_token, start)
 | 
			
		||||
                        if pos == -1:
 | 
			
		||||
                            break
 | 
			
		||||
                        positions.append((pos, special_token))
 | 
			
		||||
                        start = pos + len(special_token)
 | 
			
		||||
                positions_list.append(positions.sort(key=lambda x: x[0]))
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            # separate with audio-video
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    IMAGE_PLACEHOLDER,
 | 
			
		||||
                    f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
 | 
			
		||||
                    1,
 | 
			
		||||
                )
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            if not use_audio_in_video:
 | 
			
		||||
                while AUDIO_PLACEHOLDER in content:
 | 
			
		||||
                    audio_token_replace_length = audio_lengths[num_audio_tokens]
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        AUDIO_PLACEHOLDER,
 | 
			
		||||
                        f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
 | 
			
		||||
                        1,
 | 
			
		||||
                    )
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
                # TODO handle video_input and use_audio_in_video
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
 | 
			
		||||
                    content = content.replace(
 | 
			
		||||
                        VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
 | 
			
		||||
                    )
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
            else:  # if use the audio of video # deal video token and audio token togather
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
 | 
			
		||||
                    video_t_index = (
 | 
			
		||||
                        torch.arange(video_grid_thw[num_video_tokens][0])
 | 
			
		||||
                        .view(-1, 1, 1)
 | 
			
		||||
                        .expand(
 | 
			
		||||
                            -1,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size,
 | 
			
		||||
                            video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size,
 | 
			
		||||
                        )
 | 
			
		||||
                        .flatten()
 | 
			
		||||
                        * mm_inputs["video_second_per_grid"][num_video_tokens]
 | 
			
		||||
                        * 25  # FIXME hardcode of position_id_per_seconds=25
 | 
			
		||||
                    ).long()
 | 
			
		||||
                    t_ntoken_per_chunk = 50  # FIXME hardcode: [25 * 2]
 | 
			
		||||
                    video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
 | 
			
		||||
                    audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
 | 
			
		||||
                    placeholder_string = ""
 | 
			
		||||
                    for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
 | 
			
		||||
                        video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
 | 
			
		||||
                        audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
 | 
			
		||||
                        placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
 | 
			
		||||
                        if video_chunk_index is not None:
 | 
			
		||||
                            placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
 | 
			
		||||
                        if audio_chunk_index is not None:
 | 
			
		||||
                            placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
 | 
			
		||||
                        placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
 | 
			
		||||
                    content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
 | 
			
		||||
                    content = content.replace(AUDIO_PLACEHOLDER, "", 1)
 | 
			
		||||
                    num_audio_tokens += 1
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        if len(audios) != num_audio_tokens:
 | 
			
		||||
            raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
        if len(videos) != num_video_tokens:
 | 
			
		||||
            raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
@ -1328,6 +1514,7 @@ PLUGINS = {
 | 
			
		||||
    "paligemma": PaliGemmaPlugin,
 | 
			
		||||
    "pixtral": PixtralPlugin,
 | 
			
		||||
    "qwen2_audio": Qwen2AudioPlugin,
 | 
			
		||||
    "qwen2_omni": Qwen2OmniPlugin,
 | 
			
		||||
    "qwen2_vl": Qwen2VLPlugin,
 | 
			
		||||
    "video_llava": VideoLlavaPlugin,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1367,6 +1367,24 @@ register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from qwen template
 | 
			
		||||
register_template(
 | 
			
		||||
    name="qwen2_omni",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="qwen"),
 | 
			
		||||
    default_system="You are a helpful assistant.",
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    mm_plugin=get_mm_plugin(
 | 
			
		||||
        name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
 | 
			
		||||
    ),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# copied from qwen template
 | 
			
		||||
register_template(
 | 
			
		||||
    name="qwen2_vl",
 | 
			
		||||
 | 
			
		||||
@ -2270,6 +2270,18 @@ register_model_group(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Qwen2.5-Omni-7B": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
 | 
			
		||||
        }
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen2_omni",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Qwen2-VL-2B": {
 | 
			
		||||
 | 
			
		||||
@ -222,6 +222,10 @@ class ProcessorArguments:
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Use pan and scan to process image for gemma3."},
 | 
			
		||||
    )
 | 
			
		||||
    use_audio_in_video: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use audio in video inputs."},
 | 
			
		||||
    )
 | 
			
		||||
    video_max_pixels: int = field(
 | 
			
		||||
        default=256 * 256,
 | 
			
		||||
        metadata={"help": "The maximum number of pixels of video inputs."},
 | 
			
		||||
 | 
			
		||||
@ -21,6 +21,7 @@ from transformers import (
 | 
			
		||||
    AutoModelForCausalLM,
 | 
			
		||||
    AutoModelForImageTextToText,
 | 
			
		||||
    AutoModelForSeq2SeqLM,
 | 
			
		||||
    AutoModelForTextToWaveform,
 | 
			
		||||
    AutoModelForVision2Seq,
 | 
			
		||||
    AutoProcessor,
 | 
			
		||||
    AutoTokenizer,
 | 
			
		||||
@ -147,6 +148,8 @@ def load_model(
 | 
			
		||||
                load_class = AutoModelForImageTextToText
 | 
			
		||||
            elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():  # audio-text
 | 
			
		||||
                load_class = AutoModelForSeq2SeqLM
 | 
			
		||||
            elif type(config) in AutoModelForTextToWaveform._model_mapping.keys():  # audio hack for qwen2_5_omni
 | 
			
		||||
                load_class = AutoModelForTextToWaveform
 | 
			
		||||
            else:
 | 
			
		||||
                load_class = AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
@ -154,6 +157,8 @@ def load_model(
 | 
			
		||||
                model = load_class.from_config(config, trust_remote_code=model_args.trust_remote_code)
 | 
			
		||||
            else:
 | 
			
		||||
                model = load_class.from_pretrained(**init_kwargs)
 | 
			
		||||
                if load_class is AutoModelForTextToWaveform:
 | 
			
		||||
                    model = model.thinker  # use part of Omni model
 | 
			
		||||
 | 
			
		||||
        if model_args.mixture_of_depths == "convert":
 | 
			
		||||
            model = convert_pretrained_model_to_mod(model, config, model_args)
 | 
			
		||||
 | 
			
		||||
@ -257,6 +257,17 @@ _register_composite_model(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="qwen2_5_omni_thinker",
 | 
			
		||||
    projector_key="visual.merger",
 | 
			
		||||
    vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
 | 
			
		||||
    language_model_keys=["model", "lm_head"],
 | 
			
		||||
    lora_conflict_keys=[
 | 
			
		||||
        "patch_embed",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="qwen2_vl",
 | 
			
		||||
    projector_key="visual.merger",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user