[model] support gemma3 (#7273)

This commit is contained in:
hoshi-hiyouga 2025-03-13 01:35:23 +08:00 committed by GitHub
parent 142fd7e755
commit 165d3ed084
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 356 additions and 274 deletions

View File

@ -84,10 +84,10 @@ Choose your path:
### Day-N Support for Fine-Tuning Cutting-Edge Models
| Support Date | Model Name |
| ------------ | ---------------------------------------------------------- |
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
| Support Date | Model Name |
| ------------ | ------------------------------------------------------------ |
| Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
## Benchmark
@ -106,6 +106,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
## Changelog
[25/03/12] We supported fine-tuning the **[Gemma-3](https://huggingface.co/blog/gemma3)** model.
[25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training.
[25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage.
@ -120,7 +122,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
[25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR.
[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
[25/01/14] We supported fine-tuning the **[InternLM 3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR.
[25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model.
@ -229,6 +231,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3 |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |

View File

@ -86,10 +86,10 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
### 最新模型的 Day-N 微调适配
| 适配时间 | 模型名称 |
| ------------ | ---------------------------------------------------------- |
| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
| 适配时间 | 模型名称 |
| ------------ | ------------------------------------------------------------ |
| Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 |
| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 |
## 性能指标
@ -108,6 +108,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
## 更新日志
[25/03/12] 我们支持了 **[Gemma-3](https://huggingface.co/blog/gemma3)** 模型的微调。
[25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。
[25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。
@ -122,7 +124,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
[25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR.
[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
[25/01/14] 我们支持了 **[InternLM 3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。
[25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。
@ -231,6 +233,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 |
| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon |
| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma |
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3 |
| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |

View File

@ -27,6 +27,7 @@ indent-width = 4
ignore = [
"C408", # collection
"C901", # complex
"E501", # line too long
"E731", # lambda function
"E741", # ambiguous var name
"D100", # no doc public module

View File

@ -1,3 +1,20 @@
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import re
@ -5,7 +22,7 @@ from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Optional, TypedDict, Union
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
import numpy as np
import torch
@ -56,24 +73,63 @@ if TYPE_CHECKING:
VideoInput = str
AudioInput = Union[str, NDArray]
class MMProcessor(ProcessorMixin):
patch_size: int
image_seq_length: int
num_additional_image_tokens: int
vision_feature_select_strategy: Literal["default", "full"]
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
pass
def _get_paligemma_token_type_ids(
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor"
) -> list[list[int]]:
r"""Get paligemma token type ids for computing loss.
It is slightly different with the original token type ids where the prompt part is 0.
Returns:
batch_token_type_ids: shape (batch_size, sequence_length)
batch_token_type_ids: shape (batch_size, seq_length)
"""
batch_token_type_ids = []
for imglen, seqlen in zip(imglens, seqlens):
image_seqlen = imglen * getattr(processor, "image_seqlen")
image_seqlen = imglen * processor.image_seq_length
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
return batch_token_type_ids
def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
r"""Get gemma3 token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, seq_length)
"""
image_token_id: int = getattr(processor, "image_token_id")
batch_token_type_ids = []
for token_ids in batch_ids:
token_ids = np.array(token_ids)
token_type_ids = np.zeros_like(token_ids)
token_type_ids[token_ids == image_token_id] = 1
batch_token_type_ids.append(token_type_ids.tolist())
return batch_token_type_ids
def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
r"""Make nested list of images."""
batch_images = []
for imglen in imglens:
batch_images.append(images[:imglen])
images = images[imglen:]
return batch_images
@dataclass
class MMPluginMixin:
image_token: Optional[str]
@ -83,7 +139,7 @@ class MMPluginMixin:
def _validate_input(
self,
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
@ -204,7 +260,8 @@ class MMPluginMixin:
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
processor: "MMProcessor",
imglens: Optional[list[int]] = None,
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs.
@ -214,23 +271,34 @@ class MMPluginMixin:
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
where num_patches == torch.prod(image_grid_thw)
Returns: (mllama)
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor)
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
if imglens is not None:
images = _make_batched_images(images, imglens)
mm_inputs.update(image_processor(images, return_tensors="pt"))
if len(videos) != 0:
video_processor: BaseImageProcessor = getattr(
processor, "video_processor", getattr(processor, "image_processor", None)
)
videos = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
@ -244,6 +312,7 @@ class MMPluginMixin:
mm_inputs.update(video_processor(videos, return_tensors="pt"))
if len(audios) != 0:
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
audios = self._regularize_audios(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
@ -270,9 +339,9 @@ class BasePlugin(MMPluginMixin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
r"""Pre-processes input messages before tokenization for VLMs."""
r"""Pre-process input messages before tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return messages
@ -284,9 +353,9 @@ class BasePlugin(MMPluginMixin):
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
r"""Pre-processes token ids after tokenization for VLMs."""
r"""Pre-process token ids after tokenization for VLMs."""
self._validate_input(processor, images, videos, audios)
return input_ids, labels
@ -299,7 +368,7 @@ class BasePlugin(MMPluginMixin):
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
r"""Build batched multimodal inputs for VLMs.
@ -315,11 +384,11 @@ class BasePlugin(MMPluginMixin):
"""
self._validate_input(processor, images, videos, audios)
return {}
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaPlugin(BasePlugin):
class Gemma3Plugin(BasePlugin):
@override
def process_messages(
self,
@ -327,19 +396,21 @@ class LlavaPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token")
full_image_sequence: str = getattr(processor, "full_image_sequence")
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{image}}", image_str)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
@ -356,10 +427,53 @@ class LlavaPlugin(BasePlugin):
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("num_crops", None)
mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
return mm_inputs
@dataclass
class LlavaPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
image_seqlen = (height // processor.patch_size) * (
width // processor.patch_size
) + processor.num_additional_image_tokens
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen = 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@dataclass
@ -371,15 +485,16 @@ class LlavaNextPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
@ -387,7 +502,7 @@ class LlavaNextPlugin(BasePlugin):
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen = 1
@ -402,21 +517,6 @@ class LlavaNextPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaNextVideoPlugin(BasePlugin):
@ -427,48 +527,50 @@ class LlavaNextVideoPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
image_seqlen -= 1
else:
image_seqlen = 1
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
orig_height, orig_width = next(image_sizes)
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen = 1
message["content"] = content.replace("{{image}}", self.image_token)
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
if "pixel_values_videos" in mm_inputs:
if self.expand_mm_tokens:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
message["content"] = content.replace("{{image}}", self.image_token)
if self.expand_mm_tokens:
if "pixel_values_videos" in mm_inputs:
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
else:
video_seqlen = 1
else:
video_seqlen = 1
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
for message in messages:
content = message["content"]
while VIDEO_PLACEHOLDER in content:
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1
message["content"] = content.replace("{{video}}", self.video_token)
message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
@ -478,21 +580,6 @@ class LlavaNextVideoPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class MiniCPMVPlugin(BasePlugin):
@ -503,7 +590,7 @@ class MiniCPMVPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
@ -602,7 +689,7 @@ class MiniCPMVPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
processor: "MMProcessor",
**kwargs,
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
@ -677,7 +764,7 @@ class MiniCPMVPlugin(BasePlugin):
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
# image bound
@ -745,7 +832,7 @@ class MllamaPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
@ -760,43 +847,6 @@ class MllamaPlugin(BasePlugin):
return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
imglens: list[int],
) -> dict[str, "torch.Tensor"]:
r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) > 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
return mm_inputs
@override
def get_mm_inputs(
self,
@ -807,14 +857,14 @@ class MllamaPlugin(BasePlugin):
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
if mm_inputs:
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
image_token_id: int = getattr(processor, "image_token_id")
max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
@ -839,7 +889,7 @@ class PaliGemmaPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
@ -847,10 +897,10 @@ class PaliGemmaPlugin(BasePlugin):
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
num_image_tokens += 1
message["content"] = content.replace("{{image}}", "")
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
@ -866,15 +916,15 @@ class PaliGemmaPlugin(BasePlugin):
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]:
self._validate_input(processor, images, videos, audios)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
input_ids = [image_token_id] * image_seqlen + input_ids
input_ids = [image_token_id] * num_images * image_seqlen + input_ids
if labels is not None:
labels = [IGNORE_INDEX] * image_seqlen + labels
labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
return input_ids, labels
@ -888,7 +938,7 @@ class PaliGemmaPlugin(BasePlugin):
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids]
@ -906,33 +956,31 @@ class PixtralPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
image_end_token = getattr(processor, "image_end_token")
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"].tolist())
image_break_token: str = getattr(processor, "image_break_token")
image_end_token: str = getattr(processor, "image_end_token")
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if self.expand_mm_tokens:
height, width = next(image_sizes)
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
num_height_tokens = height // processor.patch_size
num_width_tokens = width // processor.patch_size
replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens)
else:
replace_str = image_token
replace_str = self.image_token
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
num_image_tokens += 1
@ -954,7 +1002,7 @@ class PixtralPlugin(BasePlugin):
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -971,17 +1019,18 @@ class Qwen2AudioPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token")
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs([], [], audios, processor)
if "feature_attention_mask" in mm_inputs:
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
num_audio_tokens = 0
messages = deepcopy(messages)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs([], [], audios, processor)
if "feature_attention_mask" in mm_inputs:
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
for message in messages:
content = message["content"]
while AUDIO_PLACEHOLDER in content:
@ -1014,7 +1063,7 @@ class Qwen2AudioPlugin(BasePlugin):
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -1072,7 +1121,7 @@ class Qwen2VLPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: "ProcessorMixin",
processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
mm_inputs = {}
@ -1104,7 +1153,7 @@ class Qwen2VLPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
@ -1162,14 +1211,15 @@ class Qwen2VLPlugin(BasePlugin):
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video]
return mm_inputs
@ -1183,45 +1233,45 @@ class VideoLlavaPlugin(BasePlugin):
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
num_frames = 0
has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs
if has_images or has_videos:
if self.expand_mm_tokens:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values_images" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
num_frames = 1
if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
if "pixel_values_videos" in mm_inputs:
one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
height, width = get_image_size(one_video[0])
num_frames = one_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
image_seqlen = (height // processor.patch_size) * (
width // processor.patch_size
) + processor.num_additional_image_tokens
video_seqlen = image_seqlen * num_frames
if getattr(processor, "vision_feature_select_strategy", "default") == "default":
if processor.vision_feature_select_strategy == "default":
image_seqlen -= 1
else:
image_seqlen, video_seqlen = 1, 1
else:
image_seqlen, video_seqlen = 1, 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
num_image_tokens += 1
while VIDEO_PLACEHOLDER in content:
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1
while VIDEO_PLACEHOLDER in content:
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
num_video_tokens += 1
content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token)
content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
@ -1231,24 +1281,10 @@ class VideoLlavaPlugin(BasePlugin):
return messages
@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
audlens: Sequence[int],
batch_ids: Sequence[list[int]],
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
PLUGINS = {
"base": BasePlugin,
"gemma3": Gemma3Plugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,

View File

@ -310,6 +310,8 @@ class Template:
@dataclass
class Llama2Template(Template):
r"""A template that fuse the system message to first user message."""
@override
def _encode(
self,
@ -815,10 +817,29 @@ register_template(
name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
template_class=Llama2Template,
)
# copied from gemma template
register_template(
name="gemma3",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<end_of_turn>\n"]),
format_system=StringFormatter(slots=["{{content}}\n\n"]),
format_observation=StringFormatter(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
mm_plugin=get_mm_plugin("gemma3", image_token="<image_soft_token>"),
template_class=Llama2Template,
)
@ -1255,6 +1276,7 @@ register_template(
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<end_of_turn>"],
mm_plugin=get_mm_plugin(name="paligemma", image_token="<image>"),
)

View File

@ -650,11 +650,51 @@ register_model_group(
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
},
"Gemma-3-1B": {
DownloadSource.DEFAULT: "google/gemma-3-1b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-pt",
},
"Gemma-3-1B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-1b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it",
},
},
template="gemma",
)
register_model_group(
models={
"Gemma-3-4B": {
DownloadSource.DEFAULT: "google/gemma-3-4b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-pt",
},
"Gemma-3-12B": {
DownloadSource.DEFAULT: "google/gemma-3-12b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-pt",
},
"Gemma-3-27B": {
DownloadSource.DEFAULT: "google/gemma-3-27b-pt",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-pt",
},
"Gemma-3-4B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-4b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-it",
},
"Gemma-3-12B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-12b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-it",
},
"Gemma-3-27B-Instruct": {
DownloadSource.DEFAULT: "google/gemma-3-27b-it",
DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it",
},
},
template="gemma3",
multimodal=True,
)
register_model_group(
models={
"GLM-4-9B": {

View File

@ -19,6 +19,7 @@ import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForImageTextToText,
AutoModelForSeq2SeqLM,
AutoModelForVision2Seq,
AutoProcessor,
@ -72,7 +73,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@ -94,7 +94,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
patch_tokenizer(tokenizer, model_args)
try:
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs)
patch_processor(processor, config, tokenizer, model_args)
patch_processor(processor, tokenizer, model_args)
except Exception as e:
logger.debug(f"Processor was not found: {e}.")
processor = None
@ -141,9 +141,11 @@ def load_model(
if model_args.mixture_of_depths == "load":
model = load_mod_pretrained_model(**init_kwargs)
else:
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
load_class = AutoModelForVision2Seq
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys():
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
load_class = AutoModelForImageTextToText
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
load_class = AutoModelForSeq2SeqLM
else:
load_class = AutoModelForCausalLM

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
@ -28,7 +28,7 @@ from ...extras import logging
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from ...hparams import FinetuningArguments, ModelArguments
@ -62,6 +62,16 @@ def _register_composite_model(
language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[list[str]] = None,
):
r"""Register a new composite model.
Args:
model_type: model type
projector_key: multi_modal_projector
vision_model_keys: vision_tower
language_model_keys: language_model
lora_conflict_keys: None
"""
COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type,
projector_key=projector_key or "multi_modal_projector",
@ -169,39 +179,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
return forbidden_modules
def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""Compute the number of special tokens per image."""
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
else:
image_seqlen = -1
return image_seqlen
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""Compute the patch size of the vit."""
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""Get the vision_feature_select_strategy."""
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy
def patch_target_modules(
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
) -> list[str]:
r"""Freezes vision tower for VLM LoRA tuning."""
r"""Freeze vision tower for VLM LoRA tuning."""
model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS:
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
@ -218,6 +199,11 @@ def patch_target_modules(
return target_modules
_register_composite_model(
model_type="gemma3",
)
_register_composite_model(
model_type="llava",
)

View File

@ -33,13 +33,7 @@ from .model_utils.packing import configure_packing
from .model_utils.quantization import configure_quantization
from .model_utils.rope import configure_rope
from .model_utils.valuehead import prepare_valuehead_model
from .model_utils.visual import (
autocast_projector_dtype,
configure_visual_model,
get_image_seqlen,
get_patch_size,
get_vision_feature_select_strategy,
)
from .model_utils.visual import autocast_projector_dtype, configure_visual_model
if TYPE_CHECKING:
@ -72,21 +66,16 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
def patch_processor(
processor: "ProcessorMixin",
config: "PretrainedConfig",
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
) -> None:
setattr(processor, "tokenizer", tokenizer)
if getattr(config, "vision_config", None) is not None: # visual models
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "patch_size", get_patch_size(config, processor))
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
def patch_config(