[data] refactor mm plugin (#6895)

* refactor plugin

* lint

Former-commit-id: aca63bfcca02ecd95b57cd8949a50e26a913f716
This commit is contained in:
hoshi-hiyouga 2025-02-11 16:34:49 +08:00 committed by GitHub
parent 188f22d8a7
commit 593acca556

View File

@ -2,6 +2,7 @@ import inspect
import math import math
import re import re
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
@ -72,12 +73,12 @@ def _get_paligemma_token_type_ids(
return batch_token_type_ids return batch_token_type_ids
class BasePlugin: @dataclass
def __init__(self, image_token: Optional[str], video_token: Optional[str], audio_token: Optional[str]) -> None: class MMPluginMixin:
self.image_token = image_token image_token: Optional[str]
self.video_token = video_token video_token: Optional[str]
self.audio_token = audio_token audio_token: Optional[str]
self.expand_mm_tokens = True expand_mm_tokens: bool = True
def _validate_input( def _validate_input(
self, self,
@ -103,11 +104,10 @@ class BasePlugin:
"This model does not support audio input. Please check whether the correct `template` is used." "This model does not support audio input. Please check whether the correct `template` is used."
) )
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": def _preprocess_image(self, image: "ImageObject", image_resolution: int, **kwargs) -> "ImageObject":
r""" r"""
Pre-processes a single image. Pre-processes a single image.
""" """
image_resolution: int = kwargs["image_resolution"]
if (image.width * image.height) > image_resolution: if (image.width * image.height) > image_resolution:
resize_factor = math.sqrt(image_resolution / (image.width * image.height)) resize_factor = math.sqrt(image_resolution / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor) width, height = int(image.width * resize_factor), int(image.height * resize_factor)
@ -118,12 +118,12 @@ class BasePlugin:
return image return image
def _get_video_sample_indices(self, video_stream: "Stream", **kwargs) -> List[int]: def _get_video_sample_indices(
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
) -> List[int]:
r""" r"""
Computes video sample indices according to fps. Computes video sample indices according to fps.
""" """
video_fps: float = kwargs["video_fps"]
video_maxlen: int = kwargs["video_maxlen"]
total_frames = video_stream.frames total_frames = video_stream.frames
if total_frames == 0: # infinite video if total_frames == 0: # infinite video
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
@ -175,12 +175,11 @@ class BasePlugin:
return results return results
def _regularize_audios(self, audios: Sequence["AudioInput"], **kwargs) -> List["NDArray"]: def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
r""" r"""
Regularizes audios to avoid error. Including reading and resampling. Regularizes audios to avoid error. Including reading and resampling.
""" """
results = [] results = []
sampling_rate = kwargs["sampling_rate"]
for audio in audios: for audio in audios:
if isinstance(audio, str): if isinstance(audio, str):
audio = librosa.load(audio, sr=sampling_rate)[0] audio = librosa.load(audio, sr=sampling_rate)[0]
@ -218,8 +217,7 @@ class BasePlugin:
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
images, images, image_resolution=getattr(processor, "image_resolution", 768 * 768)
image_resolution=getattr(processor, "image_resolution", 768 * 768),
) )
mm_inputs.update(image_processor(images, return_tensors="pt")) mm_inputs.update(image_processor(images, return_tensors="pt"))
@ -253,6 +251,9 @@ class BasePlugin:
return mm_inputs return mm_inputs
@dataclass
class BasePlugin(MMPluginMixin):
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -310,6 +311,7 @@ class BasePlugin:
return {} return {}
@dataclass
class LlavaPlugin(BasePlugin): class LlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -353,6 +355,7 @@ class LlavaPlugin(BasePlugin):
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaNextPlugin(BasePlugin): class LlavaNextPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -410,6 +413,7 @@ class LlavaNextPlugin(BasePlugin):
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class LlavaNextVideoPlugin(BasePlugin): class LlavaNextVideoPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -444,12 +448,15 @@ class LlavaNextVideoPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{image}}", self.image_token)
if "pixel_values_videos" in mm_inputs: if "pixel_values_videos" in mm_inputs:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) if self.expand_mm_tokens:
height, width = get_image_size(pixel_values_video[0]) pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim height, width = get_image_size(pixel_values_video[0])
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = video_seqlen if self.expand_mm_tokens else 1 video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
else:
video_seqlen = 1
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
@ -482,6 +489,7 @@ class LlavaNextVideoPlugin(BasePlugin):
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class MiniCPMVPlugin(BasePlugin): class MiniCPMVPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -645,12 +653,7 @@ class MiniCPMVPlugin(BasePlugin):
chunk_input=True, chunk_input=True,
sampling_rate=16000, sampling_rate=16000,
) )
audio_feature_lens = [ audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
torch.tensor(audio_feature_len)
if not isinstance(audio_feature_len, torch.Tensor)
else audio_feature_len
for audio_feature_len in audio_feature_lens
]
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
if kwargs.get("ret_phs", False): if kwargs.get("ret_phs", False):
mm_inputs.update({"audio_phs": audio_phs}) mm_inputs.update({"audio_phs": audio_phs})
@ -670,7 +673,6 @@ class MiniCPMVPlugin(BasePlugin):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(images, videos, audios)
# image bound # image bound
image_bounds_list = [] image_bounds_list = []
valid_image_nums_ls = [] valid_image_nums_ls = []
@ -727,6 +729,7 @@ class MiniCPMVPlugin(BasePlugin):
return mm_inputs return mm_inputs
@dataclass
class MllamaPlugin(BasePlugin): class MllamaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -757,7 +760,7 @@ class MllamaPlugin(BasePlugin):
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: "ProcessorMixin", processor: "ProcessorMixin",
**kwargs, imglens: List[int],
) -> Dict[str, "torch.Tensor"]: ) -> Dict[str, "torch.Tensor"]:
r""" r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]]. Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
@ -771,7 +774,6 @@ class MllamaPlugin(BasePlugin):
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
imglens: List[int] = kwargs["imglens"]
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768)) images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768))
batch_images = [] batch_images = []
for image_length in imglens: for image_length in imglens:
@ -793,7 +795,7 @@ class MllamaPlugin(BasePlugin):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens=imglens) mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
num_tiles = mm_inputs.pop("num_tiles") num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id") image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles") max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
@ -811,6 +813,7 @@ class MllamaPlugin(BasePlugin):
return mm_inputs return mm_inputs
@dataclass
class PaliGemmaPlugin(BasePlugin): class PaliGemmaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -877,6 +880,7 @@ class PaliGemmaPlugin(BasePlugin):
return mm_inputs return mm_inputs
@dataclass
class PixtralPlugin(BasePlugin): class PixtralPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -946,6 +950,7 @@ class PixtralPlugin(BasePlugin):
return mm_inputs return mm_inputs
@dataclass
class Qwen2AudioPlugin(BasePlugin): class Qwen2AudioPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
@ -967,9 +972,13 @@ class Qwen2AudioPlugin(BasePlugin):
for message in messages: for message in messages:
content = message["content"] content = message["content"]
while AUDIO_PLACEHOLDER in content: while AUDIO_PLACEHOLDER in content:
audio_length = audio_lengths.pop(0) if self.expand_mm_tokens:
input_length = (audio_length - 1) // 2 + 1 audio_length = audio_lengths.pop(0)
audio_seqlen = (input_length - 2) // 2 + 1 if self.expand_mm_tokens else 1 input_length = (audio_length - 1) // 2 + 1
audio_seqlen = (input_length - 2) // 2 + 1
else:
audio_seqlen = 1
content = content.replace( content = content.replace(
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1 AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
) )
@ -998,6 +1007,7 @@ class Qwen2AudioPlugin(BasePlugin):
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@dataclass
class Qwen2vlPlugin(BasePlugin): class Qwen2vlPlugin(BasePlugin):
@override @override
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
@ -1017,7 +1027,9 @@ class Qwen2vlPlugin(BasePlugin):
return image return image
@override @override
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> Tuple[List[List["ImageObject"]], List[float]]: def _regularize_videos(
self, videos: Sequence["VideoInput"], **kwargs
) -> Tuple[List[List["ImageObject"]], List[float]]:
results, fps_per_video = [], [] results, fps_per_video = [], []
for video in videos: for video in videos:
container = av.open(video, "r") container = av.open(video, "r")
@ -1053,8 +1065,7 @@ class Qwen2vlPlugin(BasePlugin):
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
images, images, image_resolution=getattr(processor, "image_resolution", 768 * 768)
image_resolution=getattr(processor, "image_resolution", 768 * 768),
) )
mm_inputs.update(image_processor(images, return_tensors="pt")) mm_inputs.update(image_processor(images, return_tensors="pt"))
@ -1142,6 +1153,7 @@ class Qwen2vlPlugin(BasePlugin):
return mm_inputs return mm_inputs
@dataclass
class VideoLlavaPlugin(BasePlugin): class VideoLlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(