mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
* preserve image_sizes * preserve image_sizes * init plugin * support audio-text2text lora * nit * support image/video-text2text, audio-text2text * remove args * remove lines * add docs && nit * remove some comments * fix && add merge part script * add license
1542 lines
63 KiB
Python
1542 lines
63 KiB
Python
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
|
|
#
|
|
# This code is inspired by the HuggingFace's Transformers library.
|
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import inspect
|
|
import math
|
|
import re
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass
|
|
from io import BytesIO
|
|
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from transformers.image_utils import get_image_size, to_numpy_array
|
|
from typing_extensions import override
|
|
|
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
|
from ..extras.packages import (
|
|
is_librosa_available,
|
|
is_pillow_available,
|
|
is_pyav_available,
|
|
is_transformers_version_greater_than,
|
|
)
|
|
|
|
|
|
if is_librosa_available():
|
|
import librosa
|
|
|
|
|
|
if is_pillow_available():
|
|
from PIL import Image
|
|
from PIL.Image import Image as ImageObject
|
|
|
|
|
|
if is_pyav_available():
|
|
import av
|
|
|
|
|
|
if is_transformers_version_greater_than("4.45.0"):
|
|
from transformers.models.mllama.processing_mllama import (
|
|
convert_sparse_cross_attention_mask_to_dense,
|
|
get_cross_attention_token_mask,
|
|
)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from av.stream import Stream
|
|
from numpy.typing import NDArray
|
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
|
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
|
from transformers.image_processing_utils import BaseImageProcessor
|
|
|
|
class EncodedImage(TypedDict):
|
|
path: Optional[str]
|
|
bytes: Optional[bytes]
|
|
|
|
ImageInput = Union[str, bytes, EncodedImage, ImageObject]
|
|
VideoInput = str
|
|
AudioInput = Union[str, NDArray]
|
|
|
|
class MMProcessor(ProcessorMixin):
|
|
patch_size: int
|
|
image_seq_length: int
|
|
num_additional_image_tokens: int
|
|
vision_feature_select_strategy: Literal["default", "full"]
|
|
|
|
def _get_number_of_features(self, orig_height: int, orig_width: int, height: int, width: int) -> int:
|
|
pass
|
|
|
|
|
|
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
|
|
r"""Get paligemma token type ids for computing loss.
|
|
|
|
It is slightly different with the original token type ids where the prompt part is 0.
|
|
|
|
Returns:
|
|
batch_token_type_ids: shape (batch_size, seq_length)
|
|
|
|
"""
|
|
batch_token_type_ids = []
|
|
for imglen, seqlen in zip(imglens, seqlens):
|
|
image_seqlen = imglen * processor.image_seq_length
|
|
batch_token_type_ids.append([0] * image_seqlen + [1] * (seqlen - image_seqlen))
|
|
|
|
return batch_token_type_ids
|
|
|
|
|
|
def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcessor"):
|
|
r"""Get gemma3 token type ids for computing loss.
|
|
|
|
Returns:
|
|
batch_token_type_ids: shape (batch_size, seq_length)
|
|
|
|
"""
|
|
image_token_id: int = getattr(processor, "image_token_id")
|
|
batch_token_type_ids = []
|
|
for token_ids in batch_ids:
|
|
token_ids = np.array(token_ids)
|
|
token_type_ids = np.zeros_like(token_ids)
|
|
token_type_ids[token_ids == image_token_id] = 1
|
|
batch_token_type_ids.append(token_type_ids.tolist())
|
|
|
|
return batch_token_type_ids
|
|
|
|
|
|
def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
|
|
r"""Make nested list of images."""
|
|
batch_images = []
|
|
for imglen in imglens:
|
|
batch_images.append(images[:imglen])
|
|
images = images[imglen:]
|
|
|
|
return batch_images
|
|
|
|
|
|
@dataclass
|
|
class MMPluginMixin:
|
|
image_token: Optional[str]
|
|
video_token: Optional[str]
|
|
audio_token: Optional[str]
|
|
expand_mm_tokens: bool = True
|
|
|
|
def _validate_input(
|
|
self,
|
|
processor: Optional["MMProcessor"],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
) -> None:
|
|
r"""Validate if this model accepts the input modalities."""
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
|
video_processor: BaseImageProcessor = getattr(
|
|
processor, "video_processor", getattr(processor, "image_processor", None)
|
|
)
|
|
if image_processor is None and video_processor is None: # hack for qwen2_5_omni
|
|
image_processor, video_processor = (
|
|
getattr(processor, "omni_processor", None),
|
|
getattr(processor, "omni_processor", None),
|
|
)
|
|
|
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
|
if len(images) != 0 and self.image_token is None:
|
|
raise ValueError(
|
|
"This model does not support image input. Please check whether the correct `template` is used."
|
|
)
|
|
|
|
if len(videos) != 0 and self.video_token is None:
|
|
raise ValueError(
|
|
"This model does not support video input. Please check whether the correct `template` is used."
|
|
)
|
|
|
|
if len(audios) != 0 and self.audio_token is None:
|
|
raise ValueError(
|
|
"This model does not support audio input. Please check whether the correct `template` is used."
|
|
)
|
|
|
|
if self.image_token is not None and processor is None:
|
|
raise ValueError("Processor was not found, please check and update your processor config.")
|
|
|
|
if self.image_token is not None and image_processor is None:
|
|
raise ValueError("Image processor was not found, please check and update your processor config.")
|
|
|
|
if self.video_token is not None and video_processor is None:
|
|
raise ValueError("Video processor was not found, please check and update your processor config.")
|
|
|
|
if self.audio_token is not None and feature_extractor is None:
|
|
raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
|
|
|
|
def _preprocess_image(
|
|
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
|
|
) -> "ImageObject":
|
|
r"""Pre-process a single image."""
|
|
if (image.width * image.height) > image_max_pixels:
|
|
resize_factor = math.sqrt(image_max_pixels / (image.width * image.height))
|
|
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
|
image = image.resize((width, height))
|
|
|
|
if (image.width * image.height) < image_min_pixels:
|
|
resize_factor = math.sqrt(image_min_pixels / (image.width * image.height))
|
|
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
|
image = image.resize((width, height))
|
|
|
|
if image.mode != "RGB":
|
|
image = image.convert("RGB")
|
|
|
|
return image
|
|
|
|
def _get_video_sample_indices(
|
|
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
|
|
) -> list[int]:
|
|
r"""Compute video sample indices according to fps."""
|
|
total_frames = video_stream.frames
|
|
if total_frames == 0: # infinite video
|
|
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
|
|
|
|
sample_frames = math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)
|
|
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
|
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
|
|
|
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> list["ImageObject"]:
|
|
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
|
results = []
|
|
for image in images:
|
|
if isinstance(image, str):
|
|
image = Image.open(image)
|
|
elif isinstance(image, bytes):
|
|
image = Image.open(BytesIO(image))
|
|
elif isinstance(image, dict):
|
|
if image["bytes"] is not None:
|
|
image = Image.open(BytesIO(image["bytes"]))
|
|
else:
|
|
image = Image.open(image["path"])
|
|
|
|
if not isinstance(image, ImageObject):
|
|
raise ValueError(f"Expect input is a list of images, but got {type(image)}.")
|
|
|
|
results.append(self._preprocess_image(image, **kwargs))
|
|
|
|
return results
|
|
|
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
|
|
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
|
results = []
|
|
for video in videos:
|
|
container = av.open(video, "r")
|
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
|
frames: list[ImageObject] = []
|
|
container.seek(0)
|
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
|
if frame_idx in sample_indices:
|
|
frames.append(frame.to_image())
|
|
|
|
frames = self._regularize_images(frames, **kwargs)
|
|
results.append(frames)
|
|
|
|
return results
|
|
|
|
def _regularize_audios(self, audios: list["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
|
|
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
|
results = []
|
|
for audio in audios:
|
|
if isinstance(audio, str):
|
|
audio = librosa.load(audio, sr=sampling_rate)[0]
|
|
|
|
if not isinstance(audio, np.ndarray):
|
|
raise ValueError(f"Expect input is a list of audios, but got {type(audio)}.")
|
|
|
|
results.append(audio)
|
|
|
|
return results
|
|
|
|
def _get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: "MMProcessor",
|
|
imglens: Optional[list[int]] = None,
|
|
) -> dict[str, "torch.Tensor"]:
|
|
r"""Process visual inputs.
|
|
|
|
Returns: (llava and paligemma)
|
|
pixel_values: tensor with shape (B, C, H, W)
|
|
|
|
Returns: (qwen2-vl)
|
|
pixel_values: tensor with shape (num_patches, patch_dim)
|
|
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
|
|
where num_patches == torch.prod(image_grid_thw)
|
|
|
|
Returns: (mllama)
|
|
pixel_values: tensor with shape
|
|
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
|
For example, (2, 1, 4, 3, 560, 560).
|
|
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
|
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
|
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
|
|
|
"""
|
|
mm_inputs = {}
|
|
if len(images) != 0:
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
|
images = self._regularize_images(
|
|
images,
|
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
|
)
|
|
if imglens is not None:
|
|
images = _make_batched_images(images, imglens)
|
|
|
|
image_processor_kwargs = {}
|
|
if getattr(processor, "image_do_pan_and_scan", False): # gemma3 image processor
|
|
image_processor_kwargs.update(
|
|
{
|
|
"do_pan_and_scan": True,
|
|
"pan_and_scan_min_crop_size": 256,
|
|
"pan_and_scan_max_num_crops": 4,
|
|
"pan_and_scan_min_ratio_to_activate": 1.2,
|
|
}
|
|
)
|
|
|
|
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
|
|
|
|
if len(videos) != 0:
|
|
video_processor: BaseImageProcessor = getattr(
|
|
processor, "video_processor", getattr(processor, "image_processor", None)
|
|
)
|
|
videos = self._regularize_videos(
|
|
videos,
|
|
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
|
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
|
video_fps=getattr(processor, "video_fps", 2.0),
|
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
|
)
|
|
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
|
|
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
|
else: # for llava_next_video
|
|
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
|
|
|
if len(audios) != 0:
|
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
|
audios = self._regularize_audios(
|
|
audios,
|
|
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
|
)
|
|
mm_inputs.update(
|
|
feature_extractor(
|
|
audios,
|
|
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
|
return_attention_mask=True,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
)
|
|
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
|
|
|
return mm_inputs
|
|
|
|
|
|
@dataclass
|
|
class BasePlugin(MMPluginMixin):
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
r"""Pre-process input messages before tokenization for VLMs."""
|
|
self._validate_input(processor, images, videos, audios)
|
|
return messages
|
|
|
|
def process_token_ids(
|
|
self,
|
|
input_ids: list[int],
|
|
labels: Optional[list[int]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
tokenizer: "PreTrainedTokenizer",
|
|
processor: Optional["MMProcessor"],
|
|
) -> tuple[list[int], Optional[list[int]]]:
|
|
r"""Pre-process token ids after tokenization for VLMs."""
|
|
self._validate_input(processor, images, videos, audios)
|
|
return input_ids, labels
|
|
|
|
def get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
imglens: list[int],
|
|
vidlens: list[int],
|
|
audlens: list[int],
|
|
batch_ids: list[list[int]],
|
|
processor: Optional["MMProcessor"],
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
r"""Build batched multimodal inputs for VLMs.
|
|
|
|
Arguments:
|
|
images: a list of image inputs, shape (num_images,)
|
|
videos: a list of video inputs, shape (num_videos,)
|
|
audios: a list of audio inputs, shape (num_audios,)
|
|
imglens: number of images in each sample, shape (batch_size,)
|
|
vidlens: number of videos in each sample, shape (batch_size,)
|
|
audlens: number of audios in each sample, shape (batch_size,)
|
|
batch_ids: token ids of input samples, shape (batch_size, seq_len)
|
|
processor: a processor for pre-processing images and videos
|
|
|
|
"""
|
|
self._validate_input(processor, images, videos, audios)
|
|
return self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
|
@dataclass
|
|
class Gemma3Plugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
boi_token: str = getattr(processor, "boi_token")
|
|
full_image_sequence: str = getattr(processor, "full_image_sequence")
|
|
image_str = full_image_sequence if self.expand_mm_tokens else boi_token
|
|
|
|
do_pan_and_scan: bool = getattr(processor, "image_do_pan_and_scan", False)
|
|
if do_pan_and_scan:
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
if do_pan_and_scan:
|
|
image_placeholder_str = (
|
|
"Here is the original image {{image}} and here are some crops to help you see better "
|
|
+ " ".join(["{{image}}"] * mm_inputs["num_crops"][0][num_image_tokens])
|
|
)
|
|
else:
|
|
image_placeholder_str = "{{image}}"
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, image_placeholder_str, 1)
|
|
num_image_tokens += 1
|
|
|
|
message["content"] = content.replace("{{image}}", image_str)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
imglens: list[int],
|
|
vidlens: list[int],
|
|
audlens: list[int],
|
|
batch_ids: list[list[int]],
|
|
processor: Optional["MMProcessor"],
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
mm_inputs.pop("num_crops", None)
|
|
mm_inputs["token_type_ids"] = _get_gemma3_token_type_ids(batch_ids, processor)
|
|
return mm_inputs
|
|
|
|
|
|
@dataclass
|
|
class LlavaPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
if self.expand_mm_tokens:
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
if "pixel_values" in mm_inputs:
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0]))
|
|
image_seqlen = (height // processor.patch_size) * (
|
|
width // processor.patch_size
|
|
) + processor.num_additional_image_tokens
|
|
if processor.vision_feature_select_strategy == "default":
|
|
image_seqlen -= 1
|
|
else:
|
|
image_seqlen = 1
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
num_image_tokens += 1
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
|
|
@dataclass
|
|
class LlavaNextPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
if self.expand_mm_tokens:
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
if "pixel_values" in mm_inputs:
|
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
if self.expand_mm_tokens:
|
|
orig_height, orig_width = next(image_sizes)
|
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
|
if processor.vision_feature_select_strategy == "default":
|
|
image_seqlen -= 1
|
|
else:
|
|
image_seqlen = 1
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
num_image_tokens += 1
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
|
|
@dataclass
|
|
class LlavaNextVideoPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens, num_video_tokens = 0, 0
|
|
messages = deepcopy(messages)
|
|
if self.expand_mm_tokens:
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
if "pixel_values" in mm_inputs:
|
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
if self.expand_mm_tokens:
|
|
orig_height, orig_width = next(image_sizes)
|
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
|
if processor.vision_feature_select_strategy == "default":
|
|
image_seqlen -= 1
|
|
else:
|
|
image_seqlen = 1
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
num_image_tokens += 1
|
|
|
|
message["content"] = content.replace("{{image}}", self.image_token)
|
|
|
|
if self.expand_mm_tokens:
|
|
if "pixel_values_videos" in mm_inputs:
|
|
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
|
height, width = get_image_size(one_video[0])
|
|
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
|
else:
|
|
video_seqlen = 1
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while VIDEO_PLACEHOLDER in content:
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
|
num_video_tokens += 1
|
|
|
|
message["content"] = content.replace("{{video}}", self.video_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
if len(videos) != num_video_tokens:
|
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
|
|
@dataclass
|
|
class MiniCPMVPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
|
messages = deepcopy(messages)
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
mm_inputs = {}
|
|
audio_inputs = {}
|
|
if len(images) != 0 and len(videos) != 0:
|
|
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
|
|
|
if len(videos) != 0:
|
|
max_slice_nums = 2
|
|
use_image_id = False
|
|
mm_inputs = self._get_mm_inputs([], videos, [], processor)
|
|
else:
|
|
max_slice_nums = image_processor.max_slice_nums
|
|
use_image_id = image_processor.use_image_id
|
|
|
|
for i, message in enumerate(messages):
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
|
num_image_tokens += 1
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
|
num_video_tokens += 1
|
|
|
|
while AUDIO_PLACEHOLDER in content:
|
|
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
|
num_audio_tokens += 1
|
|
|
|
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
|
|
"{{audio}}", "(<audio>./</audio>)"
|
|
)
|
|
|
|
if num_image_tokens > 0:
|
|
mm_inputs = self._get_mm_inputs(images, [], [], processor)
|
|
|
|
if num_audio_tokens > 0:
|
|
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
|
|
|
|
if mm_inputs:
|
|
pattern = "(<image>./</image>)"
|
|
image_sizes = mm_inputs["image_sizes"]
|
|
idx = 0
|
|
for index, message in enumerate(messages):
|
|
text = message["content"]
|
|
image_tags = re.findall(pattern, text)
|
|
text_chunks = text.split(pattern)
|
|
final_text = ""
|
|
for i in range(len(image_tags)):
|
|
final_text = (
|
|
final_text
|
|
+ text_chunks[i]
|
|
+ image_processor.get_slice_image_placeholder(
|
|
image_sizes[0][idx], idx, max_slice_nums, use_image_id
|
|
)
|
|
)
|
|
idx += 1
|
|
|
|
final_text += text_chunks[-1]
|
|
messages[index]["content"] = final_text
|
|
|
|
if audio_inputs:
|
|
pattern = "(<audio>./</audio>)"
|
|
idx = 0
|
|
for index, message in enumerate(messages):
|
|
text = message["content"]
|
|
audio_tags = re.findall(pattern, text)
|
|
text_chunks = text.split(pattern)
|
|
final_text = ""
|
|
for i in range(len(audio_tags)):
|
|
audio_placeholder = audio_inputs["audio_phs"][0][idx]
|
|
final_text = final_text + text_chunks[i] + audio_placeholder
|
|
idx += 1
|
|
|
|
final_text += text_chunks[-1]
|
|
messages[index]["content"] = final_text
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
if len(videos) != num_video_tokens:
|
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
if len(audios) != num_audio_tokens:
|
|
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def _get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: "MMProcessor",
|
|
**kwargs,
|
|
) -> dict[str, "torch.Tensor"]:
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
|
mm_inputs = {}
|
|
if len(images) != 0:
|
|
images = self._regularize_images(
|
|
images,
|
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
|
)
|
|
if "valid_image_nums_ls" in kwargs:
|
|
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
|
|
new_images = []
|
|
idx = 0
|
|
for valid_image_nums in valid_image_nums_ls:
|
|
new_images.append(images[idx : idx + valid_image_nums])
|
|
idx += valid_image_nums
|
|
|
|
images = new_images
|
|
|
|
image_inputs = image_processor(
|
|
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
|
)
|
|
mm_inputs.update(image_inputs)
|
|
|
|
if len(videos) != 0:
|
|
videos = self._regularize_videos(
|
|
videos,
|
|
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
|
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
|
video_fps=getattr(processor, "video_fps", 2.0),
|
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
|
)
|
|
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
|
mm_inputs.update(video_inputs)
|
|
|
|
if len(audios) != 0:
|
|
audios = self._regularize_audios(
|
|
audios,
|
|
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
|
)
|
|
if "valid_audio_nums_ls" in kwargs:
|
|
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
|
|
audios_ls = []
|
|
idx = 0
|
|
for valid_audio_nums in valid_audio_nums_ls:
|
|
audios_ls.append(audios[idx : idx + valid_audio_nums])
|
|
idx += valid_audio_nums
|
|
else:
|
|
audios_ls = [audios]
|
|
|
|
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
|
|
audios_ls,
|
|
chunk_input=True,
|
|
sampling_rate=16000,
|
|
)
|
|
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
|
|
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
|
if kwargs.get("ret_phs", False):
|
|
mm_inputs.update({"audio_phs": audio_phs})
|
|
|
|
return mm_inputs
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
imglens: list[int],
|
|
vidlens: list[int],
|
|
audlens: list[int],
|
|
batch_ids: list[list[int]],
|
|
processor: Optional["MMProcessor"],
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
# image bound
|
|
image_bounds_list = []
|
|
valid_image_nums_ls = []
|
|
for i, input_ids in enumerate(batch_ids):
|
|
input_ids_ = torch.tensor(input_ids)
|
|
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
|
input_ids_ == processor.tokenizer.slice_start_id
|
|
)
|
|
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
|
image_start_tokens = torch.where(start_cond)[0]
|
|
image_start_tokens += 1
|
|
image_end_tokens = torch.where(end_cond)[0]
|
|
valid_image_nums_ls.append(imglens[i])
|
|
image_bounds = torch.hstack(
|
|
[
|
|
image_start_tokens.unsqueeze(-1),
|
|
image_end_tokens.unsqueeze(-1),
|
|
]
|
|
)
|
|
image_bounds_list.append(image_bounds)
|
|
|
|
mm_inputs = self._get_mm_inputs(images, videos, [], processor, valid_image_nums_ls=valid_image_nums_ls)
|
|
if "tgt_sizes" not in mm_inputs:
|
|
dummy_data = [torch.empty(0) for _ in range(len(batch_ids))]
|
|
mm_inputs.update({"tgt_sizes": dummy_data, "pixel_values": dummy_data, "image_sizes": dummy_data})
|
|
|
|
mm_inputs.update({"image_bound": image_bounds_list})
|
|
|
|
if len(audios) > 0:
|
|
# audio bound
|
|
audio_bounds_ls = []
|
|
spk_bounds_ls = []
|
|
valid_audio_nums_ls = []
|
|
|
|
for input_ids, audiolen in zip(batch_ids, audlens):
|
|
input_ids_ = torch.tensor(input_ids)
|
|
audio_start_idx = torch.where(input_ids_ == processor.tokenizer.audio_start_id)[0]
|
|
audio_end_idx = torch.where(input_ids_ == processor.tokenizer.audio_end_id)[0]
|
|
assert len(audio_start_idx) == len(audio_end_idx)
|
|
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
|
|
audio_bounds_ls.append(audio_bounds)
|
|
valid_audio_nums_ls.append(audiolen)
|
|
|
|
spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
|
|
spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
|
|
assert len(spk_start_idx) == len(spk_end_idx)
|
|
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
|
spk_bounds_ls.append(spk_bounds)
|
|
|
|
audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls)
|
|
mm_inputs.update(audio_inputs)
|
|
mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
|
|
|
|
return mm_inputs
|
|
|
|
|
|
@dataclass
|
|
class MllamaPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
for message in messages:
|
|
content = message["content"]
|
|
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
|
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
imglens: list[int],
|
|
vidlens: list[int],
|
|
audlens: list[int],
|
|
batch_ids: list[list[int]],
|
|
processor: Optional["MMProcessor"],
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
|
|
if mm_inputs:
|
|
num_tiles = mm_inputs.pop("num_tiles")
|
|
image_token_id: int = getattr(processor, "image_token_id")
|
|
max_image_tiles: int = getattr(processor.image_processor, "max_image_tiles")
|
|
cross_attention_token_mask = [
|
|
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
|
]
|
|
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
|
convert_sparse_cross_attention_mask_to_dense(
|
|
cross_attention_token_mask,
|
|
num_tiles=num_tiles,
|
|
max_num_tiles=max_image_tiles,
|
|
length=max(len(input_ids) for input_ids in batch_ids),
|
|
)
|
|
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
|
|
|
return mm_inputs
|
|
|
|
|
|
@dataclass
|
|
class PaliGemmaPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
content = content.replace(IMAGE_PLACEHOLDER, "", 1)
|
|
num_image_tokens += 1
|
|
|
|
message["content"] = content
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def process_token_ids(
|
|
self,
|
|
input_ids: list[int],
|
|
labels: Optional[list[int]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
tokenizer: "PreTrainedTokenizer",
|
|
processor: Optional["MMProcessor"],
|
|
) -> tuple[list[int], Optional[list[int]]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_images = len(images)
|
|
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
|
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
|
input_ids = [image_token_id] * num_images * image_seqlen + input_ids
|
|
if labels is not None:
|
|
labels = [IGNORE_INDEX] * num_images * image_seqlen + labels
|
|
|
|
return input_ids, labels
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
imglens: list[int],
|
|
vidlens: list[int],
|
|
audlens: list[int],
|
|
batch_ids: list[list[int]],
|
|
processor: Optional["MMProcessor"],
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
seqlens = [len(input_ids) for input_ids in batch_ids]
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
|
return mm_inputs
|
|
|
|
|
|
@dataclass
|
|
class PixtralPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens = 0
|
|
messages = deepcopy(messages)
|
|
if self.expand_mm_tokens:
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
if "pixel_values" in mm_inputs:
|
|
# BC for transformers < 4.49.0
|
|
if isinstance(mm_inputs["image_sizes"], list):
|
|
image_sizes = iter(mm_inputs["image_sizes"][0])
|
|
else:
|
|
image_sizes = iter(mm_inputs["image_sizes"].tolist())
|
|
image_break_token: str = getattr(processor, "image_break_token")
|
|
image_end_token: str = getattr(processor, "image_end_token")
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
if self.expand_mm_tokens:
|
|
height, width = next(image_sizes)
|
|
num_height_tokens = height // processor.patch_size
|
|
num_width_tokens = width // processor.patch_size
|
|
replace_tokens = [[self.image_token] * num_width_tokens + [image_break_token]] * num_height_tokens
|
|
replace_tokens = [item for sublist in replace_tokens for item in sublist] # flatten list
|
|
replace_tokens[-1] = image_end_token
|
|
replace_str = "".join(replace_tokens)
|
|
else:
|
|
replace_str = self.image_token
|
|
|
|
content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
|
|
num_image_tokens += 1
|
|
|
|
message["content"] = content
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
imglens: list[int],
|
|
vidlens: list[int],
|
|
audlens: list[int],
|
|
batch_ids: list[list[int]],
|
|
processor: Optional["MMProcessor"],
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
# ref to this commit https://github.com/huggingface/transformers/pull/35122
|
|
# after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding.
|
|
# it can be passed into `LlavaConditionalGeneration` as a parameter.
|
|
if not is_transformers_version_greater_than("4.49.0"):
|
|
mm_inputs.pop("image_sizes", None)
|
|
return mm_inputs
|
|
|
|
|
|
@dataclass
|
|
class Qwen2AudioPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
bos_token: str = getattr(processor, "audio_bos_token")
|
|
eos_token: str = getattr(processor, "audio_eos_token")
|
|
num_audio_tokens = 0
|
|
messages = deepcopy(messages)
|
|
if self.expand_mm_tokens:
|
|
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
|
if "feature_attention_mask" in mm_inputs:
|
|
audio_lengths = mm_inputs["feature_attention_mask"].sum(-1).tolist()
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while AUDIO_PLACEHOLDER in content:
|
|
if self.expand_mm_tokens:
|
|
audio_length = audio_lengths.pop(0)
|
|
input_length = (audio_length - 1) // 2 + 1
|
|
audio_seqlen = (input_length - 2) // 2 + 1
|
|
else:
|
|
audio_seqlen = 1
|
|
|
|
content = content.replace(
|
|
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
|
|
)
|
|
num_audio_tokens += 1
|
|
|
|
message["content"] = content
|
|
|
|
if len(audios) != num_audio_tokens:
|
|
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
imglens: list[int],
|
|
vidlens: list[int],
|
|
audlens: list[int],
|
|
batch_ids: list[list[int]],
|
|
processor: Optional["MMProcessor"],
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
return self._get_mm_inputs(images, videos, audios, processor)
|
|
|
|
|
|
class Qwen2OmniPlugin(BasePlugin):
|
|
@override
|
|
def _get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: "MMProcessor",
|
|
imglens: Optional[list[int]] = None,
|
|
) -> dict[str, "torch.Tensor"]:
|
|
mm_inputs = {}
|
|
if len(images) != 0:
|
|
image_processor: BaseImageProcessor = getattr(processor, "omni_processor", None) # FIXME
|
|
images = self._regularize_images(
|
|
images,
|
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
|
)
|
|
if imglens is not None:
|
|
images = _make_batched_images(images, imglens)
|
|
|
|
image_processor_kwargs = {}
|
|
mm_inputs.update(image_processor(images, return_tensors="pt", **image_processor_kwargs))
|
|
|
|
if len(videos) != 0:
|
|
video_processor: BaseImageProcessor = getattr(
|
|
processor, "video_processor", getattr(processor, "omni_processor", None)
|
|
)
|
|
videos = self._regularize_videos(
|
|
videos,
|
|
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
|
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
|
video_fps=getattr(processor, "video_fps", 2.0),
|
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
|
)
|
|
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
|
|
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
|
fps = [2.0] * len(videos) # FIXME hardcode
|
|
video_second_per_grid = [fps[i] / video_processor.temporal_patch_size for i in range(len(fps))]
|
|
mm_inputs["video_second_per_grid"] = torch.tensor(video_second_per_grid)
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if len(audios) != 0:
|
|
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
|
audios = self._regularize_audios(
|
|
audios,
|
|
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
|
)
|
|
mm_inputs.update(
|
|
feature_extractor(
|
|
audios,
|
|
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
|
return_attention_mask=True,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
)
|
|
mm_inputs["feature_attention_mask"] = mm_inputs.pop("attention_mask") # prevent conflicts
|
|
|
|
return mm_inputs
|
|
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
messages = deepcopy(messages)
|
|
if self.expand_mm_tokens:
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
num_audio_tokens, num_image_tokens, num_video_tokens = 0, 0, 0
|
|
use_audio_in_video = getattr(processor, "use_audio_in_video", False)
|
|
|
|
# get length or size from mm_inputs
|
|
if "feature_attention_mask" in mm_inputs:
|
|
input_lengths = (mm_inputs["feature_attention_mask"].sum(-1).numpy() - 1) // 2 + 1
|
|
audio_lengths = (input_lengths - 2) // 2 + 1
|
|
if mm_inputs.get("image_grid_thw", None) is not None:
|
|
image_grid_thw = mm_inputs["image_grid_thw"]
|
|
merge_length = processor.omni_processor.merge_size**2
|
|
if mm_inputs.get("video_grid_thw", None) is not None:
|
|
video_grid_thw = mm_inputs["video_grid_thw"]
|
|
merge_length = processor.omni_processor.merge_size**2
|
|
|
|
if use_audio_in_video:
|
|
assert audio_lengths is not None, "audio_lengths should be exist when use_audio_in_video is `True`"
|
|
assert mm_inputs.get("video_grid_thw", None) is not None, (
|
|
"video_grid_thw should be exist when use_audio_in_video is `True`"
|
|
)
|
|
positions_list = []
|
|
for i, message in enumerate(messages): # get multimodal index when use_audio
|
|
positions = []
|
|
for special_token in [self.audio_token, self.image_token, self.video_token]:
|
|
start = 0
|
|
while True:
|
|
pos = message[i].find(special_token, start)
|
|
if pos == -1:
|
|
break
|
|
positions.append((pos, special_token))
|
|
start = pos + len(special_token)
|
|
positions_list.append(positions.sort(key=lambda x: x[0]))
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
# separate with audio-video
|
|
while IMAGE_PLACEHOLDER in content:
|
|
image_token_replace_length = image_grid_thw[num_image_tokens].prod() // merge_length
|
|
content = content.replace(
|
|
IMAGE_PLACEHOLDER,
|
|
f"<|vision_bos|>{self.image_token * image_token_replace_length}<|vision_eos|>",
|
|
1,
|
|
)
|
|
num_image_tokens += 1
|
|
|
|
if not use_audio_in_video:
|
|
while AUDIO_PLACEHOLDER in content:
|
|
audio_token_replace_length = audio_lengths[num_audio_tokens]
|
|
content = content.replace(
|
|
AUDIO_PLACEHOLDER,
|
|
f"<|audio_bos|>{self.audio_token * audio_token_replace_length}<|audio_eos|>",
|
|
1,
|
|
)
|
|
num_audio_tokens += 1
|
|
# TODO handle video_input and use_audio_in_video
|
|
while VIDEO_PLACEHOLDER in content:
|
|
video_replace_length = video_grid_thw[num_video_tokens].prod() // merge_length
|
|
content = content.replace(
|
|
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_replace_length}<|vision_eos|>", 1
|
|
)
|
|
num_video_tokens += 1
|
|
else: # if use the audio of video # deal video token and audio token togather
|
|
while VIDEO_PLACEHOLDER in content:
|
|
audio_t_index = torch.arange(audio_lengths[num_audio_tokens])
|
|
video_t_index = (
|
|
torch.arange(video_grid_thw[num_video_tokens][0])
|
|
.view(-1, 1, 1)
|
|
.expand(
|
|
-1,
|
|
video_grid_thw[num_video_tokens][1] // self.omni_processor.merge_size,
|
|
video_grid_thw[num_video_tokens][2] // self.omni_processor.merge_size,
|
|
)
|
|
.flatten()
|
|
* mm_inputs["video_second_per_grid"][num_video_tokens]
|
|
* 25 # FIXME hardcode of position_id_per_seconds=25
|
|
).long()
|
|
t_ntoken_per_chunk = 50 # FIXME hardcode: [25 * 2]
|
|
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
|
audio_chunk_indices = self.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
|
placeholder_string = ""
|
|
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
|
|
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
|
|
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
|
placeholder_string = "<|vision_bos|>" + "<|audio_bos|>"
|
|
if video_chunk_index is not None:
|
|
placeholder_string += self.video_token * (video_chunk_index[1] - video_chunk_index[0])
|
|
if audio_chunk_index is not None:
|
|
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
|
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
|
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
|
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
|
num_audio_tokens += 1
|
|
num_video_tokens += 1
|
|
message["content"] = content
|
|
|
|
if len(audios) != num_audio_tokens:
|
|
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
if len(videos) != num_video_tokens:
|
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
|
|
@dataclass
|
|
class Qwen2VLPlugin(BasePlugin):
|
|
@override
|
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
|
image = super()._preprocess_image(image, **kwargs)
|
|
if min(image.width, image.height) < 28:
|
|
width, height = max(image.width, 28), max(image.height, 28)
|
|
image = image.resize((width, height))
|
|
|
|
if image.width / image.height > 200:
|
|
width, height = image.height * 180, image.height
|
|
image = image.resize((width, height))
|
|
|
|
if image.height / image.width > 200:
|
|
width, height = image.width, image.width * 180
|
|
image = image.resize((width, height))
|
|
|
|
return image
|
|
|
|
@override
|
|
def _regularize_videos(
|
|
self, videos: list["VideoInput"], **kwargs
|
|
) -> tuple[list[list["ImageObject"]], list[float]]:
|
|
results, fps_per_video = [], []
|
|
for video in videos:
|
|
container = av.open(video, "r")
|
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
|
frames: list[ImageObject] = []
|
|
container.seek(0)
|
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
|
if frame_idx in sample_indices:
|
|
frames.append(frame.to_image())
|
|
|
|
if len(frames) % 2 != 0: # qwen2-vl requires even number of frames
|
|
frames.append(frames[-1])
|
|
|
|
frames = self._regularize_images(frames, **kwargs)
|
|
results.append(frames)
|
|
if video_stream.duration is None:
|
|
fps_per_video.append(2.0)
|
|
else:
|
|
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
|
|
|
return results, fps_per_video
|
|
|
|
@override
|
|
def _get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: "MMProcessor",
|
|
) -> dict[str, "torch.Tensor"]:
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
|
mm_inputs = {}
|
|
if len(images) != 0:
|
|
images = self._regularize_images(
|
|
images,
|
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
|
)
|
|
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
|
|
|
if len(videos) != 0:
|
|
videos, fps_per_video = self._regularize_videos(
|
|
videos,
|
|
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
|
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
|
video_fps=getattr(processor, "video_fps", 2.0),
|
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
|
)
|
|
mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
|
|
mm_inputs["fps_per_video"] = fps_per_video
|
|
|
|
return mm_inputs
|
|
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens, num_video_tokens = 0, 0
|
|
messages = deepcopy(messages)
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
|
|
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
|
if self.expand_mm_tokens:
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
|
else:
|
|
image_grid_thw = [None] * len(images)
|
|
video_grid_thw = [None] * len(videos)
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
if num_image_tokens >= len(image_grid_thw):
|
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
|
content = content.replace(
|
|
IMAGE_PLACEHOLDER, f"<|vision_start|>{self.image_token * image_seqlen}<|vision_end|>", 1
|
|
)
|
|
num_image_tokens += 1
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
if num_video_tokens >= len(video_grid_thw):
|
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
|
content = content.replace(
|
|
VIDEO_PLACEHOLDER, f"<|vision_start|>{self.video_token * video_seqlen}<|vision_end|>", 1
|
|
)
|
|
num_video_tokens += 1
|
|
|
|
message["content"] = content
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
if len(videos) != num_video_tokens:
|
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
@override
|
|
def get_mm_inputs(
|
|
self,
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
imglens: list[int],
|
|
vidlens: list[int],
|
|
audlens: list[int],
|
|
batch_ids: list[list[int]],
|
|
processor: Optional["MMProcessor"],
|
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
fps_per_video = mm_inputs.pop("fps_per_video", [])
|
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
|
if "second_per_grid_ts" in processor.model_input_names and fps_per_video:
|
|
mm_inputs["second_per_grid_ts"] = [temporal_patch_size / fps for fps in fps_per_video]
|
|
|
|
return mm_inputs
|
|
|
|
|
|
@dataclass
|
|
class VideoLlavaPlugin(BasePlugin):
|
|
@override
|
|
def process_messages(
|
|
self,
|
|
messages: list[dict[str, str]],
|
|
images: list["ImageInput"],
|
|
videos: list["VideoInput"],
|
|
audios: list["AudioInput"],
|
|
processor: Optional["MMProcessor"],
|
|
) -> list[dict[str, str]]:
|
|
self._validate_input(processor, images, videos, audios)
|
|
num_image_tokens, num_video_tokens = 0, 0
|
|
messages = deepcopy(messages)
|
|
num_frames = 0
|
|
if self.expand_mm_tokens:
|
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
if "pixel_values_images" in mm_inputs:
|
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values_images"][0]))
|
|
num_frames = 1
|
|
|
|
if "pixel_values_videos" in mm_inputs:
|
|
one_video = to_numpy_array(mm_inputs["pixel_values_videos"][0])
|
|
height, width = get_image_size(one_video[0])
|
|
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
|
|
|
if "pixel_values_images" in mm_inputs or "pixel_values_videos" in mm_inputs:
|
|
image_seqlen = (height // processor.patch_size) * (
|
|
width // processor.patch_size
|
|
) + processor.num_additional_image_tokens
|
|
video_seqlen = image_seqlen * num_frames
|
|
if processor.vision_feature_select_strategy == "default":
|
|
image_seqlen -= 1
|
|
else:
|
|
image_seqlen, video_seqlen = 1, 1
|
|
|
|
for message in messages:
|
|
content = message["content"]
|
|
while IMAGE_PLACEHOLDER in content:
|
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
|
num_image_tokens += 1
|
|
|
|
while VIDEO_PLACEHOLDER in content:
|
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
|
num_video_tokens += 1
|
|
|
|
content = content.replace("{{image}}", self.image_token)
|
|
message["content"] = content.replace("{{video}}", self.video_token)
|
|
|
|
if len(images) != num_image_tokens:
|
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
|
|
if len(videos) != num_video_tokens:
|
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
|
|
return messages
|
|
|
|
|
|
PLUGINS = {
|
|
"base": BasePlugin,
|
|
"gemma3": Gemma3Plugin,
|
|
"llava": LlavaPlugin,
|
|
"llava_next": LlavaNextPlugin,
|
|
"llava_next_video": LlavaNextVideoPlugin,
|
|
"minicpm_v": MiniCPMVPlugin,
|
|
"mllama": MllamaPlugin,
|
|
"paligemma": PaliGemmaPlugin,
|
|
"pixtral": PixtralPlugin,
|
|
"qwen2_audio": Qwen2AudioPlugin,
|
|
"qwen2_omni": Qwen2OmniPlugin,
|
|
"qwen2_vl": Qwen2VLPlugin,
|
|
"video_llava": VideoLlavaPlugin,
|
|
}
|
|
|
|
|
|
def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
|
|
r"""Register a multimodal plugin."""
|
|
if name in PLUGINS:
|
|
raise ValueError(f"Multimodal plugin {name} already exists.")
|
|
|
|
PLUGINS[name] = plugin_class
|
|
|
|
|
|
def get_mm_plugin(
|
|
name: str,
|
|
image_token: Optional[str] = None,
|
|
video_token: Optional[str] = None,
|
|
audio_token: Optional[str] = None,
|
|
) -> "BasePlugin":
|
|
r"""Get plugin for multimodal inputs."""
|
|
if name not in PLUGINS:
|
|
raise ValueError(f"Multimodal plugin `{name}` not found.")
|
|
|
|
return PLUGINS[name](image_token, video_token, audio_token)
|