From 165d3ed084b093accb2aa1d1209d1903183245e1 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 13 Mar 2025 01:35:23 +0800 Subject: [PATCH] [model] support gemma3 (#7273) --- README.md | 13 +- README_zh.md | 13 +- pyproject.toml | 1 + src/llamafactory/data/mm_plugin.py | 456 ++++++++++--------- src/llamafactory/data/template.py | 22 + src/llamafactory/extras/constants.py | 40 ++ src/llamafactory/model/loader.py | 10 +- src/llamafactory/model/model_utils/visual.py | 50 +- src/llamafactory/model/patcher.py | 25 +- 9 files changed, 356 insertions(+), 274 deletions(-) diff --git a/README.md b/README.md index 93006e8b..cdc3f05d 100644 --- a/README.md +++ b/README.md @@ -84,10 +84,10 @@ Choose your path: ### Day-N Support for Fine-Tuning Cutting-Edge Models -| Support Date | Model Name | -| ------------ | ---------------------------------------------------------- | -| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 | -| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 | +| Support Date | Model Name | +| ------------ | ------------------------------------------------------------ | +| Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 | +| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 | ## Benchmark @@ -106,6 +106,8 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[25/03/12] We supported fine-tuning the **[Gemma-3](https://huggingface.co/blog/gemma3)** model. + [25/02/24] Announcing **[EasyR1](https://github.com/hiyouga/EasyR1)**, an efficient, scalable and multi-modality RL training framework for efficient GRPO training. [25/02/11] We supported saving the **[Ollama](https://github.com/ollama/ollama)** modelfile when exporting the model checkpoints. See [examples](examples/README.md) for usage. @@ -120,7 +122,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR. -[25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. +[25/01/14] We supported fine-tuning the **[InternLM 3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. @@ -229,6 +231,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | diff --git a/README_zh.md b/README_zh.md index 92349517..d3a1c0c7 100644 --- a/README_zh.md +++ b/README_zh.md @@ -86,10 +86,10 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ### 最新模型的 Day-N 微调适配 -| 适配时间 | 模型名称 | -| ------------ | ---------------------------------------------------------- | -| Day 0 | Qwen2.5 / Qwen2-VL / QwQ / QvQ / InternLM3 / MiniCPM-o-2.6 | -| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 | +| 适配时间 | 模型名称 | +| ------------ | ------------------------------------------------------------ | +| Day 0 | Qwen2.5 / Qwen2.5-VL / Gemma 3 / InternLM 3 / MiniCPM-o-2.6 | +| Day 1 | Llama 3 / GLM-4 / Mistral Small / PaliGemma2 | ## 性能指标 @@ -108,6 +108,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 +[25/03/12] 我们支持了 **[Gemma-3](https://huggingface.co/blog/gemma3)** 模型的微调。 + [25/02/24] 我们宣布开源 **[EasyR1](https://github.com/hiyouga/EasyR1)**,一个高效可扩展的多模态强化学习框架,支持高效的 GRPO 训练。 [25/02/11] 我们支持了在导出模型时保存 **[Ollama](https://github.com/ollama/ollama)** 配置文件。详细用法请参照 [examples](examples/README_zh.md)。 @@ -122,7 +124,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 [25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR. -[25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 +[25/01/14] 我们支持了 **[InternLM 3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 [25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 @@ -231,6 +233,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [DeepSeek R1 (Distill)](https://huggingface.co/deepseek-ai) | 1.5B/7B/8B/14B/32B/70B/671B | deepseek3 | | [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | | [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3 | | [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | diff --git a/pyproject.toml b/pyproject.toml index cf011762..7a348c36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ indent-width = 4 ignore = [ "C408", # collection "C901", # complex + "E501", # line too long "E731", # lambda function "E741", # ambiguous var name "D100", # no doc public module diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 422b10aa..ffcb1408 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1,3 +1,20 @@ +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect import math import re @@ -5,7 +22,7 @@ from collections.abc import Sequence from copy import deepcopy from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Optional, TypedDict, Union +from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union import numpy as np import torch @@ -56,24 +73,63 @@ if TYPE_CHECKING: VideoInput = str AudioInput = Union[str, NDArray] + class MMProcessor(ProcessorMixin): + patch_size: int + image_seq_length: int + num_additional_image_tokens: int + vision_feature_select_strategy: Literal["default", "full"] + + def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int: + pass + def _get_paligemma_token_type_ids( - imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin" + imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor" ) -> list[list[int]]: r"""Get paligemma token type ids for computing loss. + It is slightly different with the original token type ids where the prompt part is 0. + Returns: - batch_token_type_ids: shape (batch_size, sequence_length) + batch_token_type_ids: shape (batch_size, seq_length) """ batch_token_type_ids = [] for imglen, seqlen in zip(imglens, seqlens): - image_seqlen = imglen * getattr(processor, "image_seqlen") + image_seqlen = imglen * processor.image_seq_length batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen)) return batch_token_type_ids +def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"): + r"""Get gemma3 token type ids for computing loss. + + Returns: + batch_token_type_ids: shape (batch_size, seq_length) + + """ + image_token_id: int = getattr(processor, "image_token_id") + batch_token_type_ids = [] + for token_ids in batch_ids: + token_ids = np.array(token_ids) + token_type_ids = np.zeros_like(token_ids) + token_type_ids[token_ids == image_token_id] = 1 + batch_token_type_ids.append(token_type_ids.tolist()) + + return batch_token_type_ids + + +def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]: + r"""Make nested list of images.""" + batch_images = [] + for imglen in imglens: + batch_images.append(images[:imglen]) + images = images[imglen:] + + return batch_images + + @dataclass class MMPluginMixin: image_token: Optional[str] @@ -83,7 +139,7 @@ class MMPluginMixin: def _validate_input( self, - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], @@ -204,7 +260,8 @@ class MMPluginMixin: images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: "ProcessorMixin", + processor: "MMProcessor", + imglens: Optional[list[int]] = None, ) -> dict[str, "torch.Tensor"]: r"""Process visual inputs. @@ -214,23 +271,34 @@ class MMPluginMixin: Returns: (qwen2-vl) pixel_values: tensor with shape (num_patches, patch_dim) image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height + where num_patches == torch.prod(image_grid_thw) + + Returns: (mllama) + pixel_values: tensor with shape + (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width) + For example, (2, 1, 4, 3, 560, 560). + aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). + aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). + num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). - It holds num_patches == torch.prod(image_grid_thw) """ - image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) - video_processor: BaseImageProcessor = getattr(processor, "video_processor", image_processor) - feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) mm_inputs = {} - if len(images) != 0: + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) images = self._regularize_images( images, image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), ) + if imglens is not None: + images = _make_batched_images(images, imglens) + mm_inputs.update(image_processor(images, return_tensors="pt")) if len(videos) != 0: + video_processor: BaseImageProcessor = getattr( + processor, "video_processor", getattr(processor, "image_processor", None) + ) videos = self._regularize_videos( videos, image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), @@ -244,6 +312,7 @@ class MMPluginMixin: mm_inputs.update(video_processor(videos, return_tensors="pt")) if len(audios) != 0: + feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) audios = self._regularize_audios( audios, sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), @@ -270,9 +339,9 @@ class BasePlugin(MMPluginMixin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: - r"""Pre-processes input messages before tokenization for VLMs.""" + r"""Pre-process input messages before tokenization for VLMs.""" self._validate_input(processor, images, videos, audios) return messages @@ -284,9 +353,9 @@ class BasePlugin(MMPluginMixin): videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> tuple[list[int], Optional[list[int]]]: - r"""Pre-processes token ids after tokenization for VLMs.""" + r"""Pre-process token ids after tokenization for VLMs.""" self._validate_input(processor, images, videos, audios) return input_ids, labels @@ -299,7 +368,7 @@ class BasePlugin(MMPluginMixin): vidlens: Sequence[int], audlens: Sequence[int], batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: r"""Build batched multimodal inputs for VLMs. @@ -315,11 +384,11 @@ class BasePlugin(MMPluginMixin): """ self._validate_input(processor, images, videos, audios) - return {} + return self._get_mm_inputs(images, videos, audios, processor) @dataclass -class LlavaPlugin(BasePlugin): +class Gemma3Plugin(BasePlugin): @override def process_messages( self, @@ -327,19 +396,21 @@ class LlavaPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens = 0 - image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1 messages = deepcopy(messages) + boi_token: str = getattr(processor, "boi_token") + full_image_sequence: str = getattr(processor, "full_image_sequence") + image_str = full_image_sequence if self.expand_mm_tokens else boi_token for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) num_image_tokens += 1 - message["content"] = content.replace("{{image}}", self.image_token) + message["content"] = content.replace("{{image}}", image_str) if len(images) != num_image_tokens: raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") @@ -356,10 +427,53 @@ class LlavaPlugin(BasePlugin): vidlens: Sequence[int], audlens: Sequence[int], batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) - return self._get_mm_inputs(images, videos, audios, processor) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("num_crops", None) + mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor) + return mm_inputs + + +@dataclass +class LlavaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + audios: Sequence["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + num_image_tokens = 0 + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0])) + image_seqlen = (height // processor.patch_size) * ( + width // processor.patch_size + ) + processor.num_additional_image_tokens + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 + + message["content"] = content.replace("{{image}}", self.image_token) + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + return messages @dataclass @@ -371,15 +485,16 @@ class LlavaNextPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens = 0 messages = deepcopy(messages) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - if "pixel_values" in mm_inputs: - image_sizes = iter(mm_inputs["image_sizes"].tolist()) - height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"].tolist()) + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) for message in messages: content = message["content"] @@ -387,7 +502,7 @@ class LlavaNextPlugin(BasePlugin): if self.expand_mm_tokens: orig_height, orig_width = next(image_sizes) image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) - if getattr(processor, "vision_feature_select_strategy", "default") == "default": + if processor.vision_feature_select_strategy == "default": image_seqlen -= 1 else: image_seqlen = 1 @@ -402,21 +517,6 @@ class LlavaNextPlugin(BasePlugin): return messages - @override - def get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - return self._get_mm_inputs(images, videos, audios, processor) - @dataclass class LlavaNextVideoPlugin(BasePlugin): @@ -427,48 +527,50 @@ class LlavaNextVideoPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - if "pixel_values" in mm_inputs: - image_sizes = iter(mm_inputs["image_sizes"].tolist()) - height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - if self.expand_mm_tokens: - orig_height, orig_width = next(image_sizes) - image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) - if getattr(processor, "vision_feature_select_strategy", "default") == "default": - image_seqlen -= 1 - else: - image_seqlen = 1 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"].tolist()) + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) - num_image_tokens += 1 + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if self.expand_mm_tokens: + orig_height, orig_width = next(image_sizes) + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + else: + image_seqlen = 1 - message["content"] = content.replace("{{image}}", self.image_token) + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 - if "pixel_values_videos" in mm_inputs: - if self.expand_mm_tokens: - pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) - height, width = get_image_size(pixel_values_video[0]) - num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim + message["content"] = content.replace("{{image}}", self.image_token) + + if self.expand_mm_tokens: + if "pixel_values_videos" in mm_inputs: + one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer - else: - video_seqlen = 1 + else: + video_seqlen = 1 - for message in messages: - content = message["content"] - while VIDEO_PLACEHOLDER in content: - num_video_tokens += 1 - content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + for message in messages: + content = message["content"] + while VIDEO_PLACEHOLDER in content: + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + num_video_tokens += 1 - message["content"] = content.replace("{{video}}", self.video_token) + message["content"] = content.replace("{{video}}", self.video_token) if len(images) != num_image_tokens: raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") @@ -478,21 +580,6 @@ class LlavaNextVideoPlugin(BasePlugin): return messages - @override - def get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - return self._get_mm_inputs(images, videos, audios, processor) - @dataclass class MiniCPMVPlugin(BasePlugin): @@ -503,7 +590,7 @@ class MiniCPMVPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0 @@ -602,7 +689,7 @@ class MiniCPMVPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: "ProcessorMixin", + processor: "MMProcessor", **kwargs, ) -> dict[str, "torch.Tensor"]: image_processor: BaseImageProcessor = getattr(processor, "image_processor") @@ -677,7 +764,7 @@ class MiniCPMVPlugin(BasePlugin): vidlens: Sequence[int], audlens: Sequence[int], batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) # image bound @@ -745,7 +832,7 @@ class MllamaPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens = 0 @@ -760,43 +847,6 @@ class MllamaPlugin(BasePlugin): return messages - @override - def _get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - processor: "ProcessorMixin", - imglens: list[int], - ) -> dict[str, "torch.Tensor"]: - r"""Process visual inputs for mllama because its image processor only accepts List[List[ImageInput]]. - - Returns: - pixel_values: tensor with shape - (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width) - For example, (2, 1, 4, 3, 560, 560). - aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). - aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). - num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). - - """ - image_processor: BaseImageProcessor = getattr(processor, "image_processor") - mm_inputs = {} - if len(images) > 0: - images = self._regularize_images( - images, - image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), - image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - ) - batch_images = [] - for image_length in imglens: - batch_images.append(images[:image_length]) - images = images[image_length:] - - mm_inputs.update(image_processor(batch_images, return_tensors="pt")) - - return mm_inputs - @override def get_mm_inputs( self, @@ -807,14 +857,14 @@ class MllamaPlugin(BasePlugin): vidlens: Sequence[int], audlens: Sequence[int], batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) if mm_inputs: num_tiles = mm_inputs.pop("num_tiles") - image_token_id = getattr(processor, "image_token_id") - max_image_tiles = getattr(processor.image_processor, "max_image_tiles") + image_token_id: int = getattr(processor, "image_token_id") + max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles") cross_attention_token_mask = [ get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids ] @@ -839,7 +889,7 @@ class PaliGemmaPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens = 0 @@ -847,10 +897,10 @@ class PaliGemmaPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + content = content.replace(IMAGE_PLACEHOLDER, "", 1) num_image_tokens += 1 - message["content"] = content.replace("{{image}}", "") + message["content"] = content if len(images) != num_image_tokens: raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") @@ -866,15 +916,15 @@ class PaliGemmaPlugin(BasePlugin): videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], tokenizer: "PreTrainedTokenizer", - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> tuple[list[int], Optional[list[int]]]: self._validate_input(processor, images, videos, audios) num_images = len(images) - image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token + image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) - input_ids = [image_token_id] * image_seqlen + input_ids + input_ids = [image_token_id] * num_images * image_seqlen + input_ids if labels is not None: - labels = [IGNORE_INDEX] * image_seqlen + labels + labels = [IGNORE_INDEX] * num_images * image_seqlen + labels return input_ids, labels @@ -888,7 +938,7 @@ class PaliGemmaPlugin(BasePlugin): vidlens: Sequence[int], audlens: Sequence[int], batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) seqlens = [len(input_ids) for input_ids in batch_ids] @@ -906,33 +956,31 @@ class PixtralPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) - patch_size = getattr(processor, "patch_size") - image_token = getattr(processor, "image_token") - image_break_token = getattr(processor, "image_break_token") - image_end_token = getattr(processor, "image_end_token") - num_image_tokens = 0 messages = deepcopy(messages) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - if "pixel_values" in mm_inputs: - image_sizes = iter(mm_inputs["image_sizes"].tolist()) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"].tolist()) + image_break_token: str = getattr(processor, "image_break_token") + image_end_token: str = getattr(processor, "image_end_token") for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: if self.expand_mm_tokens: height, width = next(image_sizes) - num_height_tokens = height // patch_size - num_width_tokens = width // patch_size - replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens + num_height_tokens = height // processor.patch_size + num_width_tokens = width // processor.patch_size + replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list replace_tokens[-1] = image_end_token replace_str = "".join(replace_tokens) else: - replace_str = image_token + replace_str = self.image_token content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) num_image_tokens += 1 @@ -954,7 +1002,7 @@ class PixtralPlugin(BasePlugin): vidlens: Sequence[int], audlens: Sequence[int], batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) @@ -971,17 +1019,18 @@ class Qwen2AudioPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) bos_token: str = getattr(processor, "audio_bos_token") eos_token: str = getattr(processor, "audio_eos_token") - messages = deepcopy(messages) - mm_inputs = self._get_mm_inputs([], [], audios, processor) - if "feature_attention_mask" in mm_inputs: - audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist() - num_audio_tokens = 0 + messages = deepcopy(messages) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs([], [], audios, processor) + if "feature_attention_mask" in mm_inputs: + audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist() + for message in messages: content = message["content"] while AUDIO_PLACEHOLDER in content: @@ -1014,7 +1063,7 @@ class Qwen2AudioPlugin(BasePlugin): vidlens: Sequence[int], audlens: Sequence[int], batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -1072,7 +1121,7 @@ class Qwen2VLPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: "ProcessorMixin", + processor: "MMProcessor", ) -> dict[str, "torch.Tensor"]: image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) mm_inputs = {} @@ -1104,7 +1153,7 @@ class Qwen2VLPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens = 0, 0 @@ -1162,14 +1211,15 @@ class Qwen2VLPlugin(BasePlugin): vidlens: Sequence[int], audlens: Sequence[int], batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) fps_per_video = mm_inputs.pop("fps_per_video", []) image_processor: BaseImageProcessor = getattr(processor, "image_processor") + temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) if "second_per_grid_ts" in processor.model_input_names and fps_per_video: - mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video] + mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video] return mm_inputs @@ -1183,45 +1233,45 @@ class VideoLlavaPlugin(BasePlugin): images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], - processor: Optional["ProcessorMixin"], + processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) num_frames = 0 - has_images = "pixel_values_images" in mm_inputs - has_videos = "pixel_values_videos" in mm_inputs - if has_images or has_videos: - if self.expand_mm_tokens: - if has_images: - height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0])) - num_frames = 1 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values_images" in mm_inputs: + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0])) + num_frames = 1 - if has_videos: - pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) - height, width = get_image_size(pixel_values_video[0]) - num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim + if "pixel_values_videos" in mm_inputs: + one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim - image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 + if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs: + image_seqlen = (height // processor.patch_size) * ( + width // processor.patch_size + ) + processor.num_additional_image_tokens video_seqlen = image_seqlen * num_frames - if getattr(processor, "vision_feature_select_strategy", "default") == "default": + if processor.vision_feature_select_strategy == "default": image_seqlen -= 1 - else: - image_seqlen, video_seqlen = 1, 1 + else: + image_seqlen, video_seqlen = 1, 1 - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) - num_image_tokens += 1 + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 - while VIDEO_PLACEHOLDER in content: - content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) - num_video_tokens += 1 + while VIDEO_PLACEHOLDER in content: + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + num_video_tokens += 1 - content = content.replace("{{image}}", self.image_token) - message["content"] = content.replace("{{video}}", self.video_token) + content = content.replace("{{image}}", self.image_token) + message["content"] = content.replace("{{video}}", self.video_token) if len(images) != num_image_tokens: raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") @@ -1231,24 +1281,10 @@ class VideoLlavaPlugin(BasePlugin): return messages - @override - def get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], - processor: Optional["ProcessorMixin"], - ) -> dict[str, Union[list[int], "torch.Tensor"]]: - self._validate_input(processor, images, videos, audios) - return self._get_mm_inputs(images, videos, audios, processor) - PLUGINS = { "base": BasePlugin, + "gemma3": Gemma3Plugin, "llava": LlavaPlugin, "llava_next": LlavaNextPlugin, "llava_next_video": LlavaNextVideoPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 50d0da24..9454d40e 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -310,6 +310,8 @@ class Template: @dataclass class Llama2Template(Template): + r"""A template that fuse the system message to first user message.""" + @override def _encode( self, @@ -815,10 +817,29 @@ register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), format_observation=StringFormatter( slots=["tool\n{{content}}\nmodel\n"] ), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + template_class=Llama2Template, +) + + +# copied from gemma template +register_template( + name="gemma3", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter(slots=["{{content}}\n\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + mm_plugin=get_mm_plugin("gemma3", image_token=""), + template_class=Llama2Template, ) @@ -1255,6 +1276,7 @@ register_template( slots=["tool\n{{content}}\nmodel\n"] ), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], mm_plugin=get_mm_plugin(name="paligemma", image_token=""), ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index a3f222e9..4fd77d0a 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -650,11 +650,51 @@ register_model_group( DownloadSource.DEFAULT: "google/gemma-2-27b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it", }, + "Gemma-3-1B": { + DownloadSource.DEFAULT: "google/gemma-3-1b-pt", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-pt", + }, + "Gemma-3-1B-Instruct": { + DownloadSource.DEFAULT: "google/gemma-3-1b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-1b-it", + }, }, template="gemma", ) +register_model_group( + models={ + "Gemma-3-4B": { + DownloadSource.DEFAULT: "google/gemma-3-4b-pt", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-pt", + }, + "Gemma-3-12B": { + DownloadSource.DEFAULT: "google/gemma-3-12b-pt", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-pt", + }, + "Gemma-3-27B": { + DownloadSource.DEFAULT: "google/gemma-3-27b-pt", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-pt", + }, + "Gemma-3-4B-Instruct": { + DownloadSource.DEFAULT: "google/gemma-3-4b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-4b-it", + }, + "Gemma-3-12B-Instruct": { + DownloadSource.DEFAULT: "google/gemma-3-12b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-12b-it", + }, + "Gemma-3-27B-Instruct": { + DownloadSource.DEFAULT: "google/gemma-3-27b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-3-27b-it", + }, + }, + template="gemma3", + multimodal=True, +) + + register_model_group( models={ "GLM-4-9B": { diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index fb7846fe..7b116397 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -19,6 +19,7 @@ import torch from transformers import ( AutoConfig, AutoModelForCausalLM, + AutoModelForImageTextToText, AutoModelForSeq2SeqLM, AutoModelForVision2Seq, AutoProcessor, @@ -72,7 +73,6 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": Note: including inplace operation of model_args. """ init_kwargs = _get_init_kwargs(model_args) - config = load_config(model_args) try: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, @@ -94,7 +94,7 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": patch_tokenizer(tokenizer, model_args) try: processor = AutoProcessor.from_pretrained(model_args.model_name_or_path, **init_kwargs) - patch_processor(processor, config, tokenizer, model_args) + patch_processor(processor, tokenizer, model_args) except Exception as e: logger.debug(f"Processor was not found: {e}.") processor = None @@ -141,9 +141,11 @@ def load_model( if model_args.mixture_of_depths == "load": model = load_mod_pretrained_model(**init_kwargs) else: - if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # assume built-in models + if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text load_class = AutoModelForVision2Seq - elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): + elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text + load_class = AutoModelForImageTextToText + elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text load_class = AutoModelForSeq2SeqLM else: load_class = AutoModelForCausalLM diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 8c1c3df9..01c20988 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's Transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py @@ -28,7 +28,7 @@ from ...extras import logging if TYPE_CHECKING: - from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin + from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel from ...hparams import FinetuningArguments, ModelArguments @@ -62,6 +62,16 @@ def _register_composite_model( language_model_keys: Optional[list[str]] = None, lora_conflict_keys: Optional[list[str]] = None, ): + r"""Register a new composite model. + + Args: + model_type: model type + projector_key: multi_modal_projector + vision_model_keys: vision_tower + language_model_keys: language_model + lora_conflict_keys: None + + """ COMPOSITE_MODELS[model_type] = CompositeModel( model_type=model_type, projector_key=projector_key or "multi_modal_projector", @@ -169,39 +179,10 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni return forbidden_modules -def get_image_seqlen(config: "PretrainedConfig") -> int: - r"""Compute the number of special tokens per image.""" - model_type = getattr(config, "model_type", None) - if model_type == "llava": - image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2 - if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token - image_seqlen += 1 - elif model_type == "paligemma": - image_seqlen = config.vision_config.num_image_tokens - else: - image_seqlen = -1 - - return image_seqlen - - -def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int: - r"""Compute the patch size of the vit.""" - patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1)) - return patch_size - - -def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int: - r"""Get the vision_feature_select_strategy.""" - vision_feature_select_strategy = getattr( - config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default") - ) - return vision_feature_select_strategy - - def patch_target_modules( model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] ) -> list[str]: - r"""Freezes vision tower for VLM LoRA tuning.""" + r"""Freeze vision tower for VLM LoRA tuning.""" model_type = getattr(model.config, "model_type", None) if model_type in COMPOSITE_MODELS: forbidden_modules = get_forbidden_modules(model.config, finetuning_args) @@ -218,6 +199,11 @@ def patch_target_modules( return target_modules +_register_composite_model( + model_type="gemma3", +) + + _register_composite_model( model_type="llava", ) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index f732c792..8997757d 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -33,13 +33,7 @@ from .model_utils.packing import configure_packing from .model_utils.quantization import configure_quantization from .model_utils.rope import configure_rope from .model_utils.valuehead import prepare_valuehead_model -from .model_utils.visual import ( - autocast_projector_dtype, - configure_visual_model, - get_image_seqlen, - get_patch_size, - get_vision_feature_select_strategy, -) +from .model_utils.visual import autocast_projector_dtype, configure_visual_model if TYPE_CHECKING: @@ -72,21 +66,16 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument def patch_processor( processor: "ProcessorMixin", - config: "PretrainedConfig", tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments", ) -> None: setattr(processor, "tokenizer", tokenizer) - if getattr(config, "vision_config", None) is not None: # visual models - setattr(processor, "image_seqlen", get_image_seqlen(config)) - setattr(processor, "patch_size", get_patch_size(config, processor)) - setattr(processor, "image_max_pixels", model_args.image_max_pixels) - setattr(processor, "image_min_pixels", model_args.image_min_pixels) - setattr(processor, "video_max_pixels", model_args.video_max_pixels) - setattr(processor, "video_min_pixels", model_args.video_min_pixels) - setattr(processor, "video_fps", model_args.video_fps) - setattr(processor, "video_maxlen", model_args.video_maxlen) - setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor)) + setattr(processor, "image_max_pixels", model_args.image_max_pixels) + setattr(processor, "image_min_pixels", model_args.image_min_pixels) + setattr(processor, "video_max_pixels", model_args.video_max_pixels) + setattr(processor, "video_min_pixels", model_args.video_min_pixels) + setattr(processor, "video_fps", model_args.video_fps) + setattr(processor, "video_maxlen", model_args.video_maxlen) def patch_config(