mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] fix internvl plugin (#7817)
This commit is contained in:
		
							parent
							
								
									49f9ed0232
								
							
						
					
					
						commit
						fa0eb91f1f
					
				@ -21,7 +21,7 @@ import re
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
from typing import TYPE_CHECKING, Any, BinaryIO, Literal, Optional, TypedDict, Union
 | 
			
		||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -86,20 +86,6 @@ if TYPE_CHECKING:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _concatenate_list(input_list: list[Any]) -> Union[list[Any], "NDArray", "torch.Tensor"]:
 | 
			
		||||
    r"""Concatenate a list of lists, numpy arrays or torch tensors.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        a list of numpy arrays or torch tensors.
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(input_list[0], list):
 | 
			
		||||
        return [item for sublist in input_list for item in sublist]
 | 
			
		||||
    elif isinstance(input_list[0], np.ndarray):
 | 
			
		||||
        return np.concatenate(input_list, axis=0)
 | 
			
		||||
    elif isinstance(input_list[0], torch.Tensor):
 | 
			
		||||
        return torch.cat(input_list, dim=0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
 | 
			
		||||
    r"""Get paligemma token type ids for computing loss.
 | 
			
		||||
 | 
			
		||||
@ -496,8 +482,15 @@ class InternVLPlugin(BasePlugin):
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor")
 | 
			
		||||
        attributes = ["crop_to_patches", "min_patches", "max_patches"]  # need for image processor
 | 
			
		||||
        image_kwargs = {attr: getattr(image_processor, attr, None) for attr in attributes}
 | 
			
		||||
        image_processor_kwargs = {}
 | 
			
		||||
        if getattr(processor, "crop_to_patches", False):
 | 
			
		||||
            image_processor_kwargs.update(
 | 
			
		||||
                {
 | 
			
		||||
                    "crop_to_patches": True,
 | 
			
		||||
                    "max_patches": 12,
 | 
			
		||||
                    "min_patches": 1,
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        image_video_patches = []
 | 
			
		||||
@ -520,7 +513,7 @@ class InternVLPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            images = make_flat_list_of_images(images)
 | 
			
		||||
            image_inputs = image_processor(images=images, **image_kwargs)
 | 
			
		||||
            image_inputs = image_processor(images=images, return_tensors="pt", **image_processor_kwargs)
 | 
			
		||||
            image_num_patches = image_inputs.pop("num_patches")
 | 
			
		||||
            image_pixel_values = image_inputs.pop("pixel_values")
 | 
			
		||||
            image_num_patches_indices = np.cumsum(image_num_patches)
 | 
			
		||||
@ -529,8 +522,8 @@ class InternVLPlugin(BasePlugin):
 | 
			
		||||
            videos = make_batched_videos(videos)
 | 
			
		||||
            num_frames_per_video = [len(video) for video in videos]
 | 
			
		||||
            patch_indices = np.cumsum(num_frames_per_video)
 | 
			
		||||
            image_kwargs["crop_to_patches"] = False
 | 
			
		||||
            video_inputs = image_processor(images=videos, **image_kwargs)
 | 
			
		||||
            image_processor_kwargs["crop_to_patches"] = False
 | 
			
		||||
            video_inputs = image_processor(images=videos, return_tensors="pt", **image_processor_kwargs)
 | 
			
		||||
            video_num_patches = video_inputs.pop("num_patches")
 | 
			
		||||
            video_pixel_values = video_inputs.pop("pixel_values")
 | 
			
		||||
            video_num_patches_indices = np.cumsum(video_num_patches)
 | 
			
		||||
@ -543,18 +536,16 @@ class InternVLPlugin(BasePlugin):
 | 
			
		||||
                image_video_patches.append(image_pixel_values[start_index:end_index])
 | 
			
		||||
 | 
			
		||||
        if len(videos) != 0 and video_pixel_values is not None:
 | 
			
		||||
            patch_indices_with_prefix = [0] + list(patch_indices)
 | 
			
		||||
            for i in range(len(videos)):
 | 
			
		||||
                current_patch_index = patch_indices[i - 1] if i > 0 else 0
 | 
			
		||||
                end_patch_index = patch_indices[i]
 | 
			
		||||
                start_index = video_num_patches_indices[current_patch_index] if i > 0 else 0
 | 
			
		||||
                current_patch_index = patch_indices_with_prefix[i]
 | 
			
		||||
                end_patch_index = patch_indices_with_prefix[i + 1]
 | 
			
		||||
                start_index = video_num_patches_indices[current_patch_index - 1] if i > 0 else 0
 | 
			
		||||
                end_index = video_num_patches_indices[end_patch_index - 1]
 | 
			
		||||
                image_video_patches.append(video_pixel_values[start_index:end_index])
 | 
			
		||||
 | 
			
		||||
        if len(images) != 0 or len(videos) != 0:
 | 
			
		||||
            pixel_values_list = _concatenate_list(image_video_patches)
 | 
			
		||||
            # in the latest version of transformers,
 | 
			
		||||
            # the pixel_values is a list of tensors not ndarray
 | 
			
		||||
            mm_inputs["pixel_values"] = torch.stack(pixel_values_list)
 | 
			
		||||
            mm_inputs["pixel_values"] = torch.cat(image_video_patches, dim=0)
 | 
			
		||||
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            mm_inputs.update({"image_num_patches": image_num_patches})
 | 
			
		||||
 | 
			
		||||
@ -231,6 +231,10 @@ class ProcessorArguments:
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Use pan and scan to process image for gemma3."},
 | 
			
		||||
    )
 | 
			
		||||
    crop_to_patches: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether to crop the image to patches for internvl."},
 | 
			
		||||
    )
 | 
			
		||||
    use_audio_in_video: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use audio in video inputs."},
 | 
			
		||||
 | 
			
		||||
@ -80,6 +80,7 @@ def patch_processor(
 | 
			
		||||
    setattr(processor, "image_max_pixels", model_args.image_max_pixels)
 | 
			
		||||
    setattr(processor, "image_min_pixels", model_args.image_min_pixels)
 | 
			
		||||
    setattr(processor, "image_do_pan_and_scan", model_args.image_do_pan_and_scan)
 | 
			
		||||
    setattr(processor, "crop_to_patches", model_args.crop_to_patches)
 | 
			
		||||
    setattr(processor, "video_max_pixels", model_args.video_max_pixels)
 | 
			
		||||
    setattr(processor, "video_min_pixels", model_args.video_min_pixels)
 | 
			
		||||
    setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
 | 
			
		||||
@ -1,2 +1,2 @@
 | 
			
		||||
# change if test fails or cache is outdated
 | 
			
		||||
0.9.3.103
 | 
			
		||||
0.9.3.104
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user