From dca5fe14c2426176b769bd0c3b4f1a2765ff3b55 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 24 Feb 2025 22:05:38 +0800 Subject: [PATCH] [data] fix mllama (#7053) * fix mllama * fix test Former-commit-id: 76314e6ad1ecaa44fcae4375dd0abf4ebaf1f924 --- src/llamafactory/data/mm_plugin.py | 129 ++++++++++++++++------------- tests/data/test_mm_plugin.py | 22 ++--- 2 files changed, 85 insertions(+), 66 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 4947ff41..133984cd 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -82,6 +82,7 @@ class MMPluginMixin: def _validate_input( self, + processor: Optional["ProcessorMixin"], images: Sequence["ImageInput"], videos: Sequence["VideoInput"], audios: Sequence["AudioInput"], @@ -89,6 +90,8 @@ class MMPluginMixin: r""" Validates if this model accepts the input modalities. """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None) + feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None) if len(images) != 0 and self.image_token is None: raise ValueError( "This model does not support image input. Please check whether the correct `template` is used." @@ -104,6 +107,15 @@ class MMPluginMixin: "This model does not support audio input. Please check whether the correct `template` is used." ) + if self.image_token is not None and processor is None: + raise ValueError("Processor was not found, please check and update your processor config.") + + if self.image_token is not None and image_processor is None: + raise ValueError("Image processor was not found, please check and update your processor config.") + + if self.audio_token is not None and feature_extractor is None: + raise ValueError("Audio feature extractor was not found, please check and update your processor config.") + def _preprocess_image( self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs ) -> "ImageObject": @@ -275,7 +287,7 @@ class BasePlugin(MMPluginMixin): r""" Pre-processes input messages before tokenization for VLMs. """ - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) return messages def process_token_ids( @@ -291,7 +303,7 @@ class BasePlugin(MMPluginMixin): r""" Pre-processes token ids after tokenization for VLMs. """ - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) return input_ids, labels def get_mm_inputs( @@ -317,7 +329,7 @@ class BasePlugin(MMPluginMixin): batch_ids: token ids of input samples, shape (batch_size, seq_len) processor: a processor for pre-processing images and videos """ - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) return {} @@ -332,7 +344,7 @@ class LlavaPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) num_image_tokens = 0 image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1 messages = deepcopy(messages) @@ -361,7 +373,7 @@ class LlavaPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -376,7 +388,7 @@ class LlavaNextPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) num_image_tokens = 0 messages = deepcopy(messages) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) @@ -417,7 +429,7 @@ class LlavaNextPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -432,7 +444,7 @@ class LlavaNextVideoPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) @@ -493,7 +505,7 @@ class LlavaNextVideoPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -508,7 +520,7 @@ class MiniCPMVPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) num_image_tokens = 0 num_video_tokens = 0 num_audio_tokens = 0 @@ -611,6 +623,7 @@ class MiniCPMVPlugin(BasePlugin): **kwargs, ) -> Dict[str, "torch.Tensor"]: image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None) mm_inputs = {} if len(images) != 0: images = self._regularize_images( @@ -645,21 +658,19 @@ class MiniCPMVPlugin(BasePlugin): mm_inputs.update(video_inputs) if len(audios) != 0: - new_audios = [] - for audio in audios: - if not isinstance(audio, np.ndarray): - audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0] - new_audios.append(audio) - + audios = self._regularize_audios( + audios, + sampling_rate=getattr(feature_extractor, "sampling_rate", 16000), + ) if "valid_audio_nums_ls" in kwargs: valid_audio_nums_ls = kwargs["valid_audio_nums_ls"] audios_ls = [] idx = 0 for valid_audio_nums in valid_audio_nums_ls: - audios_ls.append(new_audios[idx : idx + valid_audio_nums]) + audios_ls.append(audios[idx : idx + valid_audio_nums]) idx += valid_audio_nums else: - audios_ls = [new_audios] + audios_ls = [audios] audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract( audios_ls, @@ -685,7 +696,7 @@ class MiniCPMVPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) # image bound image_bounds_list = [] valid_image_nums_ls = [] @@ -753,7 +764,7 @@ class MllamaPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) num_image_tokens = 0 messages = deepcopy(messages) for message in messages: @@ -787,17 +798,21 @@ class MllamaPlugin(BasePlugin): num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - images = self._regularize_images( - images, - image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), - image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), - ) - batch_images = [] - for image_length in imglens: - batch_images.append(images[:image_length]) - images = images[image_length:] + mm_inputs = {} + if len(images) > 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + ) + batch_images = [] + for image_length in imglens: + batch_images.append(images[:image_length]) + images = images[image_length:] - return image_processor(batch_images, return_tensors="pt") + mm_inputs.update(image_processor(batch_images, return_tensors="pt")) + + return mm_inputs @override def get_mm_inputs( @@ -811,22 +826,24 @@ class MllamaPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) - num_tiles = mm_inputs.pop("num_tiles") - image_token_id = getattr(processor, "image_token_id") - max_image_tiles = getattr(processor.image_processor, "max_image_tiles") - cross_attention_token_mask = [ - get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids - ] - mm_inputs["cross_attention_mask"] = torch.from_numpy( - convert_sparse_cross_attention_mask_to_dense( - cross_attention_token_mask, - num_tiles=num_tiles, - max_num_tiles=max_image_tiles, - length=max(len(input_ids) for input_ids in batch_ids), - ) - ) # shape: (batch_size, length, max_num_images, max_num_tiles) + if mm_inputs: + num_tiles = mm_inputs.pop("num_tiles") + image_token_id = getattr(processor, "image_token_id") + max_image_tiles = getattr(processor.image_processor, "max_image_tiles") + cross_attention_token_mask = [ + get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids + ] + mm_inputs["cross_attention_mask"] = torch.from_numpy( + convert_sparse_cross_attention_mask_to_dense( + cross_attention_token_mask, + num_tiles=num_tiles, + max_num_tiles=max_image_tiles, + length=max(len(input_ids) for input_ids in batch_ids), + ) + ) # shape: (batch_size, length, max_num_images, max_num_tiles) + return mm_inputs @@ -841,7 +858,7 @@ class PaliGemmaPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) num_image_tokens = 0 messages = deepcopy(messages) for message in messages: @@ -868,7 +885,7 @@ class PaliGemmaPlugin(BasePlugin): tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], ) -> Tuple[List[int], Optional[List[int]]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) num_images = len(images) image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) @@ -890,7 +907,7 @@ class PaliGemmaPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) seqlens = [len(input_ids) for input_ids in batch_ids] mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) @@ -908,7 +925,7 @@ class PixtralPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) patch_size = getattr(processor, "patch_size") image_token = getattr(processor, "image_token") image_break_token = getattr(processor, "image_break_token") @@ -956,7 +973,7 @@ class PixtralPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs.pop("image_sizes", None) return mm_inputs @@ -973,7 +990,7 @@ class Qwen2AudioPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) bos_token: str = getattr(processor, "audio_bos_token") eos_token: str = getattr(processor, "audio_eos_token") mm_inputs = self._get_mm_inputs([], [], audios, processor) @@ -1015,7 +1032,7 @@ class Qwen2AudioPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) @@ -1105,7 +1122,7 @@ class Qwen2vlPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") merge_length: int = getattr(image_processor, "merge_size") ** 2 if self.expand_mm_tokens: @@ -1162,7 +1179,7 @@ class Qwen2vlPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) fps_per_video = mm_inputs.pop("fps_per_video", []) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") @@ -1183,7 +1200,7 @@ class VideoLlavaPlugin(BasePlugin): audios: Sequence["AudioInput"], processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) @@ -1241,7 +1258,7 @@ class VideoLlavaPlugin(BasePlugin): batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos, audios) + self._validate_input(processor, images, videos, audios) return self._get_mm_inputs(images, videos, audios, processor) diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 5e0ec660..0d26acd1 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -103,15 +103,17 @@ def _check_plugin( expected_no_mm_inputs: Dict[str, Any] = {}, ) -> None: # test mm_messages - assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages - assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == ( - expected_input_ids, - expected_labels, - ) - _is_close( - plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor), - expected_mm_inputs, - ) + if plugin.__class__.__name__ != "BasePlugin": + assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages + assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == ( + expected_input_ids, + expected_labels, + ) + _is_close( + plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor), + expected_mm_inputs, + ) + # test text_messages assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == ( @@ -128,7 +130,7 @@ def _check_plugin( def test_base_plugin(): tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA) - base_plugin = get_mm_plugin(name="base", image_token="") + base_plugin = get_mm_plugin(name="base") check_inputs = {"plugin": base_plugin, **tokenizer_module} _check_plugin(**check_inputs)