mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
[data] fix internvl plugin (#7817)
This commit is contained in:
parent
49f9ed0232
commit
fa0eb91f1f
@ -21,7 +21,7 @@ import re
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Optional, TypedDict, Union
|
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -86,20 +86,6 @@ if TYPE_CHECKING:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _concatenate_list(input_list: list[Any]) -> Union[list[Any], "NDArray", "torch.Tensor"]:
|
|
||||||
r"""Concatenate a list of lists, numpy arrays or torch tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
a list of numpy arrays or torch tensors.
|
|
||||||
"""
|
|
||||||
if isinstance(input_list[0], list):
|
|
||||||
return [item for sublist in input_list for item in sublist]
|
|
||||||
elif isinstance(input_list[0], np.ndarray):
|
|
||||||
return np.concatenate(input_list, axis=0)
|
|
||||||
elif isinstance(input_list[0], torch.Tensor):
|
|
||||||
return torch.cat(input_list, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
|
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
|
||||||
r"""Get paligemma token type ids for computing loss.
|
r"""Get paligemma token type ids for computing loss.
|
||||||
|
|
||||||
@ -496,8 +482,15 @@ class InternVLPlugin(BasePlugin):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, "torch.Tensor"]:
|
) -> dict[str, "torch.Tensor"]:
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||||
attributes = ["crop_to_patches", "min_patches", "max_patches"] # need for image processor
|
image_processor_kwargs = {}
|
||||||
image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes}
|
if getattr(processor, "crop_to_patches", False):
|
||||||
|
image_processor_kwargs.update(
|
||||||
|
{
|
||||||
|
"crop_to_patches": True,
|
||||||
|
"max_patches": 12,
|
||||||
|
"min_patches": 1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
image_video_patches = []
|
image_video_patches = []
|
||||||
@ -520,7 +513,7 @@ class InternVLPlugin(BasePlugin):
|
|||||||
|
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = make_flat_list_of_images(images)
|
images = make_flat_list_of_images(images)
|
||||||
image_inputs = image_processor(images=images, **image_kwargs)
|
image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs)
|
||||||
image_num_patches = image_inputs.pop("num_patches")
|
image_num_patches = image_inputs.pop("num_patches")
|
||||||
image_pixel_values = image_inputs.pop("pixel_values")
|
image_pixel_values = image_inputs.pop("pixel_values")
|
||||||
image_num_patches_indices = np.cumsum(image_num_patches)
|
image_num_patches_indices = np.cumsum(image_num_patches)
|
||||||
@ -529,8 +522,8 @@ class InternVLPlugin(BasePlugin):
|
|||||||
videos = make_batched_videos(videos)
|
videos = make_batched_videos(videos)
|
||||||
num_frames_per_video = [len(video) for video in videos]
|
num_frames_per_video = [len(video) for video in videos]
|
||||||
patch_indices = np.cumsum(num_frames_per_video)
|
patch_indices = np.cumsum(num_frames_per_video)
|
||||||
image_kwargs["crop_to_patches"] = False
|
image_processor_kwargs["crop_to_patches"] = False
|
||||||
video_inputs = image_processor(images=videos, **image_kwargs)
|
video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs)
|
||||||
video_num_patches = video_inputs.pop("num_patches")
|
video_num_patches = video_inputs.pop("num_patches")
|
||||||
video_pixel_values = video_inputs.pop("pixel_values")
|
video_pixel_values = video_inputs.pop("pixel_values")
|
||||||
video_num_patches_indices = np.cumsum(video_num_patches)
|
video_num_patches_indices = np.cumsum(video_num_patches)
|
||||||
@ -543,18 +536,16 @@ class InternVLPlugin(BasePlugin):
|
|||||||
image_video_patches.append(image_pixel_values[start_index:end_index])
|
image_video_patches.append(image_pixel_values[start_index:end_index])
|
||||||
|
|
||||||
if len(videos) != 0 and video_pixel_values is not None:
|
if len(videos) != 0 and video_pixel_values is not None:
|
||||||
|
patch_indices_with_prefix = [0] + list(patch_indices)
|
||||||
for i in range(len(videos)):
|
for i in range(len(videos)):
|
||||||
current_patch_index = patch_indices[i - 1] if i > 0 else 0
|
current_patch_index = patch_indices_with_prefix[i]
|
||||||
end_patch_index = patch_indices[i]
|
end_patch_index = patch_indices_with_prefix[i + 1]
|
||||||
start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0
|
start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0
|
||||||
end_index = video_num_patches_indices[end_patch_index - 1]
|
end_index = video_num_patches_indices[end_patch_index - 1]
|
||||||
image_video_patches.append(video_pixel_values[start_index:end_index])
|
image_video_patches.append(video_pixel_values[start_index:end_index])
|
||||||
|
|
||||||
if len(images) != 0 or len(videos) != 0:
|
if len(images) != 0 or len(videos) != 0:
|
||||||
pixel_values_list = _concatenate_list(image_video_patches)
|
mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0)
|
||||||
# in the latest version of transformers,
|
|
||||||
# the pixel_values is a list of tensors not ndarray
|
|
||||||
mm_inputs["pixel_values"] = torch.stack(pixel_values_list)
|
|
||||||
|
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
mm_inputs.update({"image_num_patches": image_num_patches})
|
mm_inputs.update({"image_num_patches": image_num_patches})
|
||||||
|
@ -231,6 +231,10 @@ class ProcessorArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use pan and scan to process image for gemma3."},
|
metadata={"help": "Use pan and scan to process image for gemma3."},
|
||||||
)
|
)
|
||||||
|
crop_to_patches: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether to crop the image to patches for internvl."},
|
||||||
|
)
|
||||||
use_audio_in_video: bool = field(
|
use_audio_in_video: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use audio in video inputs."},
|
metadata={"help": "Whether or not to use audio in video inputs."},
|
||||||
|
@ -80,6 +80,7 @@ def patch_processor(
|
|||||||
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
|
setattr(processor, "image_max_pixels", model_args.image_max_pixels)
|
||||||
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
|
setattr(processor, "image_min_pixels", model_args.image_min_pixels)
|
||||||
setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
|
setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
|
||||||
|
setattr(processor, "crop_to_patches", model_args.crop_to_patches)
|
||||||
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
|
setattr(processor, "video_max_pixels", model_args.video_max_pixels)
|
||||||
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
|
setattr(processor, "video_min_pixels", model_args.video_min_pixels)
|
||||||
setattr(processor, "video_fps", model_args.video_fps)
|
setattr(processor, "video_fps", model_args.video_fps)
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
# change if test fails or cache is outdated
|
# change if test fails or cache is outdated
|
||||||
0.9.3.103
|
0.9.3.104
|
||||||
|
Loading…
x
Reference in New Issue
Block a user