mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[model] support gemma3 (#7273)
This commit is contained in:
		
							parent
							
								
									e6159ad730
								
							
						
					
					
						commit
						4b9d8da5a4
					
				
							
								
								
									
										13
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								README.md
									
									
									
									
									
								
							@ -84,10 +84,10 @@ Choose your path:
 | 
			
		||||
 | 
			
		||||
### Day-N Support for Fine-Tuning Cutting-Edge Models
 | 
			
		||||
 | 
			
		||||
| Support Date | Model Name                                                 |
 | 
			
		||||
| ------------ | ---------------------------------------------------------- |
 | 
			
		||||
| Day 0        | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 |
 | 
			
		||||
| Day 1        | Llama 3 / GLM-4 / Mistral Small / PaliGemma2               |
 | 
			
		||||
| Support Date | Model Name                                                   |
 | 
			
		||||
| ------------ | ------------------------------------------------------------ |
 | 
			
		||||
| Day 0        | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6  |
 | 
			
		||||
| Day 1        | Llama 3 / GLM-4 / Mistral Small / PaliGemma2                 |
 | 
			
		||||
 | 
			
		||||
## Benchmark
 | 
			
		||||
 | 
			
		||||
@ -106,6 +106,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
## Changelog
 | 
			
		||||
 | 
			
		||||
[25/03/12] We supported fine-tuning the **[Gemma-3](https://huggingface.co/blog/gemma3)** model.
 | 
			
		||||
 | 
			
		||||
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
 | 
			
		||||
 | 
			
		||||
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
 | 
			
		||||
@ -120,7 +122,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
 | 
			
		||||
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
 | 
			
		||||
 | 
			
		||||
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
 | 
			
		||||
[25/01/14] We supported fine-tuning the **[InternLM 3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
 | 
			
		||||
 | 
			
		||||
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
 | 
			
		||||
 | 
			
		||||
@ -229,6 +231,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai)       | 1.5B/7B/8B/14B/32B/70B/671B      | deepseek3           |
 | 
			
		||||
| [Falcon](https://huggingface.co/tiiuae)                           | 7B/11B/40B/180B                  | falcon              |
 | 
			
		||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google)          | 2B/7B/9B/27B                     | gemma               |
 | 
			
		||||
| [Gemma 3](https://huggingface.co/google)                          | 1B/4B/12B/27B                    | gemma3              |
 | 
			
		||||
| [GLM-4](https://huggingface.co/THUDM)                             | 9B                               | glm4                |
 | 
			
		||||
| [GPT-2](https://huggingface.co/openai-community)                  | 0.1B/0.4B/0.8B/1.5B              | -                   |
 | 
			
		||||
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite)             | 1B/2B/3B/8B                      | granite3            |
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								README_zh.md
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								README_zh.md
									
									
									
									
									
								
							@ -86,10 +86,10 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
 | 
			
		||||
### 最新模型的 Day-N 微调适配
 | 
			
		||||
 | 
			
		||||
| 适配时间      | 模型名称                                                    |
 | 
			
		||||
| ------------ | ---------------------------------------------------------- |
 | 
			
		||||
| Day 0        | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 |
 | 
			
		||||
| Day 1        | Llama 3 / GLM-4 / Mistral Small / PaliGemma2               |
 | 
			
		||||
| 适配时间      | 模型名称                                                       |
 | 
			
		||||
| ------------ | ------------------------------------------------------------ |
 | 
			
		||||
| Day 0        | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6  |
 | 
			
		||||
| Day 1        | Llama 3 / GLM-4 / Mistral Small / PaliGemma2                 |
 | 
			
		||||
 | 
			
		||||
## 性能指标
 | 
			
		||||
 | 
			
		||||
@ -108,6 +108,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
 | 
			
		||||
## 更新日志
 | 
			
		||||
 | 
			
		||||
[25/03/12] 我们支持了 **[Gemma-3](https://huggingface.co/blog/gemma3)** 模型的微调。
 | 
			
		||||
 | 
			
		||||
[25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。
 | 
			
		||||
 | 
			
		||||
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
 | 
			
		||||
@ -122,7 +124,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
 | 
			
		||||
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
 | 
			
		||||
 | 
			
		||||
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
 | 
			
		||||
[25/01/14] 我们支持了 **[InternLM 3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
 | 
			
		||||
 | 
			
		||||
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
 | 
			
		||||
 | 
			
		||||
@ -231,6 +233,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
 | 
			
		||||
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai)       | 1.5B/7B/8B/14B/32B/70B/671B      | deepseek3           |
 | 
			
		||||
| [Falcon](https://huggingface.co/tiiuae)                           | 7B/11B/40B/180B                  | falcon              |
 | 
			
		||||
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google)          | 2B/7B/9B/27B                     | gemma               |
 | 
			
		||||
| [Gemma 3](https://huggingface.co/google)                          | 1B/4B/12B/27B                    | gemma3              |
 | 
			
		||||
| [GLM-4](https://huggingface.co/THUDM)                             | 9B                               | glm4                |
 | 
			
		||||
| [GPT-2](https://huggingface.co/openai-community)                  | 0.1B/0.4B/0.8B/1.5B              | -                   |
 | 
			
		||||
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite)             | 1B/2B/3B/8B                      | granite3            |
 | 
			
		||||
 | 
			
		||||
@ -27,6 +27,7 @@ indent-width = 4
 | 
			
		||||
ignore = [
 | 
			
		||||
    "C408", # collection
 | 
			
		||||
    "C901", # complex
 | 
			
		||||
    "E501", # line too long
 | 
			
		||||
    "E731", # lambda function
 | 
			
		||||
    "E741", # ambiguous var name
 | 
			
		||||
    "D100", # no doc public module
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,20 @@
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's Transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import inspect
 | 
			
		||||
import math
 | 
			
		||||
import re
 | 
			
		||||
@ -5,7 +22,7 @@ from collections.abc import Sequence
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from typing import TYPE_CHECKING, Optional, TypedDict, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -56,24 +73,63 @@ if TYPE_CHECKING:
 | 
			
		||||
    VideoInput = str
 | 
			
		||||
    AudioInput = Union[str, NDArray]
 | 
			
		||||
 | 
			
		||||
    class MMProcessor(ProcessorMixin):
 | 
			
		||||
        patch_size: int
 | 
			
		||||
        image_seq_length: int
 | 
			
		||||
        num_additional_image_tokens: int
 | 
			
		||||
        vision_feature_select_strategy: Literal["default", "full"]
 | 
			
		||||
 | 
			
		||||
        def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_paligemma_token_type_ids(
 | 
			
		||||
    imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
 | 
			
		||||
    imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor"
 | 
			
		||||
) -> list[list[int]]:
 | 
			
		||||
    r"""Get paligemma token type ids for computing loss.
 | 
			
		||||
 | 
			
		||||
    It is slightly different with the original token type ids where the prompt part is 0.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        batch_token_type_ids: shape (batch_size, sequence_length)
 | 
			
		||||
        batch_token_type_ids: shape (batch_size, seq_length)
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    batch_token_type_ids = []
 | 
			
		||||
    for imglen, seqlen in zip(imglens, seqlens):
 | 
			
		||||
        image_seqlen = imglen * getattr(processor, "image_seqlen")
 | 
			
		||||
        image_seqlen = imglen * processor.image_seq_length
 | 
			
		||||
        batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
 | 
			
		||||
 | 
			
		||||
    return batch_token_type_ids
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
 | 
			
		||||
    r"""Get gemma3 token type ids for computing loss.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        batch_token_type_ids: shape (batch_size, seq_length)
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    image_token_id: int = getattr(processor, "image_token_id")
 | 
			
		||||
    batch_token_type_ids = []
 | 
			
		||||
    for token_ids in batch_ids:
 | 
			
		||||
        token_ids = np.array(token_ids)
 | 
			
		||||
        token_type_ids = np.zeros_like(token_ids)
 | 
			
		||||
        token_type_ids[token_ids == image_token_id] = 1
 | 
			
		||||
        batch_token_type_ids.append(token_type_ids.tolist())
 | 
			
		||||
 | 
			
		||||
    return batch_token_type_ids
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
 | 
			
		||||
    r"""Make nested list of images."""
 | 
			
		||||
    batch_images = []
 | 
			
		||||
    for imglen in imglens:
 | 
			
		||||
        batch_images.append(images[:imglen])
 | 
			
		||||
        images = images[imglen:]
 | 
			
		||||
 | 
			
		||||
    return batch_images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class MMPluginMixin:
 | 
			
		||||
    image_token: Optional[str]
 | 
			
		||||
@ -83,7 +139,7 @@ class MMPluginMixin:
 | 
			
		||||
 | 
			
		||||
    def _validate_input(
 | 
			
		||||
        self,
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
@ -204,7 +260,8 @@ class MMPluginMixin:
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: "ProcessorMixin",
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
        imglens: Optional[list[int]] = None,
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        r"""Process visual inputs.
 | 
			
		||||
 | 
			
		||||
@ -214,23 +271,34 @@ class MMPluginMixin:
 | 
			
		||||
        Returns: (qwen2-vl)
 | 
			
		||||
            pixel_values: tensor with shape (num_patches, patch_dim)
 | 
			
		||||
            image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
 | 
			
		||||
                            where num_patches == torch.prod(image_grid_thw)
 | 
			
		||||
 | 
			
		||||
        Returns: (mllama)
 | 
			
		||||
            pixel_values: tensor with shape
 | 
			
		||||
                          (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
 | 
			
		||||
                          For example, (2, 1, 4, 3, 560, 560).
 | 
			
		||||
            aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
 | 
			
		||||
            aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
 | 
			
		||||
            num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
 | 
			
		||||
 | 
			
		||||
        It holds num_patches == torch.prod(image_grid_thw)
 | 
			
		||||
        """
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
        video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor)
 | 
			
		||||
        feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )
 | 
			
		||||
            if imglens is not None:
 | 
			
		||||
                images = _make_batched_images(images, imglens)
 | 
			
		||||
 | 
			
		||||
            mm_inputs.update(image_processor(images, return_tensors="pt"))
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            video_processor: BaseImageProcessor = getattr(
 | 
			
		||||
                processor, "video_processor", getattr(processor, "image_processor", None)
 | 
			
		||||
            )
 | 
			
		||||
            videos = self._regularize_videos(
 | 
			
		||||
                videos,
 | 
			
		||||
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
 | 
			
		||||
@ -244,6 +312,7 @@ class MMPluginMixin:
 | 
			
		||||
                mm_inputs.update(video_processor(videos, return_tensors="pt"))
 | 
			
		||||
 | 
			
		||||
        if len(audios) != 0:
 | 
			
		||||
            feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
 | 
			
		||||
            audios = self._regularize_audios(
 | 
			
		||||
                audios,
 | 
			
		||||
                sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
@ -270,9 +339,9 @@ class BasePlugin(MMPluginMixin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        r"""Pre-processes input messages before tokenization for VLMs."""
 | 
			
		||||
        r"""Pre-process input messages before tokenization for VLMs."""
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
@ -284,9 +353,9 @@ class BasePlugin(MMPluginMixin):
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> tuple[list[int], Optional[list[int]]]:
 | 
			
		||||
        r"""Pre-processes token ids after tokenization for VLMs."""
 | 
			
		||||
        r"""Pre-process token ids after tokenization for VLMs."""
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return input_ids, labels
 | 
			
		||||
 | 
			
		||||
@ -299,7 +368,7 @@ class BasePlugin(MMPluginMixin):
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        r"""Build batched multimodal inputs for VLMs.
 | 
			
		||||
 | 
			
		||||
@ -315,11 +384,11 @@ class BasePlugin(MMPluginMixin):
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return {}
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class LlavaPlugin(BasePlugin):
 | 
			
		||||
class Gemma3Plugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
@ -327,19 +396,21 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        boi_token: str = getattr(processor, "boi_token")
 | 
			
		||||
        full_image_sequence: str = getattr(processor, "full_image_sequence")
 | 
			
		||||
        image_str = full_image_sequence if self.expand_mm_tokens else boi_token
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content.replace("{{image}}", self.image_token)
 | 
			
		||||
            message["content"] = content.replace("{{image}}", image_str)
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
@ -356,10 +427,53 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        mm_inputs.pop("num_crops", None)
 | 
			
		||||
        mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class LlavaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
            if "pixel_values" in mm_inputs:
 | 
			
		||||
                height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
 | 
			
		||||
                image_seqlen = (height // processor.patch_size) * (
 | 
			
		||||
                    width // processor.patch_size
 | 
			
		||||
                ) + processor.num_additional_image_tokens
 | 
			
		||||
                if processor.vision_feature_select_strategy == "default":
 | 
			
		||||
                    image_seqlen -= 1
 | 
			
		||||
        else:
 | 
			
		||||
            image_seqlen = 1
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content.replace("{{image}}", self.image_token)
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -371,15 +485,16 @@ class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        if "pixel_values" in mm_inputs:
 | 
			
		||||
            image_sizes = iter(mm_inputs["image_sizes"].tolist())
 | 
			
		||||
            height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
            if "pixel_values" in mm_inputs:
 | 
			
		||||
                image_sizes = iter(mm_inputs["image_sizes"].tolist())
 | 
			
		||||
                height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
@ -387,7 +502,7 @@ class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
                if self.expand_mm_tokens:
 | 
			
		||||
                    orig_height, orig_width = next(image_sizes)
 | 
			
		||||
                    image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
 | 
			
		||||
                    if getattr(processor, "vision_feature_select_strategy", "default") == "default":
 | 
			
		||||
                    if processor.vision_feature_select_strategy == "default":
 | 
			
		||||
                        image_seqlen -= 1
 | 
			
		||||
                else:
 | 
			
		||||
                    image_seqlen = 1
 | 
			
		||||
@ -402,21 +517,6 @@ class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
@ -427,48 +527,50 @@ class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        if "pixel_values" in mm_inputs:
 | 
			
		||||
            image_sizes = iter(mm_inputs["image_sizes"].tolist())
 | 
			
		||||
            height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
 | 
			
		||||
            for message in messages:
 | 
			
		||||
                content = message["content"]
 | 
			
		||||
                while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                    if self.expand_mm_tokens:
 | 
			
		||||
                        orig_height, orig_width = next(image_sizes)
 | 
			
		||||
                        image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
 | 
			
		||||
                        if getattr(processor, "vision_feature_select_strategy", "default") == "default":
 | 
			
		||||
                            image_seqlen -= 1
 | 
			
		||||
                    else:
 | 
			
		||||
                        image_seqlen = 1
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
            if "pixel_values" in mm_inputs:
 | 
			
		||||
                image_sizes = iter(mm_inputs["image_sizes"].tolist())
 | 
			
		||||
                height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
 | 
			
		||||
 | 
			
		||||
                    content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
 | 
			
		||||
                    num_image_tokens += 1
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                if self.expand_mm_tokens:
 | 
			
		||||
                    orig_height, orig_width = next(image_sizes)
 | 
			
		||||
                    image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
 | 
			
		||||
                    if processor.vision_feature_select_strategy == "default":
 | 
			
		||||
                        image_seqlen -= 1
 | 
			
		||||
                else:
 | 
			
		||||
                    image_seqlen = 1
 | 
			
		||||
 | 
			
		||||
                message["content"] = content.replace("{{image}}", self.image_token)
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
        if "pixel_values_videos" in mm_inputs:
 | 
			
		||||
            if self.expand_mm_tokens:
 | 
			
		||||
                pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
 | 
			
		||||
                height, width = get_image_size(pixel_values_video[0])
 | 
			
		||||
                num_frames = pixel_values_video.shape[0]  # frame dim is always after batch dim
 | 
			
		||||
            message["content"] = content.replace("{{image}}", self.image_token)
 | 
			
		||||
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            if "pixel_values_videos" in mm_inputs:
 | 
			
		||||
                one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
 | 
			
		||||
                height, width = get_image_size(one_video[0])
 | 
			
		||||
                num_frames = one_video.shape[0]  # frame dim is always after batch dim
 | 
			
		||||
                image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
 | 
			
		||||
                video_seqlen = image_seqlen // 4 * num_frames  # divide by 4 needed for avg pooling layer
 | 
			
		||||
            else:
 | 
			
		||||
                video_seqlen = 1
 | 
			
		||||
        else:
 | 
			
		||||
            video_seqlen = 1
 | 
			
		||||
 | 
			
		||||
            for message in messages:
 | 
			
		||||
                content = message["content"]
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
                    content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
 | 
			
		||||
                num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
                message["content"] = content.replace("{{video}}", self.video_token)
 | 
			
		||||
            message["content"] = content.replace("{{video}}", self.video_token)
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
@ -478,21 +580,6 @@ class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
@ -503,7 +590,7 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
 | 
			
		||||
@ -602,7 +689,7 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: "ProcessorMixin",
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor")
 | 
			
		||||
@ -677,7 +764,7 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        # image bound
 | 
			
		||||
@ -745,7 +832,7 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
@ -760,43 +847,6 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: "ProcessorMixin",
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            pixel_values: tensor with shape
 | 
			
		||||
                          (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
 | 
			
		||||
                          For example, (2, 1, 4, 3, 560, 560).
 | 
			
		||||
            aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
 | 
			
		||||
            aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
 | 
			
		||||
            num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor")
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) > 0:
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )
 | 
			
		||||
            batch_images = []
 | 
			
		||||
            for image_length in imglens:
 | 
			
		||||
                batch_images.append(images[:image_length])
 | 
			
		||||
                images = images[image_length:]
 | 
			
		||||
 | 
			
		||||
            mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
@ -807,14 +857,14 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
 | 
			
		||||
        if mm_inputs:
 | 
			
		||||
            num_tiles = mm_inputs.pop("num_tiles")
 | 
			
		||||
            image_token_id = getattr(processor, "image_token_id")
 | 
			
		||||
            max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
 | 
			
		||||
            image_token_id: int = getattr(processor, "image_token_id")
 | 
			
		||||
            max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
 | 
			
		||||
            cross_attention_token_mask = [
 | 
			
		||||
                get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
 | 
			
		||||
            ]
 | 
			
		||||
@ -839,7 +889,7 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
@ -847,10 +897,10 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "", 1)
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
            message["content"] = content.replace("{{image}}", "")
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
@ -866,15 +916,15 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> tuple[list[int], Optional[list[int]]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_images = len(images)
 | 
			
		||||
        image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0  # skip mm token
 | 
			
		||||
        image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0  # skip mm token
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
 | 
			
		||||
        input_ids = [image_token_id] * image_seqlen + input_ids
 | 
			
		||||
        input_ids = [image_token_id] * num_images * image_seqlen + input_ids
 | 
			
		||||
        if labels is not None:
 | 
			
		||||
            labels = [IGNORE_INDEX] * image_seqlen + labels
 | 
			
		||||
            labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
 | 
			
		||||
 | 
			
		||||
        return input_ids, labels
 | 
			
		||||
 | 
			
		||||
@ -888,7 +938,7 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        seqlens = [len(input_ids) for input_ids in batch_ids]
 | 
			
		||||
@ -906,33 +956,31 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        patch_size = getattr(processor, "patch_size")
 | 
			
		||||
        image_token = getattr(processor, "image_token")
 | 
			
		||||
        image_break_token = getattr(processor, "image_break_token")
 | 
			
		||||
        image_end_token = getattr(processor, "image_end_token")
 | 
			
		||||
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        if "pixel_values" in mm_inputs:
 | 
			
		||||
            image_sizes = iter(mm_inputs["image_sizes"].tolist())
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
            if "pixel_values" in mm_inputs:
 | 
			
		||||
                image_sizes = iter(mm_inputs["image_sizes"].tolist())
 | 
			
		||||
                image_break_token: str = getattr(processor, "image_break_token")
 | 
			
		||||
                image_end_token: str = getattr(processor, "image_end_token")
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                if self.expand_mm_tokens:
 | 
			
		||||
                    height, width = next(image_sizes)
 | 
			
		||||
                    num_height_tokens = height // patch_size
 | 
			
		||||
                    num_width_tokens = width // patch_size
 | 
			
		||||
                    replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
 | 
			
		||||
                    num_height_tokens = height // processor.patch_size
 | 
			
		||||
                    num_width_tokens = width // processor.patch_size
 | 
			
		||||
                    replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
 | 
			
		||||
                    replace_tokens = [item for sublist in replace_tokens for item in sublist]  # flatten list
 | 
			
		||||
                    replace_tokens[-1] = image_end_token
 | 
			
		||||
                    replace_str = "".join(replace_tokens)
 | 
			
		||||
                else:
 | 
			
		||||
                    replace_str = image_token
 | 
			
		||||
                    replace_str = self.image_token
 | 
			
		||||
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
@ -954,7 +1002,7 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
@ -971,17 +1019,18 @@ class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        bos_token: str = getattr(processor, "audio_bos_token")
 | 
			
		||||
        eos_token: str = getattr(processor, "audio_eos_token")
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs([], [], audios, processor)
 | 
			
		||||
        if "feature_attention_mask" in mm_inputs:
 | 
			
		||||
            audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
 | 
			
		||||
 | 
			
		||||
        num_audio_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs([], [], audios, processor)
 | 
			
		||||
            if "feature_attention_mask" in mm_inputs:
 | 
			
		||||
                audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while AUDIO_PLACEHOLDER in content:
 | 
			
		||||
@ -1014,7 +1063,7 @@ class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
@ -1072,7 +1121,7 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: "ProcessorMixin",
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
@ -1104,7 +1153,7 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
@ -1162,14 +1211,15 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        fps_per_video = mm_inputs.pop("fps_per_video", [])
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor")
 | 
			
		||||
        temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
 | 
			
		||||
        if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
 | 
			
		||||
            mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
 | 
			
		||||
            mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video]
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
@ -1183,45 +1233,45 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        num_frames = 0
 | 
			
		||||
        has_images = "pixel_values_images" in mm_inputs
 | 
			
		||||
        has_videos = "pixel_values_videos" in mm_inputs
 | 
			
		||||
        if has_images or has_videos:
 | 
			
		||||
            if self.expand_mm_tokens:
 | 
			
		||||
                if has_images:
 | 
			
		||||
                    height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
 | 
			
		||||
                    num_frames = 1
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
            if "pixel_values_images" in mm_inputs:
 | 
			
		||||
                height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
 | 
			
		||||
                num_frames = 1
 | 
			
		||||
 | 
			
		||||
                if has_videos:
 | 
			
		||||
                    pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
 | 
			
		||||
                    height, width = get_image_size(pixel_values_video[0])
 | 
			
		||||
                    num_frames = pixel_values_video.shape[0]  # frame dim is always after batch dim
 | 
			
		||||
            if "pixel_values_videos" in mm_inputs:
 | 
			
		||||
                one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
 | 
			
		||||
                height, width = get_image_size(one_video[0])
 | 
			
		||||
                num_frames = one_video.shape[0]  # frame dim is always after batch dim
 | 
			
		||||
 | 
			
		||||
                image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
 | 
			
		||||
            if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
 | 
			
		||||
                image_seqlen = (height // processor.patch_size) * (
 | 
			
		||||
                    width // processor.patch_size
 | 
			
		||||
                ) + processor.num_additional_image_tokens
 | 
			
		||||
                video_seqlen = image_seqlen * num_frames
 | 
			
		||||
                if getattr(processor, "vision_feature_select_strategy", "default") == "default":
 | 
			
		||||
                if processor.vision_feature_select_strategy == "default":
 | 
			
		||||
                    image_seqlen -= 1
 | 
			
		||||
            else:
 | 
			
		||||
                image_seqlen, video_seqlen = 1, 1
 | 
			
		||||
        else:
 | 
			
		||||
            image_seqlen, video_seqlen = 1, 1
 | 
			
		||||
 | 
			
		||||
            for message in messages:
 | 
			
		||||
                content = message["content"]
 | 
			
		||||
                while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                    content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
 | 
			
		||||
                    num_image_tokens += 1
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
 | 
			
		||||
                while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                    content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
 | 
			
		||||
                    num_video_tokens += 1
 | 
			
		||||
            while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
 | 
			
		||||
                num_video_tokens += 1
 | 
			
		||||
 | 
			
		||||
                content = content.replace("{{image}}", self.image_token)
 | 
			
		||||
                message["content"] = content.replace("{{video}}", self.video_token)
 | 
			
		||||
            content = content.replace("{{image}}", self.image_token)
 | 
			
		||||
            message["content"] = content.replace("{{video}}", self.video_token)
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
@ -1231,24 +1281,10 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
PLUGINS = {
 | 
			
		||||
    "base": BasePlugin,
 | 
			
		||||
    "gemma3": Gemma3Plugin,
 | 
			
		||||
    "llava": LlavaPlugin,
 | 
			
		||||
    "llava_next": LlavaNextPlugin,
 | 
			
		||||
    "llava_next_video": LlavaNextVideoPlugin,
 | 
			
		||||
 | 
			
		||||
@ -310,6 +310,8 @@ class Template:
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Llama2Template(Template):
 | 
			
		||||
    r"""A template that fuse the system message to first user message."""
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _encode(
 | 
			
		||||
        self,
 | 
			
		||||
@ -815,10 +817,29 @@ register_template(
 | 
			
		||||
    name="gemma",
 | 
			
		||||
    format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["{{content}}\n\n"]),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    stop_words=["<end_of_turn>"],
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from gemma template
 | 
			
		||||
register_template(
 | 
			
		||||
    name="gemma3",
 | 
			
		||||
    format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["{{content}}\n\n"]),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    stop_words=["<end_of_turn>"],
 | 
			
		||||
    mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
 | 
			
		||||
    template_class=Llama2Template,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1255,6 +1276,7 @@ register_template(
 | 
			
		||||
        slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    stop_words=["<end_of_turn>"],
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -650,11 +650,51 @@ register_model_group(
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-2-27b-it",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3-1B": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3-1b-pt",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-pt",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3-1B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3-1b-it",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="gemma",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Gemma-3-4B": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3-4b-pt",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-pt",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3-12B": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3-12b-pt",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-pt",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3-27B": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3-27b-pt",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-pt",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3-4B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3-4b-it",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-it",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3-12B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3-12b-it",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-it",
 | 
			
		||||
        },
 | 
			
		||||
        "Gemma-3-27B-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "google/gemma-3-27b-it",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="gemma3",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "GLM-4-9B": {
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ import torch
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AutoConfig,
 | 
			
		||||
    AutoModelForCausalLM,
 | 
			
		||||
    AutoModelForImageTextToText,
 | 
			
		||||
    AutoModelForSeq2SeqLM,
 | 
			
		||||
    AutoModelForVision2Seq,
 | 
			
		||||
    AutoProcessor,
 | 
			
		||||
@ -72,7 +73,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
 | 
			
		||||
    Note: including inplace operation of model_args.
 | 
			
		||||
    """
 | 
			
		||||
    init_kwargs = _get_init_kwargs(model_args)
 | 
			
		||||
    config = load_config(model_args)
 | 
			
		||||
    try:
 | 
			
		||||
        tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
            model_args.model_name_or_path,
 | 
			
		||||
@ -94,7 +94,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
 | 
			
		||||
    patch_tokenizer(tokenizer, model_args)
 | 
			
		||||
    try:
 | 
			
		||||
        processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
 | 
			
		||||
        patch_processor(processor, config, tokenizer, model_args)
 | 
			
		||||
        patch_processor(processor, tokenizer, model_args)
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        logger.debug(f"Processor was not found: {e}.")
 | 
			
		||||
        processor = None
 | 
			
		||||
@ -141,9 +141,11 @@ def load_model(
 | 
			
		||||
        if model_args.mixture_of_depths == "load":
 | 
			
		||||
            model = load_mod_pretrained_model(**init_kwargs)
 | 
			
		||||
        else:
 | 
			
		||||
            if type(config) in AutoModelForVision2Seq._model_mapping.keys():  # assume built-in models
 | 
			
		||||
            if type(config) in AutoModelForVision2Seq._model_mapping.keys():  # image-text
 | 
			
		||||
                load_class = AutoModelForVision2Seq
 | 
			
		||||
            elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():
 | 
			
		||||
            elif type(config) in AutoModelForImageTextToText._model_mapping.keys():  # image-text
 | 
			
		||||
                load_class = AutoModelForImageTextToText
 | 
			
		||||
            elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():  # audio-text
 | 
			
		||||
                load_class = AutoModelForSeq2SeqLM
 | 
			
		||||
            else:
 | 
			
		||||
                load_class = AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's Transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
 | 
			
		||||
@ -28,7 +28,7 @@ from ...extras import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
 | 
			
		||||
    from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
 | 
			
		||||
 | 
			
		||||
    from ...hparams import FinetuningArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
@ -62,6 +62,16 @@ def _register_composite_model(
 | 
			
		||||
    language_model_keys: Optional[list[str]] = None,
 | 
			
		||||
    lora_conflict_keys: Optional[list[str]] = None,
 | 
			
		||||
):
 | 
			
		||||
    r"""Register a new composite model.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        model_type: model type
 | 
			
		||||
        projector_key: multi_modal_projector
 | 
			
		||||
        vision_model_keys: vision_tower
 | 
			
		||||
        language_model_keys: language_model
 | 
			
		||||
        lora_conflict_keys: None
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    COMPOSITE_MODELS[model_type] = CompositeModel(
 | 
			
		||||
        model_type=model_type,
 | 
			
		||||
        projector_key=projector_key or "multi_modal_projector",
 | 
			
		||||
@ -169,39 +179,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
 | 
			
		||||
    return forbidden_modules
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_image_seqlen(config: "PretrainedConfig") -> int:
 | 
			
		||||
    r"""Compute the number of special tokens per image."""
 | 
			
		||||
    model_type = getattr(config, "model_type", None)
 | 
			
		||||
    if model_type == "llava":
 | 
			
		||||
        image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
 | 
			
		||||
        if getattr(config, "vision_feature_select_strategy", "default") == "full":  # add [CLS] token
 | 
			
		||||
            image_seqlen += 1
 | 
			
		||||
    elif model_type == "paligemma":
 | 
			
		||||
        image_seqlen = config.vision_config.num_image_tokens
 | 
			
		||||
    else:
 | 
			
		||||
        image_seqlen = -1
 | 
			
		||||
 | 
			
		||||
    return image_seqlen
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
 | 
			
		||||
    r"""Compute the patch size of the vit."""
 | 
			
		||||
    patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
 | 
			
		||||
    return patch_size
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
 | 
			
		||||
    r"""Get the vision_feature_select_strategy."""
 | 
			
		||||
    vision_feature_select_strategy = getattr(
 | 
			
		||||
        config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
 | 
			
		||||
    )
 | 
			
		||||
    return vision_feature_select_strategy
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_target_modules(
 | 
			
		||||
    model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
 | 
			
		||||
) -> list[str]:
 | 
			
		||||
    r"""Freezes vision tower for VLM LoRA tuning."""
 | 
			
		||||
    r"""Freeze vision tower for VLM LoRA tuning."""
 | 
			
		||||
    model_type = getattr(model.config, "model_type", None)
 | 
			
		||||
    if model_type in COMPOSITE_MODELS:
 | 
			
		||||
        forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
 | 
			
		||||
@ -218,6 +199,11 @@ def patch_target_modules(
 | 
			
		||||
        return target_modules
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="gemma3",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="llava",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -33,13 +33,7 @@ from .model_utils.packing import configure_packing
 | 
			
		||||
from .model_utils.quantization import configure_quantization
 | 
			
		||||
from .model_utils.rope import configure_rope
 | 
			
		||||
from .model_utils.valuehead import prepare_valuehead_model
 | 
			
		||||
from .model_utils.visual import (
 | 
			
		||||
    autocast_projector_dtype,
 | 
			
		||||
    configure_visual_model,
 | 
			
		||||
    get_image_seqlen,
 | 
			
		||||
    get_patch_size,
 | 
			
		||||
    get_vision_feature_select_strategy,
 | 
			
		||||
)
 | 
			
		||||
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -72,21 +66,16 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
 | 
			
		||||
 | 
			
		||||
def patch_processor(
 | 
			
		||||
    processor: "ProcessorMixin",
 | 
			
		||||
    config: "PretrainedConfig",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
) -> None:
 | 
			
		||||
    setattr(processor, "tokenizer", tokenizer)
 | 
			
		||||
    if getattr(config, "vision_config", None) is not None:  # visual models
 | 
			
		||||
        setattr(processor, "image_seqlen", get_image_seqlen(config))
 | 
			
		||||
        setattr(processor, "patch_size", get_patch_size(config, processor))
 | 
			
		||||
        setattr(processor, "image_max_pixels", model_args.image_max_pixels)
 | 
			
		||||
        setattr(processor, "image_min_pixels", model_args.image_min_pixels)
 | 
			
		||||
        setattr(processor, "video_max_pixels", model_args.video_max_pixels)
 | 
			
		||||
        setattr(processor, "video_min_pixels", model_args.video_min_pixels)
 | 
			
		||||
        setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
        setattr(processor, "video_maxlen", model_args.video_maxlen)
 | 
			
		||||
        setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
 | 
			
		||||
    setattr(processor, "image_max_pixels", model_args.image_max_pixels)
 | 
			
		||||
    setattr(processor, "image_min_pixels", model_args.image_min_pixels)
 | 
			
		||||
    setattr(processor, "video_max_pixels", model_args.video_max_pixels)
 | 
			
		||||
    setattr(processor, "video_min_pixels", model_args.video_min_pixels)
 | 
			
		||||
    setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
    setattr(processor, "video_maxlen", model_args.video_maxlen)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_config(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user