mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 11:42:49 +08:00
[data] refactor mm plugin (#6895)
* refactor plugin * lint Former-commit-id: aca63bfcca02ecd95b57cd8949a50e26a913f716
This commit is contained in:
parent
188f22d8a7
commit
593acca556
@ -2,6 +2,7 @@ import inspect
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
|
|
||||||
@ -72,12 +73,12 @@ def _get_paligemma_token_type_ids(
|
|||||||
return batch_token_type_ids
|
return batch_token_type_ids
|
||||||
|
|
||||||
|
|
||||||
class BasePlugin:
|
@dataclass
|
||||||
def __init__(self, image_token: Optional[str], video_token: Optional[str], audio_token: Optional[str]) -> None:
|
class MMPluginMixin:
|
||||||
self.image_token = image_token
|
image_token: Optional[str]
|
||||||
self.video_token = video_token
|
video_token: Optional[str]
|
||||||
self.audio_token = audio_token
|
audio_token: Optional[str]
|
||||||
self.expand_mm_tokens = True
|
expand_mm_tokens: bool = True
|
||||||
|
|
||||||
def _validate_input(
|
def _validate_input(
|
||||||
self,
|
self,
|
||||||
@ -103,11 +104,10 @@ class BasePlugin:
|
|||||||
"This model does not support audio input. Please check whether the correct `template` is used."
|
"This model does not support audio input. Please check whether the correct `template` is used."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
def _preprocess_image(self, image: "ImageObject", image_resolution: int, **kwargs) -> "ImageObject":
|
||||||
r"""
|
r"""
|
||||||
Pre-processes a single image.
|
Pre-processes a single image.
|
||||||
"""
|
"""
|
||||||
image_resolution: int = kwargs["image_resolution"]
|
|
||||||
if (image.width * image.height) > image_resolution:
|
if (image.width * image.height) > image_resolution:
|
||||||
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
|
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
|
||||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||||
@ -118,12 +118,12 @@ class BasePlugin:
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _get_video_sample_indices(self, video_stream: "Stream", **kwargs) -> List[int]:
|
def _get_video_sample_indices(
|
||||||
|
self, video_stream: "Stream", video_fps: float, video_maxlen: int, **kwargs
|
||||||
|
) -> List[int]:
|
||||||
r"""
|
r"""
|
||||||
Computes video sample indices according to fps.
|
Computes video sample indices according to fps.
|
||||||
"""
|
"""
|
||||||
video_fps: float = kwargs["video_fps"]
|
|
||||||
video_maxlen: int = kwargs["video_maxlen"]
|
|
||||||
total_frames = video_stream.frames
|
total_frames = video_stream.frames
|
||||||
if total_frames == 0: # infinite video
|
if total_frames == 0: # infinite video
|
||||||
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
|
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
|
||||||
@ -175,12 +175,11 @@ class BasePlugin:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _regularize_audios(self, audios: Sequence["AudioInput"], **kwargs) -> List["NDArray"]:
|
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> List["NDArray"]:
|
||||||
r"""
|
r"""
|
||||||
Regularizes audios to avoid error. Including reading and resampling.
|
Regularizes audios to avoid error. Including reading and resampling.
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
sampling_rate = kwargs["sampling_rate"]
|
|
||||||
for audio in audios:
|
for audio in audios:
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = librosa.load(audio, sr=sampling_rate)[0]
|
audio = librosa.load(audio, sr=sampling_rate)[0]
|
||||||
@ -218,8 +217,7 @@ class BasePlugin:
|
|||||||
|
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
images,
|
images, image_resolution=getattr(processor, "image_resolution", 768 * 768)
|
||||||
image_resolution=getattr(processor, "image_resolution", 768 * 768),
|
|
||||||
)
|
)
|
||||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||||
|
|
||||||
@ -253,6 +251,9 @@ class BasePlugin:
|
|||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BasePlugin(MMPluginMixin):
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@ -310,6 +311,7 @@ class BasePlugin:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class LlavaPlugin(BasePlugin):
|
class LlavaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -353,6 +355,7 @@ class LlavaPlugin(BasePlugin):
|
|||||||
return self._get_mm_inputs(images, videos, audios, processor)
|
return self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class LlavaNextPlugin(BasePlugin):
|
class LlavaNextPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -410,6 +413,7 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
return self._get_mm_inputs(images, videos, audios, processor)
|
return self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class LlavaNextVideoPlugin(BasePlugin):
|
class LlavaNextVideoPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -444,12 +448,15 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
message["content"] = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
if "pixel_values_videos" in mm_inputs:
|
if "pixel_values_videos" in mm_inputs:
|
||||||
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
if self.expand_mm_tokens:
|
||||||
height, width = get_image_size(pixel_values_video[0])
|
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
height, width = get_image_size(pixel_values_video[0])
|
||||||
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
|
||||||
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||||
video_seqlen = video_seqlen if self.expand_mm_tokens else 1
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||||
|
else:
|
||||||
|
video_seqlen = 1
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
@ -482,6 +489,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
return self._get_mm_inputs(images, videos, audios, processor)
|
return self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class MiniCPMVPlugin(BasePlugin):
|
class MiniCPMVPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -645,12 +653,7 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
chunk_input=True,
|
chunk_input=True,
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
)
|
)
|
||||||
audio_feature_lens = [
|
audio_feature_lens = [torch.tensor(audio_feature_len) for audio_feature_len in audio_feature_lens]
|
||||||
torch.tensor(audio_feature_len)
|
|
||||||
if not isinstance(audio_feature_len, torch.Tensor)
|
|
||||||
else audio_feature_len
|
|
||||||
for audio_feature_len in audio_feature_lens
|
|
||||||
]
|
|
||||||
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
mm_inputs.update({"audio_features": audio_features, "audio_feature_lens": audio_feature_lens})
|
||||||
if kwargs.get("ret_phs", False):
|
if kwargs.get("ret_phs", False):
|
||||||
mm_inputs.update({"audio_phs": audio_phs})
|
mm_inputs.update({"audio_phs": audio_phs})
|
||||||
@ -670,7 +673,6 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos, audios)
|
self._validate_input(images, videos, audios)
|
||||||
|
|
||||||
# image bound
|
# image bound
|
||||||
image_bounds_list = []
|
image_bounds_list = []
|
||||||
valid_image_nums_ls = []
|
valid_image_nums_ls = []
|
||||||
@ -727,6 +729,7 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class MllamaPlugin(BasePlugin):
|
class MllamaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -757,7 +760,7 @@ class MllamaPlugin(BasePlugin):
|
|||||||
videos: Sequence["VideoInput"],
|
videos: Sequence["VideoInput"],
|
||||||
audios: Sequence["AudioInput"],
|
audios: Sequence["AudioInput"],
|
||||||
processor: "ProcessorMixin",
|
processor: "ProcessorMixin",
|
||||||
**kwargs,
|
imglens: List[int],
|
||||||
) -> Dict[str, "torch.Tensor"]:
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
r"""
|
r"""
|
||||||
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||||
@ -771,7 +774,6 @@ class MllamaPlugin(BasePlugin):
|
|||||||
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||||
"""
|
"""
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
imglens: List[int] = kwargs["imglens"]
|
|
||||||
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768))
|
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768))
|
||||||
batch_images = []
|
batch_images = []
|
||||||
for image_length in imglens:
|
for image_length in imglens:
|
||||||
@ -793,7 +795,7 @@ class MllamaPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos, audios)
|
self._validate_input(images, videos, audios)
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens=imglens)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
|
||||||
num_tiles = mm_inputs.pop("num_tiles")
|
num_tiles = mm_inputs.pop("num_tiles")
|
||||||
image_token_id = getattr(processor, "image_token_id")
|
image_token_id = getattr(processor, "image_token_id")
|
||||||
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||||
@ -811,6 +813,7 @@ class MllamaPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class PaliGemmaPlugin(BasePlugin):
|
class PaliGemmaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -877,6 +880,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class PixtralPlugin(BasePlugin):
|
class PixtralPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -946,6 +950,7 @@ class PixtralPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Qwen2AudioPlugin(BasePlugin):
|
class Qwen2AudioPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -967,9 +972,13 @@ class Qwen2AudioPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while AUDIO_PLACEHOLDER in content:
|
while AUDIO_PLACEHOLDER in content:
|
||||||
audio_length = audio_lengths.pop(0)
|
if self.expand_mm_tokens:
|
||||||
input_length = (audio_length - 1) // 2 + 1
|
audio_length = audio_lengths.pop(0)
|
||||||
audio_seqlen = (input_length - 2) // 2 + 1 if self.expand_mm_tokens else 1
|
input_length = (audio_length - 1) // 2 + 1
|
||||||
|
audio_seqlen = (input_length - 2) // 2 + 1
|
||||||
|
else:
|
||||||
|
audio_seqlen = 1
|
||||||
|
|
||||||
content = content.replace(
|
content = content.replace(
|
||||||
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
|
AUDIO_PLACEHOLDER, f"{bos_token}{self.audio_token * audio_seqlen}{eos_token}", 1
|
||||||
)
|
)
|
||||||
@ -998,6 +1007,7 @@ class Qwen2AudioPlugin(BasePlugin):
|
|||||||
return self._get_mm_inputs(images, videos, audios, processor)
|
return self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class Qwen2vlPlugin(BasePlugin):
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||||
@ -1017,7 +1027,9 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> Tuple[List[List["ImageObject"]], List[float]]:
|
def _regularize_videos(
|
||||||
|
self, videos: Sequence["VideoInput"], **kwargs
|
||||||
|
) -> Tuple[List[List["ImageObject"]], List[float]]:
|
||||||
results, fps_per_video = [], []
|
results, fps_per_video = [], []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
@ -1053,8 +1065,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
images,
|
images, image_resolution=getattr(processor, "image_resolution", 768 * 768)
|
||||||
image_resolution=getattr(processor, "image_resolution", 768 * 768),
|
|
||||||
)
|
)
|
||||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||||
|
|
||||||
@ -1142,6 +1153,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class VideoLlavaPlugin(BasePlugin):
|
class VideoLlavaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user