From 593acca55688296f0eec710be30ac4613541e162 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 11 Feb 2025 16:34:49 +0800 Subject: [PATCH] [data] refactor mm plugin (#6895) * refactor plugin * lint Former-commit-id: aca63bfcca02ecd95b57cd8949a50e26a913f716 --- src/llamafactory/data/mm_plugin.py | 86 +++++++++++++++++------------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index f9fc2b72..2430a74d 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -2,6 +2,7 @@ import inspect import math import re from copy import deepcopy +from dataclasses import dataclass from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union @@ -72,12 +73,12 @@ def _get_paligemma_token_type_ids( return batch_token_type_ids -class BasePlugin: - def __init__(self, image_token: Optional[str], video_token: Optional[str], audio_token: Optional[str]) -> None: - self.image_token = image_token - self.video_token = video_token - self.audio_token = audio_token - self.expand_mm_tokens = True +@dataclass +class MMPluginMixin: + image_token: Optional[str] + video_token: Optional[str] + audio_token: Optional[str] + expand_mm_tokens: bool = True def _validate_input( self, @@ -103,11 +104,10 @@ class BasePlugin: "This model does not support audio input. Please check whether the correct `template` is used." ) - def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + def _preprocess_image(self, image: "ImageObject", image_resolution: int, **kwargs) -> "ImageObject": r""" Pre-processes a single image. """ - image_resolution: int = kwargs["image_resolution"] if (image.width * image.height) > image_resolution: resize_factor = math.sqrt(image_resolution / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) @@ -118,12 +118,12 @@ class BasePlugin: return image - def _get_video_sample_indices(self, video_stream: "Stream", **kwargs) -> List[int]: + def _get_video_sample_indices( + self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs + ) -> List[int]: r""" Computes video sample indices according to fps. """ - video_fps: float = kwargs["video_fps"] - video_maxlen: int = kwargs["video_maxlen"] total_frames = video_stream.frames if total_frames == 0: # infinite video return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32) @@ -175,12 +175,11 @@ class BasePlugin: return results - def _regularize_audios(self, audios: Sequence["AudioInput"], **kwargs) -> List["NDArray"]: + def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]: r""" Regularizes audios to avoid error. Including reading and resampling. """ results = [] - sampling_rate = kwargs["sampling_rate"] for audio in audios: if isinstance(audio, str): audio = librosa.load(audio, sr=sampling_rate)[0] @@ -218,8 +217,7 @@ class BasePlugin: if len(images) != 0: images = self._regularize_images( - images, - image_resolution=getattr(processor, "image_resolution", 768 * 768), + images, image_resolution=getattr(processor, "image_resolution", 768 * 768) ) mm_inputs.update(image_processor(images, return_tensors="pt")) @@ -253,6 +251,9 @@ class BasePlugin: return mm_inputs + +@dataclass +class BasePlugin(MMPluginMixin): def process_messages( self, messages: Sequence[Dict[str, str]], @@ -310,6 +311,7 @@ class BasePlugin: return {} +@dataclass class LlavaPlugin(BasePlugin): @override def process_messages( @@ -353,6 +355,7 @@ class LlavaPlugin(BasePlugin): return self._get_mm_inputs(images, videos, audios, processor) +@dataclass class LlavaNextPlugin(BasePlugin): @override def process_messages( @@ -410,6 +413,7 @@ class LlavaNextPlugin(BasePlugin): return self._get_mm_inputs(images, videos, audios, processor) +@dataclass class LlavaNextVideoPlugin(BasePlugin): @override def process_messages( @@ -444,12 +448,15 @@ class LlavaNextVideoPlugin(BasePlugin): message["content"] = content.replace("{{image}}", self.image_token) if "pixel_values_videos" in mm_inputs: - pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) - height, width = get_image_size(pixel_values_video[0]) - num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim - image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) - video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer - video_seqlen = video_seqlen if self.expand_mm_tokens else 1 + if self.expand_mm_tokens: + pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(pixel_values_video[0]) + num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer + else: + video_seqlen = 1 + for message in messages: content = message["content"] while VIDEO_PLACEHOLDER in content: @@ -482,6 +489,7 @@ class LlavaNextVideoPlugin(BasePlugin): return self._get_mm_inputs(images, videos, audios, processor) +@dataclass class MiniCPMVPlugin(BasePlugin): @override def process_messages( @@ -645,12 +653,7 @@ class MiniCPMVPlugin(BasePlugin): chunk_input=True, sampling_rate=16000, ) - audio_feature_lens = [ - torch.tensor(audio_feature_len) - if not isinstance(audio_feature_len, torch.Tensor) - else audio_feature_len - for audio_feature_len in audio_feature_lens - ] + audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens] mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens}) if kwargs.get("ret_phs", False): mm_inputs.update({"audio_phs": audio_phs}) @@ -670,7 +673,6 @@ class MiniCPMVPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos, audios) - # image bound image_bounds_list = [] valid_image_nums_ls = [] @@ -727,6 +729,7 @@ class MiniCPMVPlugin(BasePlugin): return mm_inputs +@dataclass class MllamaPlugin(BasePlugin): @override def process_messages( @@ -757,7 +760,7 @@ class MllamaPlugin(BasePlugin): videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], processor: "ProcessorMixin", - **kwargs, + imglens: List[int], ) -> Dict[str, "torch.Tensor"]: r""" Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]]. @@ -771,7 +774,6 @@ class MllamaPlugin(BasePlugin): num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - imglens: List[int] = kwargs["imglens"] images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768)) batch_images = [] for image_length in imglens: @@ -793,7 +795,7 @@ class MllamaPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos, audios) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens=imglens) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) num_tiles = mm_inputs.pop("num_tiles") image_token_id = getattr(processor, "image_token_id") max_image_tiles = getattr(processor.image_processor, "max_image_tiles") @@ -811,6 +813,7 @@ class MllamaPlugin(BasePlugin): return mm_inputs +@dataclass class PaliGemmaPlugin(BasePlugin): @override def process_messages( @@ -877,6 +880,7 @@ class PaliGemmaPlugin(BasePlugin): return mm_inputs +@dataclass class PixtralPlugin(BasePlugin): @override def process_messages( @@ -946,6 +950,7 @@ class PixtralPlugin(BasePlugin): return mm_inputs +@dataclass class Qwen2AudioPlugin(BasePlugin): @override def process_messages( @@ -967,9 +972,13 @@ class Qwen2AudioPlugin(BasePlugin): for message in messages: content = message["content"] while AUDIO_PLACEHOLDER in content: - audio_length = audio_lengths.pop(0) - input_length = (audio_length - 1) // 2 + 1 - audio_seqlen = (input_length - 2) // 2 + 1 if self.expand_mm_tokens else 1 + if self.expand_mm_tokens: + audio_length = audio_lengths.pop(0) + input_length = (audio_length - 1) // 2 + 1 + audio_seqlen = (input_length - 2) // 2 + 1 + else: + audio_seqlen = 1 + content = content.replace( AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1 ) @@ -998,6 +1007,7 @@ class Qwen2AudioPlugin(BasePlugin): return self._get_mm_inputs(images, videos, audios, processor) +@dataclass class Qwen2vlPlugin(BasePlugin): @override def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": @@ -1017,7 +1027,9 @@ class Qwen2vlPlugin(BasePlugin): return image @override - def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> Tuple[List[List["ImageObject"]], List[float]]: + def _regularize_videos( + self, videos: Sequence["VideoInput"], **kwargs + ) -> Tuple[List[List["ImageObject"]], List[float]]: results, fps_per_video = [], [] for video in videos: container = av.open(video, "r") @@ -1053,8 +1065,7 @@ class Qwen2vlPlugin(BasePlugin): mm_inputs = {} if len(images) != 0: images = self._regularize_images( - images, - image_resolution=getattr(processor, "image_resolution", 768 * 768), + images, image_resolution=getattr(processor, "image_resolution", 768 * 768) ) mm_inputs.update(image_processor(images, return_tensors="pt")) @@ -1142,6 +1153,7 @@ class Qwen2vlPlugin(BasePlugin): return mm_inputs +@dataclass class VideoLlavaPlugin(BasePlugin): @override def process_messages(