[data] fix mllama (#7053)

* fix mllama

* fix test

Former-commit-id: 76314e6ad1ecaa44fcae4375dd0abf4ebaf1f924
This commit is contained in:
hoshi-hiyouga 2025-02-24 22:05:38 +08:00 committed by GitHub
parent ca78ba964d
commit dca5fe14c2
2 changed files with 85 additions and 66 deletions

View File

@ -82,6 +82,7 @@ class MMPluginMixin:
def _validate_input(
self,
processor: Optional["ProcessorMixin"],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"],
@ -89,6 +90,8 @@ class MMPluginMixin:
r"""
Validates if this model accepts the input modalities.
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None:
raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used."
@ -104,6 +107,15 @@ class MMPluginMixin:
"This model does not support audio input. Please check whether the correct `template` is used."
)
if self.image_token is not None and processor is None:
raise ValueError("Processor was not found, please check and update your processor config.")
if self.image_token is not None and image_processor is None:
raise ValueError("Image processor was not found, please check and update your processor config.")
if self.audio_token is not None and feature_extractor is None:
raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
) -> "ImageObject":
@ -275,7 +287,7 @@ class BasePlugin(MMPluginMixin):
r"""
Pre-processes input messages before tokenization for VLMs.
"""
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
return messages
def process_token_ids(
@ -291,7 +303,7 @@ class BasePlugin(MMPluginMixin):
r"""
Pre-processes token ids after tokenization for VLMs.
"""
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
return input_ids, labels
def get_mm_inputs(
@ -317,7 +329,7 @@ class BasePlugin(MMPluginMixin):
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
return {}
@ -332,7 +344,7 @@ class LlavaPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages)
@ -361,7 +373,7 @@ class LlavaPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -376,7 +388,7 @@ class LlavaNextPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -417,7 +429,7 @@ class LlavaNextPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -432,7 +444,7 @@ class LlavaNextVideoPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -493,7 +505,7 @@ class LlavaNextVideoPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -508,7 +520,7 @@ class MiniCPMVPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
num_video_tokens = 0
num_audio_tokens = 0
@ -611,6 +623,7 @@ class MiniCPMVPlugin(BasePlugin):
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
@ -645,21 +658,19 @@ class MiniCPMVPlugin(BasePlugin):
mm_inputs.update(video_inputs)
if len(audios) != 0:
new_audios = []
for audio in audios:
if not isinstance(audio, np.ndarray):
audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0]
new_audios.append(audio)
audios = self._regularize_audios(
audios,
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
)
if "valid_audio_nums_ls" in kwargs:
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
audios_ls = []
idx = 0
for valid_audio_nums in valid_audio_nums_ls:
audios_ls.append(new_audios[idx : idx + valid_audio_nums])
audios_ls.append(audios[idx : idx + valid_audio_nums])
idx += valid_audio_nums
else:
audios_ls = [new_audios]
audios_ls = [audios]
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
audios_ls,
@ -685,7 +696,7 @@ class MiniCPMVPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
# image bound
image_bounds_list = []
valid_image_nums_ls = []
@ -753,7 +764,7 @@ class MllamaPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
@ -787,17 +798,21 @@ class MllamaPlugin(BasePlugin):
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
mm_inputs = {}
if len(images) > 0:
images = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
return image_processor(batch_images, return_tensors="pt")
mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
return mm_inputs
@override
def get_mm_inputs(
@ -811,22 +826,24 @@ class MllamaPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
mm_inputs["cross_attention_mask"] = torch.from_numpy(
convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
) # shape: (batch_size, length, max_num_images, max_num_tiles)
if mm_inputs:
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
mm_inputs["cross_attention_mask"] = torch.from_numpy(
convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
) # shape: (batch_size, length, max_num_images, max_num_tiles)
return mm_inputs
@ -841,7 +858,7 @@ class PaliGemmaPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
@ -868,7 +885,7 @@ class PaliGemmaPlugin(BasePlugin):
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
@ -890,7 +907,7 @@ class PaliGemmaPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
@ -908,7 +925,7 @@ class PixtralPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token")
@ -956,7 +973,7 @@ class PixtralPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_sizes", None)
return mm_inputs
@ -973,7 +990,7 @@ class Qwen2AudioPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token")
mm_inputs = self._get_mm_inputs([], [], audios, processor)
@ -1015,7 +1032,7 @@ class Qwen2AudioPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)
@ -1105,7 +1122,7 @@ class Qwen2vlPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens:
@ -1162,7 +1179,7 @@ class Qwen2vlPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
@ -1183,7 +1200,7 @@ class VideoLlavaPlugin(BasePlugin):
audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -1241,7 +1258,7 @@ class VideoLlavaPlugin(BasePlugin):
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios)
self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor)

View File

@ -103,15 +103,17 @@ def _check_plugin(
expected_no_mm_inputs: Dict[str, Any] = {},
) -> None:
# test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
if plugin.__class__.__name__ != "BasePlugin":
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
# test text_messages
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
@ -128,7 +130,7 @@ def _check_plugin(
def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>")
base_plugin = get_mm_plugin(name="base")
check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs)