mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[model] support intern-VL 2.5-3 series (#7258)
* add internvl and rebase * fix for internvl2&3 * remove lines * fix video_inputs & lint * nit * add constants * remove lines * fix * fix error * pass ci * pass ci * skip internvl & nit
This commit is contained in:
		
							parent
							
								
									8f88a4e6a4
								
							
						
					
					
						commit
						2e518f255f
					
				@ -247,6 +247,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
 | 
			
		||||
| [Hunyuan](https://huggingface.co/tencent/)                        | 7B                               | hunyuan             |
 | 
			
		||||
| [Index](https://huggingface.co/IndexTeam)                         | 1.9B                             | index               |
 | 
			
		||||
| [InternLM 2-3](https://huggingface.co/internlm)                   | 7B/8B/20B                        | intern2             |
 | 
			
		||||
| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL)        | 1B/2B/4B/8B/9B/14B/26B/38B/78B   | intern_vl           |
 | 
			
		||||
| [Kimi-VL](https://huggingface.co/moonshotai)                      | 16B                              | kimi_vl             |
 | 
			
		||||
| [Llama](https://github.com/facebookresearch/llama)                | 7B/13B/33B/65B                   | -                   |
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2              |
 | 
			
		||||
 | 
			
		||||
@ -250,6 +250,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
| [Hunyuan](https://huggingface.co/tencent/)                        | 7B                               | hunyuan             |
 | 
			
		||||
| [Index](https://huggingface.co/IndexTeam)                         | 1.9B                             | index               |
 | 
			
		||||
| [InternLM 2-3](https://huggingface.co/internlm)                   | 7B/8B/20B                        | intern2             |
 | 
			
		||||
| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL)        | 1B/2B/4B/8B/9B/14B/26B/38B/78B   | intern_vl           |
 | 
			
		||||
| [Kimi-VL](https://huggingface.co/moonshotai)                      | 16B                              | kimi_vl             |
 | 
			
		||||
| [Llama](https://github.com/facebookresearch/llama)                | 7B/13B/33B/65B                   | -                   |
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2              |
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,12 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from transformers.image_utils import get_image_size, to_numpy_array
 | 
			
		||||
from transformers.image_utils import (
 | 
			
		||||
    get_image_size,
 | 
			
		||||
    make_batched_videos,
 | 
			
		||||
    make_flat_list_of_images,
 | 
			
		||||
    to_numpy_array,
 | 
			
		||||
)
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
@ -82,6 +87,20 @@ if TYPE_CHECKING:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _concatenate_list(input_list):
 | 
			
		||||
    r"""Concatenate a list of lists, numpy arrays or torch tensors.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        a list of numpy arrays or torch tensors.
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(input_list[0], list):
 | 
			
		||||
        return [item for sublist in input_list for item in sublist]
 | 
			
		||||
    elif isinstance(input_list[0], np.ndarray):
 | 
			
		||||
        return np.concatenate(input_list, axis=0)
 | 
			
		||||
    elif isinstance(input_list[0], torch.Tensor):
 | 
			
		||||
        return torch.cat(input_list, dim=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
 | 
			
		||||
    r"""Get paligemma token type ids for computing loss.
 | 
			
		||||
 | 
			
		||||
@ -467,6 +486,163 @@ class Gemma3Plugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class InternVLPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        num_video_tokens = 0
 | 
			
		||||
        image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
        image_pixel_patch_list = mm_inputs.get("image_num_patches", None)  # pathes of images
 | 
			
		||||
        video_num_patches = mm_inputs.get("video_num_patches", None)  # all patches for frames of videos
 | 
			
		||||
        video_patch_indices = mm_inputs.get("video_patch_indices", None)  # num frames of per video
 | 
			
		||||
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            while IMAGE_PLACEHOLDER in content:
 | 
			
		||||
                if num_image_tokens >= len(image_pixel_patch_list):
 | 
			
		||||
                    raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
                content = content.replace(
 | 
			
		||||
                    IMAGE_PLACEHOLDER,
 | 
			
		||||
                    f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
 | 
			
		||||
                    1,
 | 
			
		||||
                )
 | 
			
		||||
                num_image_tokens += 1
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
            while VIDEO_PLACEHOLDER in content:
 | 
			
		||||
                if num_video_tokens >= len(video_patch_indices):
 | 
			
		||||
                    raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
                current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
 | 
			
		||||
                end_patch_index = video_patch_indices[num_video_tokens]
 | 
			
		||||
                num_patches = list(video_num_patches[current_patch_index:end_patch_index])
 | 
			
		||||
                video_replaced_prompt = "\n".join(
 | 
			
		||||
                    f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
 | 
			
		||||
                    for i in range(len(num_patches))
 | 
			
		||||
                )
 | 
			
		||||
                content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
 | 
			
		||||
                num_video_tokens += 1
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        if len(videos) != num_video_tokens:
 | 
			
		||||
            raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: "ProcessorMixin",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor")
 | 
			
		||||
        attributes = ["crop_to_patches", "min_patches", "max_patches"]  # need for image processor
 | 
			
		||||
        image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes}
 | 
			
		||||
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        image_video_patches = []
 | 
			
		||||
 | 
			
		||||
        if len(images) != 0 and isinstance(images[0], str):
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 1024 * 1024),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )["images"]
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0 and isinstance(videos[0], str):
 | 
			
		||||
            videos = self._regularize_videos(
 | 
			
		||||
                videos,
 | 
			
		||||
                image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
 | 
			
		||||
                image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
 | 
			
		||||
                video_fps=getattr(processor, "video_fps", 2.0),
 | 
			
		||||
                video_maxlen=getattr(processor, "video_maxlen", 128),
 | 
			
		||||
            )["videos"]
 | 
			
		||||
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            images = make_flat_list_of_images(images)
 | 
			
		||||
            image_inputs = image_processor(images=images, **image_kwargs)
 | 
			
		||||
            image_num_patches = image_inputs.pop("num_patches")
 | 
			
		||||
            image_pixel_values = image_inputs.pop("pixel_values")
 | 
			
		||||
            image_num_patches_indices = np.cumsum(image_num_patches)
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            videos = make_batched_videos(videos)
 | 
			
		||||
            num_frames_per_video = [len(video) for video in videos]
 | 
			
		||||
            patch_indices = np.cumsum(num_frames_per_video)
 | 
			
		||||
            image_kwargs["crop_to_patches"] = False
 | 
			
		||||
            video_inputs = image_processor(images=videos, **image_kwargs)
 | 
			
		||||
            video_num_patches = video_inputs.pop("num_patches")
 | 
			
		||||
            video_pixel_values = video_inputs.pop("pixel_values")
 | 
			
		||||
            video_num_patches_indices = np.cumsum(video_num_patches)
 | 
			
		||||
 | 
			
		||||
        # NOT SUPPORT IMAGE VIDEO INTERLEAVED
 | 
			
		||||
        if len(images) != 0 and image_pixel_values is not None:
 | 
			
		||||
            for i in range(len(images)):
 | 
			
		||||
                start_index = image_num_patches_indices[i - 1] if i > 0 else 0
 | 
			
		||||
                end_index = image_num_patches_indices[i]
 | 
			
		||||
                image_video_patches.append(image_pixel_values[start_index:end_index])
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0 and video_pixel_values is not None:
 | 
			
		||||
            for i in range(len(videos)):
 | 
			
		||||
                current_patch_index = patch_indices[i - 1] if i > 0 else 0
 | 
			
		||||
                end_patch_index = patch_indices[i]
 | 
			
		||||
                start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0
 | 
			
		||||
                end_index = video_num_patches_indices[end_patch_index - 1]
 | 
			
		||||
                image_video_patches.append(video_pixel_values[start_index:end_index])
 | 
			
		||||
 | 
			
		||||
        if len(images) != 0 or len(videos) != 0:
 | 
			
		||||
            pixel_values_list = _concatenate_list(image_video_patches)
 | 
			
		||||
            mm_inputs["pixel_values"] = torch.stack(
 | 
			
		||||
                [torch.tensor(patch_ndarray) for patch_ndarray in pixel_values_list]
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            mm_inputs.update({"image_num_patches": image_num_patches})
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0:
 | 
			
		||||
            mm_inputs.update({"video_patch_indices": patch_indices})
 | 
			
		||||
            mm_inputs.update({"video_num_patches": video_num_patches})
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        mm_inputs.pop("image_num_patches", None)
 | 
			
		||||
        mm_inputs.pop("video_patch_indices", None)
 | 
			
		||||
        mm_inputs.pop("video_num_patches", None)
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KimiVLPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(self, messages, images, videos, audios, processor):
 | 
			
		||||
@ -1603,6 +1779,7 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
PLUGINS = {
 | 
			
		||||
    "base": BasePlugin,
 | 
			
		||||
    "gemma3": Gemma3Plugin,
 | 
			
		||||
    "intern_vl": InternVLPlugin,
 | 
			
		||||
    "kimi_vl": KimiVLPlugin,
 | 
			
		||||
    "llama4": Llama4Plugin,
 | 
			
		||||
    "llava": LlavaPlugin,
 | 
			
		||||
 | 
			
		||||
@ -923,6 +923,20 @@ register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="intern_vl",
 | 
			
		||||
    format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    default_system=(
 | 
			
		||||
        "你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。"
 | 
			
		||||
    ),
 | 
			
		||||
    stop_words=["<|im_end|>"],
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="intern_vl", image_token="<image>", video_token="<video>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="kimi_vl",
 | 
			
		||||
    format_user=StringFormatter(
 | 
			
		||||
 | 
			
		||||
@ -965,6 +965,35 @@ register_model_group(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "InternVL2_5-1B-MPO": {
 | 
			
		||||
            DownloadSource.DEFAULT: "kingsley01/InternVL2_5-1B-MPO-hf",
 | 
			
		||||
        },
 | 
			
		||||
        "InternVL2_5-2B-MPO": {
 | 
			
		||||
            DownloadSource.DEFAULT: "kingsley01/InternVL2_5-2B-MPO-hf",
 | 
			
		||||
        },
 | 
			
		||||
        "InternVL2_5-4B-MPO": {
 | 
			
		||||
            DownloadSource.DEFAULT: "kingsley01/InternVL2_5-4B-MPO-hf",
 | 
			
		||||
        },
 | 
			
		||||
        "InternVL2_5-8B-MPO": {
 | 
			
		||||
            DownloadSource.DEFAULT: "kingsley01/InternVL2_5-8B-MPO-hf",
 | 
			
		||||
        },
 | 
			
		||||
        "InternVL3-1B-hf": {
 | 
			
		||||
            DownloadSource.DEFAULT: "kingsley01/InternVL3-1B-hf",
 | 
			
		||||
        },
 | 
			
		||||
        "InternVL3-2B-hf": {
 | 
			
		||||
            DownloadSource.DEFAULT: "kingsley01/InternVL3-2B-hf",
 | 
			
		||||
        },
 | 
			
		||||
        "InternVL3-8B-hf": {
 | 
			
		||||
            DownloadSource.DEFAULT: "kingsley01/InternVL3-8B-hf",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="intern_vl",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Jamba-v0.1": {
 | 
			
		||||
 | 
			
		||||
@ -198,6 +198,11 @@ def patch_target_modules(
 | 
			
		||||
        return target_modules
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="internvl",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="gemma3",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -157,6 +157,24 @@ def test_gemma3_plugin():
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.xfail(reason="cache failure.")
 | 
			
		||||
def test_internvl_plugin():
 | 
			
		||||
    image_seqlen = 256
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path="kingsley01/InternVL2_5-1B-MPO-hf")
 | 
			
		||||
    internvl_plugin = get_mm_plugin("intern_vl", image_token="<image>", video_token="<video>")
 | 
			
		||||
    check_inputs = {"plugin": internvl_plugin, **tokenizer_module}
 | 
			
		||||
    check_inputs["expected_mm_messages"] = [
 | 
			
		||||
        {
 | 
			
		||||
            key: value.replace("<image>", f"<img>{'<IMG_CONTEXT>' * image_seqlen * 1}</img>")
 | 
			
		||||
            for key, value in message.items()
 | 
			
		||||
        }
 | 
			
		||||
        for message in MM_MESSAGES
 | 
			
		||||
    ]
 | 
			
		||||
    check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
 | 
			
		||||
    check_inputs["expected_mm_inputs"].pop("num_patches", None)
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.xfail(reason="Unknown error.")
 | 
			
		||||
def test_llama4_plugin():
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
 | 
			
		||||
 | 
			
		||||
@ -1,2 +1,2 @@
 | 
			
		||||
# change if test fails
 | 
			
		||||
0.9.3.102
 | 
			
		||||
0.9.3.103
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user