[model] support gemma3 (#7273)

This commit is contained in:
hoshi-hiyouga 2025-03-13 01:35:23 +08:00 committed by GitHub
parent 142fd7e755
commit 165d3ed084
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 356 additions and 274 deletions

View File

@ -84,10 +84,10 @@ Choose your path:
### Day-N Support for Fine-Tuning Cutting-Edge Models ### Day-N Support for Fine-Tuning Cutting-Edge Models
| Support Date | Model Name | | Support Date | Model Name |
| ------------ | ---------------------------------------------------------- | | ------------ | ------------------------------------------------------------ |
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 | | Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 | | Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
## Benchmark ## Benchmark
@ -106,6 +106,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog ## Changelog
[25/03/12] We supported fine-tuning the **[Gemma-3](https://huggingface.co/blog/gemma3)** model.
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training. [25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage. [25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
@ -120,7 +122,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR. [25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. [25/01/14] We supported fine-tuning the **[InternLM 3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
@ -229,6 +231,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3 |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |

View File

@ -86,10 +86,10 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
### 最新模型的 Day-N 微调适配 ### 最新模型的 Day-N 微调适配
| 适配时间 | 模型名称 | | 适配时间 | 模型名称 |
| ------------ | ---------------------------------------------------------- | | ------------ | ------------------------------------------------------------ |
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 | | Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 | | Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
## 性能指标 ## 性能指标
@ -108,6 +108,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志 ## 更新日志
[25/03/12] 我们支持了 **[Gemma-3](https://huggingface.co/blog/gemma3)** 模型的微调。
[25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。 [25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。 [25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
@ -122,7 +124,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR. [25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 [25/01/14] 我们支持了 **[InternLM 3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 [25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
@ -231,6 +233,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3 |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |

View File

@ -27,6 +27,7 @@ indent-width = 4
ignore = [ ignore = [
"C408", # collection "C408", # collection
"C901", # complex "C901", # complex
"E501", # line too long
"E731", # lambda function "E731", # lambda function
"E741", # ambiguous var name "E741", # ambiguous var name
"D100", # no doc public module "D100", # no doc public module

View File

@ -1,3 +1,20 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect import inspect
import math import math
import re import re
@ -5,7 +22,7 @@ from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Optional, TypedDict, Union from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
@ -56,24 +73,63 @@ if TYPE_CHECKING:
VideoInput = str VideoInput = str
AudioInput = Union[str, NDArray] AudioInput = Union[str, NDArray]
class MMProcessor(ProcessorMixin):
patch_size: int
image_seq_length: int
num_additional_image_tokens: int
vision_feature_select_strategy: Literal["default", "full"]
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
pass
def _get_paligemma_token_type_ids( def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor"
) -> list[list[int]]: ) -> list[list[int]]:
r"""Get paligemma token type ids for computing loss. r"""Get paligemma token type ids for computing loss.
It is slightly different with the original token type ids where the prompt part is 0.
Returns: Returns:
batch_token_type_ids: shape (batch_size, sequence_length) batch_token_type_ids: shape (batch_size, seq_length)
""" """
batch_token_type_ids = [] batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens): for imglen, seqlen in zip(imglens, seqlens):
image_seqlen = imglen * getattr(processor, "image_seqlen") image_seqlen = imglen * processor.image_seq_length
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen)) batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
return batch_token_type_ids return batch_token_type_ids
def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
r"""Get gemma3 token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, seq_length)
"""
image_token_id: int = getattr(processor, "image_token_id")
batch_token_type_ids = []
for token_ids in batch_ids:
token_ids = np.array(token_ids)
token_type_ids = np.zeros_like(token_ids)
token_type_ids[token_ids == image_token_id] = 1
batch_token_type_ids.append(token_type_ids.tolist())
return batch_token_type_ids
def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
r"""Make nested list of images."""
batch_images = []
for imglen in imglens:
batch_images.append(images[:imglen])
images = images[imglen:]
return batch_images
@dataclass @dataclass
class MMPluginMixin: class MMPluginMixin:
image_token: Optional[str] image_token: Optional[str]
@ -83,7 +139,7 @@ class MMPluginMixin:
def _validate_input( def _validate_input(
self, self,
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
@ -204,7 +260,8 @@ class MMPluginMixin:
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "MMProcessor",
imglens: Optional[list[int]] = None,
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs. r"""Process visual inputs.
@ -214,23 +271,34 @@ class MMPluginMixin:
Returns: (qwen2-vl) Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim) pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
where num_patches == torch.prod(image_grid_thw)
Returns: (mllama)
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
It holds num_patches == torch.prod(image_grid_thw)
""" """
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
images = self._regularize_images( images = self._regularize_images(
images, images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
) )
if imglens is not None:
images = _make_batched_images(images, imglens)
mm_inputs.update(image_processor(images, return_tensors="pt")) mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0: if len(videos) != 0:
video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None)
)
videos = self._regularize_videos( videos = self._regularize_videos(
videos, videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
@ -244,6 +312,7 @@ class MMPluginMixin:
mm_inputs.update(video_processor(videos, return_tensors="pt")) mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0: if len(audios) != 0:
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
audios = self._regularize_audios( audios = self._regularize_audios(
audios, audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
@ -270,9 +339,9 @@ class BasePlugin(MMPluginMixin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
r"""Pre-processes input messages before tokenization for VLMs.""" r"""Pre-process input messages before tokenization for VLMs."""
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return messages return messages
@ -284,9 +353,9 @@ class BasePlugin(MMPluginMixin):
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]: ) -> tuple[list[int], Optional[list[int]]]:
r"""Pre-processes token ids after tokenization for VLMs.""" r"""Pre-process token ids after tokenization for VLMs."""
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return input_ids, labels return input_ids, labels
@ -299,7 +368,7 @@ class BasePlugin(MMPluginMixin):
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[list[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
r"""Build batched multimodal inputs for VLMs. r"""Build batched multimodal inputs for VLMs.
@ -315,11 +384,11 @@ class BasePlugin(MMPluginMixin):
""" """
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return {} return self._get_mm_inputs(images, videos, audios, processor)
@dataclass @dataclass
class LlavaPlugin(BasePlugin): class Gemma3Plugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
@ -327,19 +396,21 @@ class LlavaPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages) messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token")
full_image_sequence: str = getattr(processor, "full_image_sequence")
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1 num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", image_str)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
@ -356,10 +427,53 @@ class LlavaPlugin(BasePlugin):
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[list[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("num_crops", None)
mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
return mm_inputs
@dataclass
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
image_seqlen = (height // processor.patch_size) * (
width // processor.patch_size
) + processor.num_additional_image_tokens
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen = 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@dataclass @dataclass
@ -371,15 +485,16 @@ class LlavaNextPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) if self.expand_mm_tokens:
if "pixel_values" in mm_inputs: mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_sizes = iter(mm_inputs["image_sizes"].tolist()) if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages: for message in messages:
content = message["content"] content = message["content"]
@ -387,7 +502,7 @@ class LlavaNextPlugin(BasePlugin):
if self.expand_mm_tokens: if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes) orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy", "default") == "default": if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1 image_seqlen -= 1
else: else:
image_seqlen = 1 image_seqlen = 1
@ -402,21 +517,6 @@ class LlavaNextPlugin(BasePlugin):
return messages return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass @dataclass
class LlavaNextVideoPlugin(BasePlugin): class LlavaNextVideoPlugin(BasePlugin):
@ -427,48 +527,50 @@ class LlavaNextVideoPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) if self.expand_mm_tokens:
if "pixel_values" in mm_inputs: mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_sizes = iter(mm_inputs["image_sizes"].tolist()) if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) image_sizes = iter(mm_inputs["image_sizes"].tolist())
for message in messages: height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) for message in messages:
num_image_tokens += 1 content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen = 1
message["content"] = content.replace("{{image}}", self.image_token) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
if "pixel_values_videos" in mm_inputs: message["content"] = content.replace("{{image}}", self.image_token)
if self.expand_mm_tokens:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) if self.expand_mm_tokens:
height, width = get_image_size(pixel_values_video[0]) if "pixel_values_videos" in mm_inputs:
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
else: else:
video_seqlen = 1 video_seqlen = 1
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1 content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) num_video_tokens += 1
message["content"] = content.replace("{{video}}", self.video_token) message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
@ -478,21 +580,6 @@ class LlavaNextVideoPlugin(BasePlugin):
return messages return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass @dataclass
class MiniCPMVPlugin(BasePlugin): class MiniCPMVPlugin(BasePlugin):
@ -503,7 +590,7 @@ class MiniCPMVPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
@ -602,7 +689,7 @@ class MiniCPMVPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "MMProcessor",
**kwargs, **kwargs,
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
@ -677,7 +764,7 @@ class MiniCPMVPlugin(BasePlugin):
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[list[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
# image bound # image bound
@ -745,7 +832,7 @@ class MllamaPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
@ -760,43 +847,6 @@ class MllamaPlugin(BasePlugin):
return messages return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
imglens: list[int],
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) > 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
return mm_inputs
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
@ -807,14 +857,14 @@ class MllamaPlugin(BasePlugin):
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[list[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
if mm_inputs: if mm_inputs:
num_tiles = mm_inputs.pop("num_tiles") num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id") image_token_id: int = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles") max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [ cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
] ]
@ -839,7 +889,7 @@ class PaliGemmaPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
@ -847,10 +897,10 @@ class PaliGemmaPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) content = content.replace(IMAGE_PLACEHOLDER, "", 1)
num_image_tokens += 1 num_image_tokens += 1
message["content"] = content.replace("{{image}}", "") message["content"] = content
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
@ -866,15 +916,15 @@ class PaliGemmaPlugin(BasePlugin):
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]: ) -> tuple[list[int], Optional[list[int]]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_images = len(images) num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids input_ids = [image_token_id] * num_images * image_seqlen + input_ids
if labels is not None: if labels is not None:
labels = [IGNORE_INDEX] * image_seqlen + labels labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
return input_ids, labels return input_ids, labels
@ -888,7 +938,7 @@ class PaliGemmaPlugin(BasePlugin):
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[list[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids] seqlens = [len(input_ids) for input_ids in batch_ids]
@ -906,33 +956,31 @@ class PixtralPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) if self.expand_mm_tokens:
if "pixel_values" in mm_inputs: mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_sizes = iter(mm_inputs["image_sizes"].tolist()) if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
image_break_token: str = getattr(processor, "image_break_token")
image_end_token: str = getattr(processor, "image_end_token")
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens: if self.expand_mm_tokens:
height, width = next(image_sizes) height, width = next(image_sizes)
num_height_tokens = height // patch_size num_height_tokens = height // processor.patch_size
num_width_tokens = width // patch_size num_width_tokens = width // processor.patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens) replace_str = "".join(replace_tokens)
else: else:
replace_str = image_token replace_str = self.image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1 num_image_tokens += 1
@ -954,7 +1002,7 @@ class PixtralPlugin(BasePlugin):
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[list[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -971,17 +1019,18 @@ class Qwen2AudioPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token") bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token") eos_token: str = getattr(processor, "audio_eos_token")
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs([], [], audios, processor)
if "feature_attention_mask" in mm_inputs:
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
num_audio_tokens = 0 num_audio_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs([], [], audios, processor)
if "feature_attention_mask" in mm_inputs:
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
@ -1014,7 +1063,7 @@ class Qwen2AudioPlugin(BasePlugin):
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[list[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -1072,7 +1121,7 @@ class Qwen2VLPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {} mm_inputs = {}
@ -1104,7 +1153,7 @@ class Qwen2VLPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
@ -1162,14 +1211,15 @@ class Qwen2VLPlugin(BasePlugin):
vidlens: Sequence[int], vidlens: Sequence[int],
audlens: Sequence[int], audlens: Sequence[int],
batch_ids: Sequence[list[int]], batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", []) fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: BaseImageProcessor = getattr(processor, "image_processor") image_processor: BaseImageProcessor = getattr(processor, "image_processor")
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names and fps_per_video: if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video] mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video]
return mm_inputs return mm_inputs
@ -1183,45 +1233,45 @@ class VideoLlavaPlugin(BasePlugin):
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
num_frames = 0 num_frames = 0
has_images = "pixel_values_images" in mm_inputs if self.expand_mm_tokens:
has_videos = "pixel_values_videos" in mm_inputs mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if has_images or has_videos: if "pixel_values_images" in mm_inputs:
if self.expand_mm_tokens: height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
if has_images: num_frames = 1
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if has_videos: if "pixel_values_videos" in mm_inputs:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
height, width = get_image_size(pixel_values_video[0]) height, width = get_image_size(one_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
image_seqlen = (height // processor.patch_size) * (
width // processor.patch_size
) + processor.num_additional_image_tokens
video_seqlen = image_seqlen * num_frames video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy", "default") == "default": if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1 image_seqlen -= 1
else: else:
image_seqlen, video_seqlen = 1, 1 image_seqlen, video_seqlen = 1, 1
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1 num_video_tokens += 1
content = content.replace("{{image}}", self.image_token) content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token) message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens: if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
@ -1231,24 +1281,10 @@ class VideoLlavaPlugin(BasePlugin):
return messages return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
PLUGINS = { PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"gemma3": Gemma3Plugin,
"llava": LlavaPlugin, "llava": LlavaPlugin,
"llava_next": LlavaNextPlugin, "llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin, "llava_next_video": LlavaNextVideoPlugin,

View File

@ -310,6 +310,8 @@ class Template:
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
r"""A template that fuse the system message to first user message."""
@override @override
def _encode( def _encode(
self, self,
@ -815,10 +817,29 @@ register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]), format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
template_class=Llama2Template,
)
# copied from gemma template
register_template(
name="gemma3",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
template_class=Llama2Template,
) )
@ -1255,6 +1276,7 @@ register_template(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"] slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
), ),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"), mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
) )

View File

@ -650,11 +650,51 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-27b-it", DownloadSource.DEFAULT: "google/gemma-2-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
}, },
"Gemma-3-1B": {
DownloadSource.DEFAULT: "google/gemma-3-1b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-pt",
},
"Gemma-3-1B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-1b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it",
},
}, },
template="gemma", template="gemma",
) )
register_model_group(
models={
"Gemma-3-4B": {
DownloadSource.DEFAULT: "google/gemma-3-4b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-pt",
},
"Gemma-3-12B": {
DownloadSource.DEFAULT: "google/gemma-3-12b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-pt",
},
"Gemma-3-27B": {
DownloadSource.DEFAULT: "google/gemma-3-27b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-pt",
},
"Gemma-3-4B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-4b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-it",
},
"Gemma-3-12B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-12b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-it",
},
"Gemma-3-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it",
},
},
template="gemma3",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"GLM-4-9B": { "GLM-4-9B": {

View File

@ -19,6 +19,7 @@ import torch
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoModelForVision2Seq,
AutoProcessor, AutoProcessor,
@ -72,7 +73,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
Note: including inplace operation of model_args. Note: including inplace operation of model_args.
""" """
init_kwargs = _get_init_kwargs(model_args) init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
@ -94,7 +94,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer(tokenizer, model_args) patch_tokenizer(tokenizer, model_args)
try: try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args) patch_processor(processor, tokenizer, model_args)
except Exception as e: except Exception as e:
logger.debug(f"Processor was not found: {e}.") logger.debug(f"Processor was not found: {e}.")
processor = None processor = None
@ -141,9 +141,11 @@ def load_model(
if model_args.mixture_of_depths == "load": if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs) model = load_mod_pretrained_model(**init_kwargs)
else: else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM load_class = AutoModelForSeq2SeqLM
else: else:
load_class = AutoModelForCausalLM load_class = AutoModelForCausalLM

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's Transformers library. # This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
@ -28,7 +28,7 @@ from ...extras import logging
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import FinetuningArguments, ModelArguments from ...hparams import FinetuningArguments, ModelArguments
@ -62,6 +62,16 @@ def _register_composite_model(
language_model_keys: Optional[list[str]] = None, language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[list[str]] = None, lora_conflict_keys: Optional[list[str]] = None,
): ):
r"""Register a new composite model.
Args:
model_type: model type
projector_key: multi_modal_projector
vision_model_keys: vision_tower
language_model_keys: language_model
lora_conflict_keys: None
"""
COMPOSITE_MODELS[model_type] = CompositeModel( COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type, model_type=model_type,
projector_key=projector_key or "multi_modal_projector", projector_key=projector_key or "multi_modal_projector",
@ -169,39 +179,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return forbidden_modules return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""Compute the number of special tokens per image."""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
else:
image_seqlen = -1
return image_seqlen
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""Compute the patch size of the vit."""
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""Get the vision_feature_select_strategy."""
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy
def patch_target_modules( def patch_target_modules(
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> list[str]: ) -> list[str]:
r"""Freezes vision tower for VLM LoRA tuning.""" r"""Freeze vision tower for VLM LoRA tuning."""
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS: if model_type in COMPOSITE_MODELS:
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
@ -218,6 +199,11 @@ def patch_target_modules(
return target_modules return target_modules
_register_composite_model(
model_type="gemma3",
)
_register_composite_model( _register_composite_model(
model_type="llava", model_type="llava",
) )

View File

@ -33,13 +33,7 @@ from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import ( from .model_utils.visual import autocast_projector_dtype, configure_visual_model
autocast_projector_dtype,
configure_visual_model,
get_image_seqlen,
get_patch_size,
get_vision_feature_select_strategy,
)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -72,21 +66,16 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
def patch_processor( def patch_processor(
processor: "ProcessorMixin", processor: "ProcessorMixin",
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments", model_args: "ModelArguments",
) -> None: ) -> None:
setattr(processor, "tokenizer", tokenizer) setattr(processor, "tokenizer", tokenizer)
if getattr(config, "vision_config", None) is not None: # visual models setattr(processor, "image_max_pixels", model_args.image_max_pixels)
setattr(processor, "image_seqlen", get_image_seqlen(config)) setattr(processor, "image_min_pixels", model_args.image_min_pixels)
setattr(processor, "patch_size", get_patch_size(config, processor)) setattr(processor, "video_max_pixels", model_args.video_max_pixels)
setattr(processor, "image_max_pixels", model_args.image_max_pixels) setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "image_min_pixels", model_args.image_min_pixels) setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_max_pixels", model_args.video_max_pixels) setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
def patch_config( def patch_config(