mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] fix mllama (#7053)
* fix mllama * fix test Former-commit-id: f5af20a63f3d59a6a68d323a7c6f68e551edb3a3
This commit is contained in:
		
							parent
							
								
									c1d5073bd3
								
							
						
					
					
						commit
						065f7fb5da
					
				@ -82,6 +82,7 @@ class MMPluginMixin:
 | 
			
		||||
 | 
			
		||||
    def _validate_input(
 | 
			
		||||
        self,
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
@ -89,6 +90,8 @@ class MMPluginMixin:
 | 
			
		||||
        r"""
 | 
			
		||||
        Validates if this model accepts the input modalities.
 | 
			
		||||
        """
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
 | 
			
		||||
        feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
 | 
			
		||||
        if len(images) != 0 and self.image_token is None:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "This model does not support image input. Please check whether the correct `template` is used."
 | 
			
		||||
@ -104,6 +107,15 @@ class MMPluginMixin:
 | 
			
		||||
                "This model does not support audio input. Please check whether the correct `template` is used."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.image_token is not None and processor is None:
 | 
			
		||||
            raise ValueError("Processor was not found, please check and update your processor config.")
 | 
			
		||||
 | 
			
		||||
        if self.image_token is not None and image_processor is None:
 | 
			
		||||
            raise ValueError("Image processor was not found, please check and update your processor config.")
 | 
			
		||||
 | 
			
		||||
        if self.audio_token is not None and feature_extractor is None:
 | 
			
		||||
            raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
 | 
			
		||||
 | 
			
		||||
    def _preprocess_image(
 | 
			
		||||
        self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
 | 
			
		||||
    ) -> "ImageObject":
 | 
			
		||||
@ -275,7 +287,7 @@ class BasePlugin(MMPluginMixin):
 | 
			
		||||
        r"""
 | 
			
		||||
        Pre-processes input messages before tokenization for VLMs.
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    def process_token_ids(
 | 
			
		||||
@ -291,7 +303,7 @@ class BasePlugin(MMPluginMixin):
 | 
			
		||||
        r"""
 | 
			
		||||
        Pre-processes token ids after tokenization for VLMs.
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return input_ids, labels
 | 
			
		||||
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
@ -317,7 +329,7 @@ class BasePlugin(MMPluginMixin):
 | 
			
		||||
            batch_ids: token ids of input samples, shape (batch_size, seq_len)
 | 
			
		||||
            processor: a processor for pre-processing images and videos
 | 
			
		||||
        """
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -332,7 +344,7 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
@ -361,7 +373,7 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -376,7 +388,7 @@ class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
@ -417,7 +429,7 @@ class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -432,7 +444,7 @@ class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
@ -493,7 +505,7 @@ class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -508,7 +520,7 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        num_video_tokens = 0
 | 
			
		||||
        num_audio_tokens = 0
 | 
			
		||||
@ -611,6 +623,7 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) != 0:
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
@ -645,21 +658,19 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
            mm_inputs.update(video_inputs)
 | 
			
		||||
 | 
			
		||||
        if len(audios) != 0:
 | 
			
		||||
            new_audios = []
 | 
			
		||||
            for audio in audios:
 | 
			
		||||
                if not isinstance(audio, np.ndarray):
 | 
			
		||||
                    audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0]
 | 
			
		||||
                new_audios.append(audio)
 | 
			
		||||
 | 
			
		||||
            audios = self._regularize_audios(
 | 
			
		||||
                audios,
 | 
			
		||||
                sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
 | 
			
		||||
            )
 | 
			
		||||
            if "valid_audio_nums_ls" in kwargs:
 | 
			
		||||
                valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
 | 
			
		||||
                audios_ls = []
 | 
			
		||||
                idx = 0
 | 
			
		||||
                for valid_audio_nums in valid_audio_nums_ls:
 | 
			
		||||
                    audios_ls.append(new_audios[idx : idx + valid_audio_nums])
 | 
			
		||||
                    audios_ls.append(audios[idx : idx + valid_audio_nums])
 | 
			
		||||
                    idx += valid_audio_nums
 | 
			
		||||
            else:
 | 
			
		||||
                audios_ls = [new_audios]
 | 
			
		||||
                audios_ls = [audios]
 | 
			
		||||
 | 
			
		||||
            audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
 | 
			
		||||
                audios_ls,
 | 
			
		||||
@ -685,7 +696,7 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        # image bound
 | 
			
		||||
        image_bounds_list = []
 | 
			
		||||
        valid_image_nums_ls = []
 | 
			
		||||
@ -753,7 +764,7 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
@ -787,17 +798,21 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
            num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
 | 
			
		||||
        """
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        images = self._regularize_images(
 | 
			
		||||
            images,
 | 
			
		||||
            image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
            image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
        )
 | 
			
		||||
        batch_images = []
 | 
			
		||||
        for image_length in imglens:
 | 
			
		||||
            batch_images.append(images[:image_length])
 | 
			
		||||
            images = images[image_length:]
 | 
			
		||||
        mm_inputs = {}
 | 
			
		||||
        if len(images) > 0:
 | 
			
		||||
            images = self._regularize_images(
 | 
			
		||||
                images,
 | 
			
		||||
                image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
 | 
			
		||||
                image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
 | 
			
		||||
            )
 | 
			
		||||
            batch_images = []
 | 
			
		||||
            for image_length in imglens:
 | 
			
		||||
                batch_images.append(images[:image_length])
 | 
			
		||||
                images = images[image_length:]
 | 
			
		||||
 | 
			
		||||
        return image_processor(batch_images, return_tensors="pt")
 | 
			
		||||
            mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
@ -811,22 +826,24 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
 | 
			
		||||
        num_tiles = mm_inputs.pop("num_tiles")
 | 
			
		||||
        image_token_id = getattr(processor, "image_token_id")
 | 
			
		||||
        max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
 | 
			
		||||
        cross_attention_token_mask = [
 | 
			
		||||
            get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
 | 
			
		||||
        ]
 | 
			
		||||
        mm_inputs["cross_attention_mask"] = torch.from_numpy(
 | 
			
		||||
            convert_sparse_cross_attention_mask_to_dense(
 | 
			
		||||
                cross_attention_token_mask,
 | 
			
		||||
                num_tiles=num_tiles,
 | 
			
		||||
                max_num_tiles=max_image_tiles,
 | 
			
		||||
                length=max(len(input_ids) for input_ids in batch_ids),
 | 
			
		||||
            )
 | 
			
		||||
        )  # shape: (batch_size, length, max_num_images, max_num_tiles)
 | 
			
		||||
        if mm_inputs:
 | 
			
		||||
            num_tiles = mm_inputs.pop("num_tiles")
 | 
			
		||||
            image_token_id = getattr(processor, "image_token_id")
 | 
			
		||||
            max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
 | 
			
		||||
            cross_attention_token_mask = [
 | 
			
		||||
                get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
 | 
			
		||||
            ]
 | 
			
		||||
            mm_inputs["cross_attention_mask"] = torch.from_numpy(
 | 
			
		||||
                convert_sparse_cross_attention_mask_to_dense(
 | 
			
		||||
                    cross_attention_token_mask,
 | 
			
		||||
                    num_tiles=num_tiles,
 | 
			
		||||
                    max_num_tiles=max_image_tiles,
 | 
			
		||||
                    length=max(len(input_ids) for input_ids in batch_ids),
 | 
			
		||||
                )
 | 
			
		||||
            )  # shape: (batch_size, length, max_num_images, max_num_tiles)
 | 
			
		||||
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -841,7 +858,7 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
@ -868,7 +885,7 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Tuple[List[int], Optional[List[int]]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_images = len(images)
 | 
			
		||||
        image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0  # skip mm token
 | 
			
		||||
        image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
 | 
			
		||||
@ -890,7 +907,7 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        seqlens = [len(input_ids) for input_ids in batch_ids]
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
 | 
			
		||||
@ -908,7 +925,7 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        patch_size = getattr(processor, "patch_size")
 | 
			
		||||
        image_token = getattr(processor, "image_token")
 | 
			
		||||
        image_break_token = getattr(processor, "image_break_token")
 | 
			
		||||
@ -956,7 +973,7 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        mm_inputs.pop("image_sizes", None)
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
@ -973,7 +990,7 @@ class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        bos_token: str = getattr(processor, "audio_bos_token")
 | 
			
		||||
        eos_token: str = getattr(processor, "audio_eos_token")
 | 
			
		||||
        mm_inputs = self._get_mm_inputs([], [], audios, processor)
 | 
			
		||||
@ -1015,7 +1032,7 @@ class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1105,7 +1122,7 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        merge_length: int = getattr(image_processor, "merge_size") ** 2
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
@ -1162,7 +1179,7 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        fps_per_video = mm_inputs.pop("fps_per_video", [])
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
@ -1183,7 +1200,7 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> List[Dict[str, str]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
@ -1241,7 +1258,7 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
        batch_ids: Sequence[List[int]],
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
    ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        return self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -103,15 +103,17 @@ def _check_plugin(
 | 
			
		||||
    expected_no_mm_inputs: Dict[str, Any] = {},
 | 
			
		||||
) -> None:
 | 
			
		||||
    # test mm_messages
 | 
			
		||||
    assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
 | 
			
		||||
    assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
 | 
			
		||||
        expected_input_ids,
 | 
			
		||||
        expected_labels,
 | 
			
		||||
    )
 | 
			
		||||
    _is_close(
 | 
			
		||||
        plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
 | 
			
		||||
        expected_mm_inputs,
 | 
			
		||||
    )
 | 
			
		||||
    if plugin.__class__.__name__ != "BasePlugin":
 | 
			
		||||
        assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
 | 
			
		||||
        assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
 | 
			
		||||
            expected_input_ids,
 | 
			
		||||
            expected_labels,
 | 
			
		||||
        )
 | 
			
		||||
        _is_close(
 | 
			
		||||
            plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
 | 
			
		||||
            expected_mm_inputs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # test text_messages
 | 
			
		||||
    assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES
 | 
			
		||||
    assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
 | 
			
		||||
@ -128,7 +130,7 @@ def _check_plugin(
 | 
			
		||||
 | 
			
		||||
def test_base_plugin():
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
 | 
			
		||||
    base_plugin = get_mm_plugin(name="base", image_token="<image>")
 | 
			
		||||
    base_plugin = get_mm_plugin(name="base")
 | 
			
		||||
    check_inputs = {"plugin": base_plugin, **tokenizer_module}
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user