From 2e518f255fe807b6bf6fb0eef9b16fa36b3edecc Mon Sep 17 00:00:00 2001 From: Kingsley <82590017+Kuangdd01@users.noreply.github.com> Date: Thu, 17 Apr 2025 00:31:30 +0800 Subject: [PATCH] [model] support intern-VL 2.5-3 series (#7258) * add internvl and rebase * fix for internvl2&3 * remove lines * fix video_inputs & lint * nit * add constants * remove lines * fix * fix error * pass ci * pass ci * skip internvl & nit --- README.md | 1 + README_zh.md | 1 + src/llamafactory/data/mm_plugin.py | 179 ++++++++++++++++++- src/llamafactory/data/template.py | 14 ++ src/llamafactory/extras/constants.py | 29 +++ src/llamafactory/model/model_utils/visual.py | 5 + tests/data/test_mm_plugin.py | 18 ++ tests/version.txt | 2 +- 8 files changed, 247 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 72d73688..699b85e8 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | +| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | diff --git a/README_zh.md b/README_zh.md index dc0aabe2..058d443a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -250,6 +250,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | +| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index e07b764d..cfbeefd2 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -25,7 +25,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union import numpy as np import torch -from transformers.image_utils import get_image_size, to_numpy_array +from transformers.image_utils import ( + get_image_size, + make_batched_videos, + make_flat_list_of_images, + to_numpy_array, +) from typing_extensions import override from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER @@ -82,6 +87,20 @@ if TYPE_CHECKING: pass +def _concatenate_list(input_list): + r"""Concatenate a list of lists, numpy arrays or torch tensors. + + Returns: + a list of numpy arrays or torch tensors. + """ + if isinstance(input_list[0], list): + return [item for sublist in input_list for item in sublist] + elif isinstance(input_list[0], np.ndarray): + return np.concatenate(input_list, axis=0) + elif isinstance(input_list[0], torch.Tensor): + return torch.cat(input_list, dim=0) + + def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]: r"""Get paligemma token type ids for computing loss. @@ -467,6 +486,163 @@ class Gemma3Plugin(BasePlugin): @dataclass +class InternVLPlugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["ProcessorMixin"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + num_image_tokens = 0 + num_video_tokens = 0 + image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1 + messages = deepcopy(messages) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + + image_pixel_patch_list = mm_inputs.get("image_num_patches", None) # pathes of images + video_num_patches = mm_inputs.get("video_num_patches", None) # all patches for frames of videos + video_patch_indices = mm_inputs.get("video_patch_indices", None) # num frames of per video + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(image_pixel_patch_list): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + content = content.replace( + IMAGE_PLACEHOLDER, + f"{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}", + 1, + ) + num_image_tokens += 1 + message["content"] = content + + while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(video_patch_indices): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 + end_patch_index = video_patch_indices[num_video_tokens] + num_patches = list(video_num_patches[current_patch_index:end_patch_index]) + video_replaced_prompt = "\n".join( + f"Frame{i + 1}: {'' * image_seqlen * num_patches[i]}" + for i in range(len(num_patches)) + ) + content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1) + num_video_tokens += 1 + message["content"] = content + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + if len(videos) != num_video_tokens: + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") + + return messages + + @override + def _get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: "ProcessorMixin", + **kwargs, + ) -> dict[str, "torch.Tensor"]: + image_processor: BaseImageProcessor = getattr(processor, "image_processor") + attributes = ["crop_to_patches", "min_patches", "max_patches"] # need for image processor + image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes} + + mm_inputs = {} + image_video_patches = [] + + if len(images) != 0 and isinstance(images[0], str): + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + + if len(videos) != 0 and isinstance(videos[0], str): + videos = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + )["videos"] + + if len(images) != 0: + images = make_flat_list_of_images(images) + image_inputs = image_processor(images=images, **image_kwargs) + image_num_patches = image_inputs.pop("num_patches") + image_pixel_values = image_inputs.pop("pixel_values") + image_num_patches_indices = np.cumsum(image_num_patches) + + if len(videos) != 0: + videos = make_batched_videos(videos) + num_frames_per_video = [len(video) for video in videos] + patch_indices = np.cumsum(num_frames_per_video) + image_kwargs["crop_to_patches"] = False + video_inputs = image_processor(images=videos, **image_kwargs) + video_num_patches = video_inputs.pop("num_patches") + video_pixel_values = video_inputs.pop("pixel_values") + video_num_patches_indices = np.cumsum(video_num_patches) + + # NOT SUPPORT IMAGE VIDEO INTERLEAVED + if len(images) != 0 and image_pixel_values is not None: + for i in range(len(images)): + start_index = image_num_patches_indices[i - 1] if i > 0 else 0 + end_index = image_num_patches_indices[i] + image_video_patches.append(image_pixel_values[start_index:end_index]) + + if len(videos) != 0 and video_pixel_values is not None: + for i in range(len(videos)): + current_patch_index = patch_indices[i - 1] if i > 0 else 0 + end_patch_index = patch_indices[i] + start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0 + end_index = video_num_patches_indices[end_patch_index - 1] + image_video_patches.append(video_pixel_values[start_index:end_index]) + + if len(images) != 0 or len(videos) != 0: + pixel_values_list = _concatenate_list(image_video_patches) + mm_inputs["pixel_values"] = torch.stack( + [torch.tensor(patch_ndarray) for patch_ndarray in pixel_values_list] + ) + + if len(images) != 0: + mm_inputs.update({"image_num_patches": image_num_patches}) + + if len(videos) != 0: + mm_inputs.update({"video_patch_indices": patch_indices}) + mm_inputs.update({"video_num_patches": video_num_patches}) + + return mm_inputs + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["ProcessorMixin"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("image_num_patches", None) + mm_inputs.pop("video_patch_indices", None) + mm_inputs.pop("video_num_patches", None) + + return mm_inputs + + class KimiVLPlugin(BasePlugin): @override def process_messages(self, messages, images, videos, audios, processor): @@ -1603,6 +1779,7 @@ class VideoLlavaPlugin(BasePlugin): PLUGINS = { "base": BasePlugin, "gemma3": Gemma3Plugin, + "intern_vl": InternVLPlugin, "kimi_vl": KimiVLPlugin, "llama4": Llama4Plugin, "llava": LlavaPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 0a334787..b02d6df2 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -923,6 +923,20 @@ register_template( ) +register_template( + name="intern_vl", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" + ), + stop_words=["<|im_end|>"], + mm_plugin=get_mm_plugin(name="intern_vl", image_token="", video_token="