From fa0eb91f1f73cb6cd7543441b9b35d2028ce4c73 Mon Sep 17 00:00:00 2001 From: Kingsley <82590017+Kuangdd01@users.noreply.github.com> Date: Wed, 23 Apr 2025 00:58:22 +0800 Subject: [PATCH] [data] fix internvl plugin (#7817) --- src/llamafactory/data/mm_plugin.py | 45 +++++++++++--------------- src/llamafactory/hparams/model_args.py | 4 +++ src/llamafactory/model/patcher.py | 1 + tests/version.txt | 2 +- 4 files changed, 24 insertions(+), 28 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 324b857f..5e32fab4 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -21,7 +21,7 @@ import re from copy import deepcopy from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Optional, TypedDict, Union +from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union import numpy as np import torch @@ -86,20 +86,6 @@ if TYPE_CHECKING: pass -def _concatenate_list(input_list: list[Any]) -> Union[list[Any], "NDArray", "torch.Tensor"]: - r"""Concatenate a list of lists, numpy arrays or torch tensors. - - Returns: - a list of numpy arrays or torch tensors. - """ - if isinstance(input_list[0], list): - return [item for sublist in input_list for item in sublist] - elif isinstance(input_list[0], np.ndarray): - return np.concatenate(input_list, axis=0) - elif isinstance(input_list[0], torch.Tensor): - return torch.cat(input_list, dim=0) - - def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]: r"""Get paligemma token type ids for computing loss. @@ -496,8 +482,15 @@ class InternVLPlugin(BasePlugin): **kwargs, ) -> dict[str, "torch.Tensor"]: image_processor: BaseImageProcessor = getattr(processor, "image_processor") - attributes = ["crop_to_patches", "min_patches", "max_patches"] # need for image processor - image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes} + image_processor_kwargs = {} + if getattr(processor, "crop_to_patches", False): + image_processor_kwargs.update( + { + "crop_to_patches": True, + "max_patches": 12, + "min_patches": 1, + } + ) mm_inputs = {} image_video_patches = [] @@ -520,7 +513,7 @@ class InternVLPlugin(BasePlugin): if len(images) != 0: images = make_flat_list_of_images(images) - image_inputs = image_processor(images=images, **image_kwargs) + image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs) image_num_patches = image_inputs.pop("num_patches") image_pixel_values = image_inputs.pop("pixel_values") image_num_patches_indices = np.cumsum(image_num_patches) @@ -529,8 +522,8 @@ class InternVLPlugin(BasePlugin): videos = make_batched_videos(videos) num_frames_per_video = [len(video) for video in videos] patch_indices = np.cumsum(num_frames_per_video) - image_kwargs["crop_to_patches"] = False - video_inputs = image_processor(images=videos, **image_kwargs) + image_processor_kwargs["crop_to_patches"] = False + video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs) video_num_patches = video_inputs.pop("num_patches") video_pixel_values = video_inputs.pop("pixel_values") video_num_patches_indices = np.cumsum(video_num_patches) @@ -543,18 +536,16 @@ class InternVLPlugin(BasePlugin): image_video_patches.append(image_pixel_values[start_index:end_index]) if len(videos) != 0 and video_pixel_values is not None: + patch_indices_with_prefix = [0] + list(patch_indices) for i in range(len(videos)): - current_patch_index = patch_indices[i - 1] if i > 0 else 0 - end_patch_index = patch_indices[i] - start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0 + current_patch_index = patch_indices_with_prefix[i] + end_patch_index = patch_indices_with_prefix[i + 1] + start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0 end_index = video_num_patches_indices[end_patch_index - 1] image_video_patches.append(video_pixel_values[start_index:end_index]) if len(images) != 0 or len(videos) != 0: - pixel_values_list = _concatenate_list(image_video_patches) - # in the latest version of transformers, - # the pixel_values is a list of tensors not ndarray - mm_inputs["pixel_values"] = torch.stack(pixel_values_list) + mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0) if len(images) != 0: mm_inputs.update({"image_num_patches": image_num_patches}) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 07319ca8..e7a74046 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -231,6 +231,10 @@ class ProcessorArguments: default=False, metadata={"help": "Use pan and scan to process image for gemma3."}, ) + crop_to_patches: bool = field( + default=False, + metadata={"help": "Whether to crop the image to patches for internvl."}, + ) use_audio_in_video: bool = field( default=False, metadata={"help": "Whether or not to use audio in video inputs."}, diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index add6320a..ce1a5a7d 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -80,6 +80,7 @@ def patch_processor( setattr(processor, "image_max_pixels", model_args.image_max_pixels) setattr(processor, "image_min_pixels", model_args.image_min_pixels) setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan) + setattr(processor, "crop_to_patches", model_args.crop_to_patches) setattr(processor, "video_max_pixels", model_args.video_max_pixels) setattr(processor, "video_min_pixels", model_args.video_min_pixels) setattr(processor, "video_fps", model_args.video_fps) diff --git a/tests/version.txt b/tests/version.txt index aee61b44..bceb4d5c 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.3.103 +0.9.3.104