mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[data] fix internvl plugin (#7817)
This commit is contained in:
parent
2b7d564e3b
commit
1dd67eb042
@ -21,7 +21,7 @@ import re
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -86,20 +86,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _concatenate_list(input_list: list[Any]) -> Union[list[Any], "NDArray", "torch.Tensor"]:
|
||||
r"""Concatenate a list of lists, numpy arrays or torch tensors.
|
||||
|
||||
Returns:
|
||||
a list of numpy arrays or torch tensors.
|
||||
"""
|
||||
if isinstance(input_list[0], list):
|
||||
return [item for sublist in input_list for item in sublist]
|
||||
elif isinstance(input_list[0], np.ndarray):
|
||||
return np.concatenate(input_list, axis=0)
|
||||
elif isinstance(input_list[0], torch.Tensor):
|
||||
return torch.cat(input_list, dim=0)
|
||||
|
||||
|
||||
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
|
||||
r"""Get paligemma token type ids for computing loss.
|
||||
|
||||
@ -496,8 +482,15 @@ class InternVLPlugin(BasePlugin):
|
||||
**kwargs,
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||
attributes = ["crop_to_patches", "min_patches", "max_patches"] # need for image processor
|
||||
image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes}
|
||||
image_processor_kwargs = {}
|
||||
if getattr(processor, "crop_to_patches", False):
|
||||
image_processor_kwargs.update(
|
||||
{
|
||||
"crop_to_patches": True,
|
||||
"max_patches": 12,
|
||||
"min_patches": 1,
|
||||
}
|
||||
)
|
||||
|
||||
mm_inputs = {}
|
||||
image_video_patches = []
|
||||
@ -520,7 +513,7 @@ class InternVLPlugin(BasePlugin):
|
||||
|
||||
if len(images) != 0:
|
||||
images = make_flat_list_of_images(images)
|
||||
image_inputs = image_processor(images=images, **image_kwargs)
|
||||
image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs)
|
||||
image_num_patches = image_inputs.pop("num_patches")
|
||||
image_pixel_values = image_inputs.pop("pixel_values")
|
||||
image_num_patches_indices = np.cumsum(image_num_patches)
|
||||
@ -529,8 +522,8 @@ class InternVLPlugin(BasePlugin):
|
||||
videos = make_batched_videos(videos)
|
||||
num_frames_per_video = [len(video) for video in videos]
|
||||
patch_indices = np.cumsum(num_frames_per_video)
|
||||
image_kwargs["crop_to_patches"] = False
|
||||
video_inputs = image_processor(images=videos, **image_kwargs)
|
||||
image_processor_kwargs["crop_to_patches"] = False
|
||||
video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs)
|
||||
video_num_patches = video_inputs.pop("num_patches")
|
||||
video_pixel_values = video_inputs.pop("pixel_values")
|
||||
video_num_patches_indices = np.cumsum(video_num_patches)
|
||||
@ -543,18 +536,16 @@ class InternVLPlugin(BasePlugin):
|
||||
image_video_patches.append(image_pixel_values[start_index:end_index])
|
||||
|
||||
if len(videos) != 0 and video_pixel_values is not None:
|
||||
patch_indices_with_prefix = [0] + list(patch_indices)
|
||||
for i in range(len(videos)):
|
||||
current_patch_index = patch_indices[i - 1] if i > 0 else 0
|
||||
end_patch_index = patch_indices[i]
|
||||
start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0
|
||||
current_patch_index = patch_indices_with_prefix[i]
|
||||
end_patch_index = patch_indices_with_prefix[i + 1]
|
||||
start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0
|
||||
end_index = video_num_patches_indices[end_patch_index - 1]
|
||||
image_video_patches.append(video_pixel_values[start_index:end_index])
|
||||
|
||||
if len(images) != 0 or len(videos) != 0:
|
||||
pixel_values_list = _concatenate_list(image_video_patches)
|
||||
# in the latest version of transformers,
|
||||
# the pixel_values is a list of tensors not ndarray
|
||||
mm_inputs["pixel_values"] = torch.stack(pixel_values_list)
|
||||
mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0)
|
||||
|
||||
if len(images) != 0:
|
||||
mm_inputs.update({"image_num_patches": image_num_patches})
|
||||
|
@ -231,6 +231,10 @@ class ProcessorArguments:
|
||||
default=False,
|
||||
metadata={"help": "Use pan and scan to process image for gemma3."},
|
||||
)
|
||||
crop_to_patches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to crop the image to patches for internvl."},
|
||||
)
|
||||
use_audio_in_video: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use audio in video inputs."},
|
||||
|
@ -80,6 +80,7 @@ def patch_processor(
|
||||
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
|
||||
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
|
||||
setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
|
||||
setattr(processor, "crop_to_patches", model_args.crop_to_patches)
|
||||
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
|
||||
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
|
@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.3.103
|
||||
0.9.3.104
|
||||
|
Loading…
x
Reference in New Issue
Block a user