[data] fix internvl plugin (#7817)

This commit is contained in:
Kingsley 2025-04-23 00:58:22 +08:00 committed by GitHub
parent 2b7d564e3b
commit 1dd67eb042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 24 additions and 28 deletions

View File

@ -21,7 +21,7 @@ import re
from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Optional, TypedDict, Union
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np
import torch
@ -86,20 +86,6 @@ if TYPE_CHECKING:
pass
def _concatenate_list(input_list: list[Any]) -> Union[list[Any], "NDArray", "torch.Tensor"]:
r"""Concatenate a list of lists, numpy arrays or torch tensors.
Returns:
a list of numpy arrays or torch tensors.
"""
if isinstance(input_list[0], list):
return [item for sublist in input_list for item in sublist]
elif isinstance(input_list[0], np.ndarray):
return np.concatenate(input_list, axis=0)
elif isinstance(input_list[0], torch.Tensor):
return torch.cat(input_list, dim=0)
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
r"""Get paligemma token type ids for computing loss.
@ -496,8 +482,15 @@ class InternVLPlugin(BasePlugin):
**kwargs,
) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
attributes = ["crop_to_patches", "min_patches", "max_patches"] # need for image processor
image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes}
image_processor_kwargs = {}
if getattr(processor, "crop_to_patches", False):
image_processor_kwargs.update(
{
"crop_to_patches": True,
"max_patches": 12,
"min_patches": 1,
}
)
mm_inputs = {}
image_video_patches = []
@ -520,7 +513,7 @@ class InternVLPlugin(BasePlugin):
if len(images) != 0:
images = make_flat_list_of_images(images)
image_inputs = image_processor(images=images, **image_kwargs)
image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs)
image_num_patches = image_inputs.pop("num_patches")
image_pixel_values = image_inputs.pop("pixel_values")
image_num_patches_indices = np.cumsum(image_num_patches)
@ -529,8 +522,8 @@ class InternVLPlugin(BasePlugin):
videos = make_batched_videos(videos)
num_frames_per_video = [len(video) for video in videos]
patch_indices = np.cumsum(num_frames_per_video)
image_kwargs["crop_to_patches"] = False
video_inputs = image_processor(images=videos, **image_kwargs)
image_processor_kwargs["crop_to_patches"] = False
video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs)
video_num_patches = video_inputs.pop("num_patches")
video_pixel_values = video_inputs.pop("pixel_values")
video_num_patches_indices = np.cumsum(video_num_patches)
@ -543,18 +536,16 @@ class InternVLPlugin(BasePlugin):
image_video_patches.append(image_pixel_values[start_index:end_index])
if len(videos) != 0 and video_pixel_values is not None:
patch_indices_with_prefix = [0] + list(patch_indices)
for i in range(len(videos)):
current_patch_index = patch_indices[i - 1] if i > 0 else 0
end_patch_index = patch_indices[i]
start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0
current_patch_index = patch_indices_with_prefix[i]
end_patch_index = patch_indices_with_prefix[i + 1]
start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0
end_index = video_num_patches_indices[end_patch_index - 1]
image_video_patches.append(video_pixel_values[start_index:end_index])
if len(images) != 0 or len(videos) != 0:
pixel_values_list = _concatenate_list(image_video_patches)
# in the latest version of transformers,
# the pixel_values is a list of tensors not ndarray
mm_inputs["pixel_values"] = torch.stack(pixel_values_list)
mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0)
if len(images) != 0:
mm_inputs.update({"image_num_patches": image_num_patches})

View File

@ -231,6 +231,10 @@ class ProcessorArguments:
default=False,
metadata={"help": "Use pan and scan to process image for gemma3."},
)
crop_to_patches: bool = field(
default=False,
metadata={"help": "Whether to crop the image to patches for internvl."},
)
use_audio_in_video: bool = field(
default=False,
metadata={"help": "Whether or not to use audio in video inputs."},

View File

@ -80,6 +80,7 @@ def patch_processor(
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
setattr(processor, "crop_to_patches", model_args.crop_to_patches)
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
setattr(processor, "video_fps", model_args.video_fps)

View File

@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.3.103
0.9.3.104