[model] support intern-VL 2.5-3 series (#7258)

* add internvl and rebase

* fix for internvl2&3

* remove lines

* fix video_inputs & lint

* nit

* add constants

* remove lines

* fix

* fix error

* pass ci

* pass ci

* skip internvl & nit
This commit is contained in:
Kingsley 2025-04-17 00:31:30 +08:00 committed by GitHub
parent 8543400584
commit 125513fa5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 247 additions and 2 deletions

View File

@ -247,6 +247,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |

View File

@ -250,6 +250,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |

View File

@ -25,7 +25,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np
import torch
from transformers.image_utils import get_image_size, to_numpy_array
from transformers.image_utils import (
get_image_size,
make_batched_videos,
make_flat_list_of_images,
to_numpy_array,
)
from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@ -82,6 +87,20 @@ if TYPE_CHECKING:
pass
def _concatenate_list(input_list):
r"""Concatenate a list of lists, numpy arrays or torch tensors.
Returns:
a list of numpy arrays or torch tensors.
"""
if isinstance(input_list[0], list):
return [item for sublist in input_list for item in sublist]
elif isinstance(input_list[0], np.ndarray):
return np.concatenate(input_list, axis=0)
elif isinstance(input_list[0], torch.Tensor):
return torch.cat(input_list, dim=0)
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
r"""Get paligemma token type ids for computing loss.
@ -467,6 +486,163 @@ class Gemma3Plugin(BasePlugin):
@dataclass
class InternVLPlugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
num_video_tokens = 0
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
image_pixel_patch_list = mm_inputs.get("image_num_patches", None) # pathes of images
video_num_patches = mm_inputs.get("video_num_patches", None) # all patches for frames of videos
video_patch_indices = mm_inputs.get("video_patch_indices", None) # num frames of per video
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if num_image_tokens >= len(image_pixel_patch_list):
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
content = content.replace(
IMAGE_PLACEHOLDER,
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
1,
)
num_image_tokens += 1
message["content"] = content
while VIDEO_PLACEHOLDER in content:
if num_video_tokens >= len(video_patch_indices):
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
end_patch_index = video_patch_indices[num_video_tokens]
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
video_replaced_prompt = "\n".join(
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
for i in range(len(num_patches))
)
content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
num_video_tokens += 1
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "ProcessorMixin",
**kwargs,
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
attributes = ["crop_to_patches", "min_patches", "max_patches"] # need for image processor
image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes}
mm_inputs = {}
image_video_patches = []
if len(images) != 0 and isinstance(images[0], str):
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
if len(videos) != 0 and isinstance(videos[0], str):
videos = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)["videos"]
if len(images) != 0:
images = make_flat_list_of_images(images)
image_inputs = image_processor(images=images, **image_kwargs)
image_num_patches = image_inputs.pop("num_patches")
image_pixel_values = image_inputs.pop("pixel_values")
image_num_patches_indices = np.cumsum(image_num_patches)
if len(videos) != 0:
videos = make_batched_videos(videos)
num_frames_per_video = [len(video) for video in videos]
patch_indices = np.cumsum(num_frames_per_video)
image_kwargs["crop_to_patches"] = False
video_inputs = image_processor(images=videos, **image_kwargs)
video_num_patches = video_inputs.pop("num_patches")
video_pixel_values = video_inputs.pop("pixel_values")
video_num_patches_indices = np.cumsum(video_num_patches)
# NOT SUPPORT IMAGE VIDEO INTERLEAVED
if len(images) != 0 and image_pixel_values is not None:
for i in range(len(images)):
start_index = image_num_patches_indices[i - 1] if i > 0 else 0
end_index = image_num_patches_indices[i]
image_video_patches.append(image_pixel_values[start_index:end_index])
if len(videos) != 0 and video_pixel_values is not None:
for i in range(len(videos)):
current_patch_index = patch_indices[i - 1] if i > 0 else 0
end_patch_index = patch_indices[i]
start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0
end_index = video_num_patches_indices[end_patch_index - 1]
image_video_patches.append(video_pixel_values[start_index:end_index])
if len(images) != 0 or len(videos) != 0:
pixel_values_list = _concatenate_list(image_video_patches)
mm_inputs["pixel_values"] = torch.stack(
[torch.tensor(patch_ndarray) for patch_ndarray in pixel_values_list]
)
if len(images) != 0:
mm_inputs.update({"image_num_patches": image_num_patches})
if len(videos) != 0:
mm_inputs.update({"video_patch_indices": patch_indices})
mm_inputs.update({"video_num_patches": video_num_patches})
return mm_inputs
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_num_patches", None)
mm_inputs.pop("video_patch_indices", None)
mm_inputs.pop("video_num_patches", None)
return mm_inputs
class KimiVLPlugin(BasePlugin):
@override
def process_messages(self, messages, images, videos, audios, processor):
@ -1603,6 +1779,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
"gemma3": Gemma3Plugin,
"intern_vl": InternVLPlugin,
"kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin,
"llava": LlavaPlugin,

View File

@ -923,6 +923,20 @@ register_template(
)
register_template(
name="intern_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
default_system=(
"你是书生·万象英文名是InternVL是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
),
stop_words=["<|im_end|>"],
mm_plugin=get_mm_plugin(name="intern_vl", image_token="<image>", video_token="<video>"),
)
register_template(
name="kimi_vl",
format_user=StringFormatter(

View File

@ -965,6 +965,35 @@ register_model_group(
)
register_model_group(
models={
"InternVL2_5-1B-MPO": {
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-1B-MPO-hf",
},
"InternVL2_5-2B-MPO": {
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-2B-MPO-hf",
},
"InternVL2_5-4B-MPO": {
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-4B-MPO-hf",
},
"InternVL2_5-8B-MPO": {
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-8B-MPO-hf",
},
"InternVL3-1B-hf": {
DownloadSource.DEFAULT: "kingsley01/InternVL3-1B-hf",
},
"InternVL3-2B-hf": {
DownloadSource.DEFAULT: "kingsley01/InternVL3-2B-hf",
},
"InternVL3-8B-hf": {
DownloadSource.DEFAULT: "kingsley01/InternVL3-8B-hf",
},
},
template="intern_vl",
multimodal=True,
)
register_model_group(
models={
"Jamba-v0.1": {

View File

@ -198,6 +198,11 @@ def patch_target_modules(
return target_modules
_register_composite_model(
model_type="internvl",
)
_register_composite_model(
model_type="gemma3",
)

View File

@ -157,6 +157,24 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs)
@pytest.mark.xfail(reason="cache failure.")
def test_internvl_plugin():
image_seqlen = 256
tokenizer_module = _load_tokenizer_module(model_name_or_path="kingsley01/InternVL2_5-1B-MPO-hf")
internvl_plugin = get_mm_plugin("intern_vl", image_token="<image>", video_token="<video>")
check_inputs = {"plugin": internvl_plugin, **tokenizer_module}
check_inputs["expected_mm_messages"] = [
{
key: value.replace("<image>", f"<img>{'<IMG_CONTEXT>' * image_seqlen * 1}</img>")
for key, value in message.items()
}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
check_inputs["expected_mm_inputs"].pop("num_patches", None)
_check_plugin(**check_inputs)
@pytest.mark.xfail(reason="Unknown error.")
def test_llama4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)

View File

@ -1,2 +1,2 @@
# change if test fails
0.9.3.102
0.9.3.103