mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
814 lines
32 KiB
Python
814 lines
32 KiB
Python
import math
|
|
from copy import deepcopy
|
|
from io import BytesIO
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from transformers.image_utils import get_image_size, to_numpy_array
|
|
from typing_extensions import override
|
|
|
|
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
|
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
|
|
|
|
|
if is_pillow_available():
|
|
from PIL import Image
|
|
from PIL.Image import Image as ImageObject
|
|
|
|
|
|
if is_pyav_available():
|
|
import av
|
|
|
|
|
|
if is_transformers_version_greater_than("4.45.0"):
|
|
from transformers.models.mllama.processing_mllama import (
|
|
convert_sparse_cross_attention_mask_to_dense,
|
|
get_cross_attention_token_mask,
|
|
)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from av.stream import Stream
|
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
|
from transformers.image_processing_utils import BaseImageProcessor
|
|
|
|
class EncodedImage(TypedDict):
|
|
path: Optional[str]
|
|
bytes: Optional[bytes]
|
|
|
|
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
|
|
VideoInput = str
|
|
|
|
|
|
def _get_paligemma_token_type_ids(
|
|
imglens: Sequence[int], seqlens: Sequence[int], processor: "ProcessorMixin"
|
|
) -> List[List[int]]:
|
|
r"""
|
|
Gets paligemma token type ids for computing loss.
|
|
|
|
Returns:
|
|
batch_token_type_ids: shape (batch_size, sequence_length)
|
|
"""
|
|
batch_token_type_ids = []
|
|
for imglen, seqlen in zip(imglens, seqlens):
|
|
image_seqlen = imglen * getattr(processor, "image_seqlen")
|
|
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
|
|
|
|
return batch_token_type_ids
|
|
|
|
|
|
class BasePlugin:
|
|
def __init__(self, image_token: Optional[str], video_token: Optional[str]) -> None:
|
|
self.image_token = image_token
|
|
self.video_token = video_token
|
|
self.expand_mm_tokens = True
|
|
|
|
def _validate_input(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
) -> None:
|
|
r"""
|
|
Validates if this model accepts the input modalities.
|
|
"""
|
|
if len(images) != 0 and self.image_token is None:
|
|
raise ValueError("This model does not support image input.")
|
|
|
|
if len(videos) != 0 and self.video_token is None:
|
|
raise ValueError("This model does not support video input.")
|
|
|
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
|
r"""
|
|
Pre-processes a single image.
|
|
"""
|
|
image_resolution: int = kwargs.get("image_resolution")
|
|
if (image.width * image.height) > image_resolution:
|
|
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
|
|
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
|
image = image.resize((width, height), resample=Image.NEAREST)
|
|
|
|
if image.mode != "RGB":
|
|
image = image.convert("RGB")
|
|
|
|
return image
|
|
|
|
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
|
r"""
|
|
Computes video sample frames according to fps.
|
|
"""
|
|
video_fps: float = kwargs.get("video_fps")
|
|
video_maxlen: int = kwargs.get("video_maxlen")
|
|
total_frames = video_stream.frames
|
|
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
|
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
|
return math.floor(sample_frames)
|
|
|
|
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
|
r"""
|
|
Regularizes images to avoid error. Including reading and pre-processing.
|
|
"""
|
|
results = []
|
|
for image in images:
|
|
if isinstance(image, str):
|
|
image = Image.open(image)
|
|
elif isinstance(image, bytes):
|
|
image = Image.open(BytesIO(image))
|
|
elif isinstance(image, dict):
|
|
if image["bytes"] is not None:
|
|
image = Image.open(BytesIO(image["bytes"]))
|
|
else:
|
|
image = Image.open(image["path"])
|
|
|
|
if not isinstance(image, ImageObject):
|
|
raise ValueError(f"Expect input is a list of Images, but got {type(image)}.")
|
|
|
|
results.append(self._preprocess_image(image, **kwargs))
|
|
|
|
return results
|
|
|
|
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
|
r"""
|
|
Regularizes videos to avoid error. Including reading, resizing and converting.
|
|
"""
|
|
results = []
|
|
for video in videos:
|
|
container = av.open(video, "r")
|
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
|
total_frames = video_stream.frames
|
|
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
|
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
|
frames: List["ImageObject"] = []
|
|
container.seek(0)
|
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
|
if frame_idx in sample_indices:
|
|
frames.append(frame.to_image())
|
|
|
|
frames = self._regularize_images(frames, **kwargs)
|
|
results.append(frames)
|
|
|
|
return results
|
|
|
|
def _get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: "ProcessorMixin",
|
|
) -> Dict[str, "torch.Tensor"]:
|
|
r"""
|
|
Processes visual inputs.
|
|
|
|
Returns: (llava and paligemma)
|
|
pixel_values: tensor with shape (B, C, H, W)
|
|
|
|
Returns: (qwen2-vl)
|
|
pixel_values: tensor with shape (num_patches, patch_dim)
|
|
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
|
|
|
It holds num_patches == torch.prod(image_grid_thw)
|
|
"""
|
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
|
video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor)
|
|
input_dict = {"images": None} # default key
|
|
if len(images) != 0:
|
|
images = self._regularize_images(
|
|
images,
|
|
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
|
)
|
|
input_dict["images"] = images
|
|
|
|
if len(videos) != 0:
|
|
videos = self._regularize_videos(
|
|
videos,
|
|
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
|
video_fps=getattr(processor, "video_fps", 2.0),
|
|
video_maxlen=getattr(processor, "video_maxlen", 64),
|
|
)
|
|
input_dict["videos"] = videos
|
|
|
|
mm_inputs = {}
|
|
if image_processor != video_processor:
|
|
if input_dict.get("images") is not None:
|
|
mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt"))
|
|
if input_dict.get("videos") is not None:
|
|
mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt"))
|
|
elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl)
|
|
mm_inputs.update(image_processor(**input_dict, return_tensors="pt"))
|
|
|
|
return mm_inputs
|
|
|
|
def process_messages(
|
|
self,
|
|
messages: Sequence[Dict[str, str]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> List[Dict[str, str]]:
|
|
r"""
|
|
Pre-processes input messages before tokenization for VLMs.
|
|
"""
|
|
self._validate_input(images, videos)
|
|
return messages
|
|
|
|
def process_token_ids(
|
|
self,
|
|
input_ids: List[int],
|
|
labels: Optional[List[int]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
tokenizer: "PreTrainedTokenizer",
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Tuple[List[int], Optional[List[int]]]:
|
|
r"""
|
|
Pre-processes token ids after tokenization for VLMs.
|
|
"""
|
|
self._validate_input(images, videos)
|
|
return input_ids, labels
|
|
|
|
def get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
imglens: Sequence[int],
|
|
vidlens: Sequence[int],
|
|
batch_ids: Sequence[List[int]],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
r"""
|
|
Builds batched multimodal inputs for VLMs.
|
|
|
|
Arguments:
|
|
images: a list of image inputs, shape (num_images,)
|
|
videos: a list of video inputs, shape (num_videos,)
|
|
imglens: number of images in each sample, shape (batch_size,)
|
|
vidlens: number of videos in each sample, shape (batch_size,)
|
|
batch_ids: token ids of input samples, shape (batch_size, seq_len)
|
|
processor: a processor for pre-processing images and videos
|
|
"""
|
|
self._validate_input(images, videos)
|
|
return {}
|
|
|
|
|
|
class LlavaPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: Sequence[Dict[str, str]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> List[Dict[str, str]]:
|
|
self._validate_input(images, videos)
|
|
num_image_tokens = 0
|
|
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
|
|
messages = deepcopy(messages)
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
num_image_tokens += 1
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
imglens: Sequence[int],
|
|
vidlens: Sequence[int],
|
|
batch_ids: Sequence[List[int]],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
self._validate_input(images, videos)
|
|
return self._get_mm_inputs(images, videos, processor)
|
|
|
|
|
|
class LlavaNextPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: Sequence[Dict[str, str]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> List[Dict[str, str]]:
|
|
self._validate_input(images, videos)
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
|
if "image_sizes" in mm_inputs:
|
|
image_sizes = iter(mm_inputs["image_sizes"])
|
|
|
|
if "pixel_values" in mm_inputs:
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
if self.expand_mm_tokens:
|
|
orig_height, orig_width = next(image_sizes)
|
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
|
image_seqlen -= 1
|
|
else:
|
|
image_seqlen = 1
|
|
|
|
num_image_tokens += 1
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
imglens: Sequence[int],
|
|
vidlens: Sequence[int],
|
|
batch_ids: Sequence[List[int]],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
self._validate_input(images, videos)
|
|
return self._get_mm_inputs(images, videos, processor)
|
|
|
|
|
|
class LlavaNextVideoPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: Sequence[Dict[str, str]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> List[Dict[str, str]]:
|
|
self._validate_input(images, videos)
|
|
num_image_tokens, num_video_tokens = 0, 0
|
|
messages = deepcopy(messages)
|
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
|
if "pixel_values" in mm_inputs:
|
|
image_sizes = iter(mm_inputs["image_sizes"])
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
if self.expand_mm_tokens:
|
|
orig_height, orig_width = next(image_sizes)
|
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
|
image_seqlen -= 1
|
|
else:
|
|
image_seqlen = 1
|
|
|
|
num_image_tokens += 1
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
if "pixel_values_videos" in mm_inputs:
|
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
|
height, width = get_image_size(pixel_values_video[0])
|
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
|
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
|
|
for message in messages:
|
|
content = message["content"]
|
|
while VIDEO_PLACEHOLDER in content:
|
|
num_video_tokens += 1
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
|
|
|
message["content"] = content.replace("{{video}}", self.video_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
if len(videos) != num_video_tokens:
|
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
imglens: Sequence[int],
|
|
vidlens: Sequence[int],
|
|
batch_ids: Sequence[List[int]],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
self._validate_input(images, videos)
|
|
return self._get_mm_inputs(images, videos, processor)
|
|
|
|
|
|
class PaliGemmaPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: Sequence[Dict[str, str]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> List[Dict[str, str]]:
|
|
self._validate_input(images, videos)
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
num_image_tokens += 1
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
|
|
|
message["content"] = content.replace("{{image}}", "")
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def process_token_ids(
|
|
self,
|
|
input_ids: List[int],
|
|
labels: Optional[List[int]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
tokenizer: "PreTrainedTokenizer",
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Tuple[List[int], Optional[List[int]]]:
|
|
self._validate_input(images, videos)
|
|
num_images = len(images)
|
|
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
|
|
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
|
input_ids = [image_token_id] * image_seqlen + input_ids
|
|
if labels is not None:
|
|
labels = [IGNORE_INDEX] * image_seqlen + labels
|
|
|
|
return input_ids, labels
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
imglens: Sequence[int],
|
|
vidlens: Sequence[int],
|
|
batch_ids: Sequence[List[int]],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
self._validate_input(images, videos)
|
|
seqlens = [len(input_ids) for input_ids in batch_ids]
|
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
|
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
|
return mm_inputs
|
|
|
|
|
|
class PixtralPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: Sequence[Dict[str, str]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> List[Dict[str, str]]:
|
|
self._validate_input(images, videos)
|
|
patch_size = getattr(processor, "patch_size")
|
|
image_token = getattr(processor, "image_token")
|
|
image_break_token = getattr(processor, "image_break_token")
|
|
image_end_token = getattr(processor, "image_end_token")
|
|
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
|
image_input_sizes = mm_inputs.get("image_sizes", None)
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
if image_input_sizes is None:
|
|
raise ValueError("Cannot get image input sizes.")
|
|
|
|
if self.expand_mm_tokens:
|
|
image_size = image_input_sizes[0][num_image_tokens]
|
|
height, width = image_size
|
|
num_height_tokens = height // patch_size
|
|
num_width_tokens = width // patch_size
|
|
replace_tokens = [[image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
|
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
|
replace_tokens[-1] = image_end_token
|
|
replace_str = "".join(replace_tokens)
|
|
else:
|
|
replace_str = image_token
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
|
num_image_tokens += 1
|
|
|
|
message["content"] = content
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
imglens: Sequence[int],
|
|
vidlens: Sequence[int],
|
|
batch_ids: Sequence[List[int]],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
self._validate_input(images, videos)
|
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
|
if mm_inputs.get("pixel_values"):
|
|
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0]
|
|
|
|
mm_inputs.pop("image_sizes", None)
|
|
return mm_inputs
|
|
|
|
|
|
class Qwen2vlPlugin(BasePlugin):
|
|
@override
|
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
|
image = super()._preprocess_image(image, **kwargs)
|
|
if min(image.width, image.height) < 28:
|
|
width, height = max(image.width, 28), max(image.height, 28)
|
|
image = image.resize((width, height), resample=Image.NEAREST)
|
|
|
|
if image.width / image.height > 200:
|
|
width, height = image.height * 180, image.height
|
|
image = image.resize((width, height), resample=Image.NEAREST)
|
|
|
|
if image.height / image.width > 200:
|
|
width, height = image.width, image.width * 180
|
|
image = image.resize((width, height), resample=Image.NEAREST)
|
|
|
|
return image
|
|
|
|
@override
|
|
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
|
results = []
|
|
for video in videos:
|
|
container = av.open(video, "r")
|
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
|
total_frames = video_stream.frames
|
|
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
|
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
|
frames: List["ImageObject"] = []
|
|
container.seek(0)
|
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
|
if frame_idx in sample_indices:
|
|
frames.append(frame.to_image())
|
|
|
|
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
|
|
frames.append(frames[-1])
|
|
|
|
frames = self._regularize_images(frames, **kwargs)
|
|
results.append(frames)
|
|
|
|
return results
|
|
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: Sequence[Dict[str, str]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> List[Dict[str, str]]:
|
|
self._validate_input(images, videos)
|
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
|
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
|
|
|
num_image_tokens, num_video_tokens = 0, 0
|
|
messages = deepcopy(messages)
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
if num_image_tokens >= len(image_grid_thw):
|
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
|
content = content.replace(
|
|
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
|
|
)
|
|
num_image_tokens += 1
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
if num_video_tokens >= len(video_grid_thw):
|
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
|
content = content.replace(
|
|
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
|
|
)
|
|
num_video_tokens += 1
|
|
|
|
message["content"] = content
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
if len(videos) != num_video_tokens:
|
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
imglens: Sequence[int],
|
|
vidlens: Sequence[int],
|
|
batch_ids: Sequence[List[int]],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
self._validate_input(images, videos)
|
|
return self._get_mm_inputs(images, videos, processor)
|
|
|
|
|
|
class VideoLlavaPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: Sequence[Dict[str, str]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> List[Dict[str, str]]:
|
|
self._validate_input(images, videos)
|
|
num_image_tokens, num_video_tokens = 0, 0
|
|
messages = deepcopy(messages)
|
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
|
num_frames = 0
|
|
has_images = "pixel_values_images" in mm_inputs
|
|
has_videos = "pixel_values_videos" in mm_inputs
|
|
if has_images or has_videos:
|
|
if self.expand_mm_tokens:
|
|
if has_images:
|
|
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
|
num_frames = 1
|
|
|
|
if has_videos:
|
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
|
height, width = get_image_size(pixel_values_video[0])
|
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
|
|
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
|
video_seqlen = image_seqlen * num_frames
|
|
if getattr(processor, "vision_feature_select_strategy") == "default":
|
|
image_seqlen -= 1
|
|
else:
|
|
image_seqlen, video_seqlen = 1, 1
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
num_image_tokens += 1
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
num_video_tokens += 1
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
|
|
|
content = content.replace("{{image}}", self.image_token)
|
|
message["content"] = content.replace("{{video}}", self.video_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
if len(videos) != num_video_tokens:
|
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
imglens: Sequence[int],
|
|
vidlens: Sequence[int],
|
|
batch_ids: Sequence[List[int]],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
self._validate_input(images, videos)
|
|
return self._get_mm_inputs(images, videos, processor)
|
|
|
|
|
|
class MllamaPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: Sequence[Dict[str, str]],
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> List[Dict[str, str]]:
|
|
self._validate_input(images, videos)
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
for message in messages:
|
|
content = message["content"]
|
|
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
|
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def _get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
processor: "ProcessorMixin",
|
|
) -> Dict[str, "torch.Tensor"]:
|
|
r"""
|
|
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
|
|
|
Returns:
|
|
pixel_values: tensor with shape
|
|
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
|
For example, (2, 1, 4, 3, 560, 560).
|
|
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
|
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
|
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
|
"""
|
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
|
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
|
return image_processor([[image] for image in images], return_tensors="pt")
|
|
|
|
def get_mm_inputs(
|
|
self,
|
|
images: Sequence["ImageInput"],
|
|
videos: Sequence["VideoInput"],
|
|
imglens: Sequence[int],
|
|
vidlens: Sequence[int],
|
|
batch_ids: Sequence[List[int]],
|
|
processor: Optional["ProcessorMixin"],
|
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
self._validate_input(images, videos)
|
|
if len(images) != len(batch_ids):
|
|
raise ValueError("Mllama only supports one image per sample.")
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
|
num_tiles = mm_inputs.pop("num_tiles")
|
|
image_token_id = getattr(processor, "image_token_id")
|
|
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
|
cross_attention_token_mask = [
|
|
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
|
]
|
|
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
|
convert_sparse_cross_attention_mask_to_dense(
|
|
cross_attention_token_mask,
|
|
num_tiles=num_tiles,
|
|
max_num_tiles=max_image_tiles,
|
|
length=max(len(input_ids) for input_ids in batch_ids),
|
|
)
|
|
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
|
return mm_inputs
|
|
|
|
|
|
PLUGINS = {
|
|
"base": BasePlugin,
|
|
"llava": LlavaPlugin,
|
|
"llava_next": LlavaNextPlugin,
|
|
"llava_next_video": LlavaNextVideoPlugin,
|
|
"paligemma": PaliGemmaPlugin,
|
|
"pixtral": PixtralPlugin,
|
|
"qwen2_vl": Qwen2vlPlugin,
|
|
"video_llava": VideoLlavaPlugin,
|
|
"mllama": MllamaPlugin,
|
|
}
|
|
|
|
|
|
def get_mm_plugin(
|
|
name: str,
|
|
image_token: Optional[str] = None,
|
|
video_token: Optional[str] = None,
|
|
) -> "BasePlugin":
|
|
plugin_class = PLUGINS.get(name, None)
|
|
if plugin_class is None:
|
|
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
|
|
|
return plugin_class(image_token, video_token)
|