From 2e518f255fe807b6bf6fb0eef9b16fa36b3edecc Mon Sep 17 00:00:00 2001
From: Kingsley <82590017+Kuangdd01@users.noreply.github.com>
Date: Thu, 17 Apr 2025 00:31:30 +0800
Subject: [PATCH] [model] support intern-VL 2.5-3 series (#7258)
* add internvl and rebase
* fix for internvl2&3
* remove lines
* fix video_inputs & lint
* nit
* add constants
* remove lines
* fix
* fix error
* pass ci
* pass ci
* skip internvl & nit
---
README.md | 1 +
README_zh.md | 1 +
src/llamafactory/data/mm_plugin.py | 179 ++++++++++++++++++-
src/llamafactory/data/template.py | 14 ++
src/llamafactory/extras/constants.py | 29 +++
src/llamafactory/model/model_utils/visual.py | 5 +
tests/data/test_mm_plugin.py | 18 ++
tests/version.txt | 2 +-
8 files changed, 247 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 72d73688..699b85e8 100644
--- a/README.md
+++ b/README.md
@@ -247,6 +247,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
+| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
diff --git a/README_zh.md b/README_zh.md
index dc0aabe2..058d443a 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -250,6 +250,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
+| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py
index e07b764d..cfbeefd2 100644
--- a/src/llamafactory/data/mm_plugin.py
+++ b/src/llamafactory/data/mm_plugin.py
@@ -25,7 +25,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np
import torch
-from transformers.image_utils import get_image_size, to_numpy_array
+from transformers.image_utils import (
+ get_image_size,
+ make_batched_videos,
+ make_flat_list_of_images,
+ to_numpy_array,
+)
from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@@ -82,6 +87,20 @@ if TYPE_CHECKING:
pass
+def _concatenate_list(input_list):
+ r"""Concatenate a list of lists, numpy arrays or torch tensors.
+
+ Returns:
+ a list of numpy arrays or torch tensors.
+ """
+ if isinstance(input_list[0], list):
+ return [item for sublist in input_list for item in sublist]
+ elif isinstance(input_list[0], np.ndarray):
+ return np.concatenate(input_list, axis=0)
+ elif isinstance(input_list[0], torch.Tensor):
+ return torch.cat(input_list, dim=0)
+
+
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
r"""Get paligemma token type ids for computing loss.
@@ -467,6 +486,163 @@ class Gemma3Plugin(BasePlugin):
@dataclass
+class InternVLPlugin(BasePlugin):
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["ProcessorMixin"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ num_image_tokens = 0
+ num_video_tokens = 0
+ image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
+ messages = deepcopy(messages)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+
+ image_pixel_patch_list = mm_inputs.get("image_num_patches", None) # pathes of images
+ video_num_patches = mm_inputs.get("video_num_patches", None) # all patches for frames of videos
+ video_patch_indices = mm_inputs.get("video_patch_indices", None) # num frames of per video
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ if num_image_tokens >= len(image_pixel_patch_list):
+ raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
+ content = content.replace(
+ IMAGE_PLACEHOLDER,
+ f"
{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}",
+ 1,
+ )
+ num_image_tokens += 1
+ message["content"] = content
+
+ while VIDEO_PLACEHOLDER in content:
+ if num_video_tokens >= len(video_patch_indices):
+ raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
+ current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
+ end_patch_index = video_patch_indices[num_video_tokens]
+ num_patches = list(video_num_patches[current_patch_index:end_patch_index])
+ video_replaced_prompt = "\n".join(
+ f"Frame{i + 1}:
{'' * image_seqlen * num_patches[i]}"
+ for i in range(len(num_patches))
+ )
+ content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
+ num_video_tokens += 1
+ message["content"] = content
+
+ if len(images) != num_image_tokens:
+ raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
+
+ if len(videos) != num_video_tokens:
+ raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
+
+ return messages
+
+ @override
+ def _get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: "ProcessorMixin",
+ **kwargs,
+ ) -> dict[str, "torch.Tensor"]:
+ image_processor: BaseImageProcessor = getattr(processor, "image_processor")
+ attributes = ["crop_to_patches", "min_patches", "max_patches"] # need for image processor
+ image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes}
+
+ mm_inputs = {}
+ image_video_patches = []
+
+ if len(images) != 0 and isinstance(images[0], str):
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+
+ if len(videos) != 0 and isinstance(videos[0], str):
+ videos = self._regularize_videos(
+ videos,
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 128),
+ )["videos"]
+
+ if len(images) != 0:
+ images = make_flat_list_of_images(images)
+ image_inputs = image_processor(images=images, **image_kwargs)
+ image_num_patches = image_inputs.pop("num_patches")
+ image_pixel_values = image_inputs.pop("pixel_values")
+ image_num_patches_indices = np.cumsum(image_num_patches)
+
+ if len(videos) != 0:
+ videos = make_batched_videos(videos)
+ num_frames_per_video = [len(video) for video in videos]
+ patch_indices = np.cumsum(num_frames_per_video)
+ image_kwargs["crop_to_patches"] = False
+ video_inputs = image_processor(images=videos, **image_kwargs)
+ video_num_patches = video_inputs.pop("num_patches")
+ video_pixel_values = video_inputs.pop("pixel_values")
+ video_num_patches_indices = np.cumsum(video_num_patches)
+
+ # NOT SUPPORT IMAGE VIDEO INTERLEAVED
+ if len(images) != 0 and image_pixel_values is not None:
+ for i in range(len(images)):
+ start_index = image_num_patches_indices[i - 1] if i > 0 else 0
+ end_index = image_num_patches_indices[i]
+ image_video_patches.append(image_pixel_values[start_index:end_index])
+
+ if len(videos) != 0 and video_pixel_values is not None:
+ for i in range(len(videos)):
+ current_patch_index = patch_indices[i - 1] if i > 0 else 0
+ end_patch_index = patch_indices[i]
+ start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0
+ end_index = video_num_patches_indices[end_patch_index - 1]
+ image_video_patches.append(video_pixel_values[start_index:end_index])
+
+ if len(images) != 0 or len(videos) != 0:
+ pixel_values_list = _concatenate_list(image_video_patches)
+ mm_inputs["pixel_values"] = torch.stack(
+ [torch.tensor(patch_ndarray) for patch_ndarray in pixel_values_list]
+ )
+
+ if len(images) != 0:
+ mm_inputs.update({"image_num_patches": image_num_patches})
+
+ if len(videos) != 0:
+ mm_inputs.update({"video_patch_indices": patch_indices})
+ mm_inputs.update({"video_num_patches": video_num_patches})
+
+ return mm_inputs
+
+ @override
+ def get_mm_inputs(
+ self,
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ imglens: list[int],
+ vidlens: list[int],
+ audlens: list[int],
+ batch_ids: list[list[int]],
+ processor: Optional["ProcessorMixin"],
+ ) -> dict[str, Union[list[int], "torch.Tensor"]]:
+ self._validate_input(processor, images, videos, audios)
+
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ mm_inputs.pop("image_num_patches", None)
+ mm_inputs.pop("video_patch_indices", None)
+ mm_inputs.pop("video_num_patches", None)
+
+ return mm_inputs
+
+
class KimiVLPlugin(BasePlugin):
@override
def process_messages(self, messages, images, videos, audios, processor):
@@ -1603,6 +1779,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
"gemma3": Gemma3Plugin,
+ "intern_vl": InternVLPlugin,
"kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin,
"llava": LlavaPlugin,
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index 0a334787..b02d6df2 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -923,6 +923,20 @@ register_template(
)
+register_template(
+ name="intern_vl",
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+ default_system=(
+ "你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
+ ),
+ stop_words=["<|im_end|>"],
+ mm_plugin=get_mm_plugin(name="intern_vl", image_token="", video_token="