[data] fix mllama (#7053)

* fix mllama

* fix test

Former-commit-id: 76314e6ad1ecaa44fcae4375dd0abf4ebaf1f924
This commit is contained in:
hoshi-hiyouga 2025-02-24 22:05:38 +08:00 committed by GitHub
parent ca78ba964d
commit dca5fe14c2
2 changed files with 85 additions and 66 deletions

View File

@ -82,6 +82,7 @@ class MMPluginMixin:
def _validate_input( def _validate_input(
self, self,
processor: Optional["ProcessorMixin"],
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
videos: Sequence["VideoInput"], videos: Sequence["VideoInput"],
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
@ -89,6 +90,8 @@ class MMPluginMixin:
r""" r"""
Validates if this model accepts the input modalities. Validates if this model accepts the input modalities.
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
if len(images) != 0 and self.image_token is None: if len(images) != 0 and self.image_token is None:
raise ValueError( raise ValueError(
"This model does not support image input. Please check whether the correct `template` is used." "This model does not support image input. Please check whether the correct `template` is used."
@ -104,6 +107,15 @@ class MMPluginMixin:
"This model does not support audio input. Please check whether the correct `template` is used." "This model does not support audio input. Please check whether the correct `template` is used."
) )
if self.image_token is not None and processor is None:
raise ValueError("Processor was not found, please check and update your processor config.")
if self.image_token is not None and image_processor is None:
raise ValueError("Image processor was not found, please check and update your processor config.")
if self.audio_token is not None and feature_extractor is None:
raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
def _preprocess_image( def _preprocess_image(
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
) -> "ImageObject": ) -> "ImageObject":
@ -275,7 +287,7 @@ class BasePlugin(MMPluginMixin):
r""" r"""
Pre-processes input messages before tokenization for VLMs. Pre-processes input messages before tokenization for VLMs.
""" """
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
return messages return messages
def process_token_ids( def process_token_ids(
@ -291,7 +303,7 @@ class BasePlugin(MMPluginMixin):
r""" r"""
Pre-processes token ids after tokenization for VLMs. Pre-processes token ids after tokenization for VLMs.
""" """
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
return input_ids, labels return input_ids, labels
def get_mm_inputs( def get_mm_inputs(
@ -317,7 +329,7 @@ class BasePlugin(MMPluginMixin):
batch_ids: token ids of input samples, shape (batch_size, seq_len) batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos processor: a processor for pre-processing images and videos
""" """
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
return {} return {}
@ -332,7 +344,7 @@ class LlavaPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1 image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
messages = deepcopy(messages) messages = deepcopy(messages)
@ -361,7 +373,7 @@ class LlavaPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -376,7 +388,7 @@ class LlavaNextPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -417,7 +429,7 @@ class LlavaNextPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -432,7 +444,7 @@ class LlavaNextVideoPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -493,7 +505,7 @@ class LlavaNextVideoPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -508,7 +520,7 @@ class MiniCPMVPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
num_video_tokens = 0 num_video_tokens = 0
num_audio_tokens = 0 num_audio_tokens = 0
@ -611,6 +623,7 @@ class MiniCPMVPlugin(BasePlugin):
**kwargs, **kwargs,
) -> Dict[str, "torch.Tensor"]: ) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
mm_inputs = {} mm_inputs = {}
if len(images) != 0: if len(images) != 0:
images = self._regularize_images( images = self._regularize_images(
@ -645,21 +658,19 @@ class MiniCPMVPlugin(BasePlugin):
mm_inputs.update(video_inputs) mm_inputs.update(video_inputs)
if len(audios) != 0: if len(audios) != 0:
new_audios = [] audios = self._regularize_audios(
for audio in audios: audios,
if not isinstance(audio, np.ndarray): sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0] )
new_audios.append(audio)
if "valid_audio_nums_ls" in kwargs: if "valid_audio_nums_ls" in kwargs:
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"] valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
audios_ls = [] audios_ls = []
idx = 0 idx = 0
for valid_audio_nums in valid_audio_nums_ls: for valid_audio_nums in valid_audio_nums_ls:
audios_ls.append(new_audios[idx : idx + valid_audio_nums]) audios_ls.append(audios[idx : idx + valid_audio_nums])
idx += valid_audio_nums idx += valid_audio_nums
else: else:
audios_ls = [new_audios] audios_ls = [audios]
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract( audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
audios_ls, audios_ls,
@ -685,7 +696,7 @@ class MiniCPMVPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
# image bound # image bound
image_bounds_list = [] image_bounds_list = []
valid_image_nums_ls = [] valid_image_nums_ls = []
@ -753,7 +764,7 @@ class MllamaPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
@ -787,17 +798,21 @@ class MllamaPlugin(BasePlugin):
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
""" """
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
images = self._regularize_images( mm_inputs = {}
images, if len(images) > 0:
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), images = self._regularize_images(
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), images,
) image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
batch_images = [] image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
for image_length in imglens: )
batch_images.append(images[:image_length]) batch_images = []
images = images[image_length:] for image_length in imglens:
batch_images.append(images[:image_length])
images = images[image_length:]
return image_processor(batch_images, return_tensors="pt") mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
return mm_inputs
@override @override
def get_mm_inputs( def get_mm_inputs(
@ -811,22 +826,24 @@ class MllamaPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens) mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
num_tiles = mm_inputs.pop("num_tiles") if mm_inputs:
image_token_id = getattr(processor, "image_token_id") num_tiles = mm_inputs.pop("num_tiles")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles") image_token_id = getattr(processor, "image_token_id")
cross_attention_token_mask = [ max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids cross_attention_token_mask = [
] get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
mm_inputs["cross_attention_mask"] = torch.from_numpy( ]
convert_sparse_cross_attention_mask_to_dense( mm_inputs["cross_attention_mask"] = torch.from_numpy(
cross_attention_token_mask, convert_sparse_cross_attention_mask_to_dense(
num_tiles=num_tiles, cross_attention_token_mask,
max_num_tiles=max_image_tiles, num_tiles=num_tiles,
length=max(len(input_ids) for input_ids in batch_ids), max_num_tiles=max_image_tiles,
) length=max(len(input_ids) for input_ids in batch_ids),
) # shape: (batch_size, length, max_num_images, max_num_tiles) )
) # shape: (batch_size, length, max_num_images, max_num_tiles)
return mm_inputs return mm_inputs
@ -841,7 +858,7 @@ class PaliGemmaPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens = 0 num_image_tokens = 0
messages = deepcopy(messages) messages = deepcopy(messages)
for message in messages: for message in messages:
@ -868,7 +885,7 @@ class PaliGemmaPlugin(BasePlugin):
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Tuple[List[int], Optional[List[int]]]: ) -> Tuple[List[int], Optional[List[int]]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
num_images = len(images) num_images = len(images)
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token) image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
@ -890,7 +907,7 @@ class PaliGemmaPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
seqlens = [len(input_ids) for input_ids in batch_ids] seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
@ -908,7 +925,7 @@ class PixtralPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
patch_size = getattr(processor, "patch_size") patch_size = getattr(processor, "patch_size")
image_token = getattr(processor, "image_token") image_token = getattr(processor, "image_token")
image_break_token = getattr(processor, "image_break_token") image_break_token = getattr(processor, "image_break_token")
@ -956,7 +973,7 @@ class PixtralPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_sizes", None) mm_inputs.pop("image_sizes", None)
return mm_inputs return mm_inputs
@ -973,7 +990,7 @@ class Qwen2AudioPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
bos_token: str = getattr(processor, "audio_bos_token") bos_token: str = getattr(processor, "audio_bos_token")
eos_token: str = getattr(processor, "audio_eos_token") eos_token: str = getattr(processor, "audio_eos_token")
mm_inputs = self._get_mm_inputs([], [], audios, processor) mm_inputs = self._get_mm_inputs([], [], audios, processor)
@ -1015,7 +1032,7 @@ class Qwen2AudioPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)
@ -1105,7 +1122,7 @@ class Qwen2vlPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
merge_length: int = getattr(image_processor, "merge_size") ** 2 merge_length: int = getattr(image_processor, "merge_size") ** 2
if self.expand_mm_tokens: if self.expand_mm_tokens:
@ -1162,7 +1179,7 @@ class Qwen2vlPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
fps_per_video = mm_inputs.pop("fps_per_video", []) fps_per_video = mm_inputs.pop("fps_per_video", [])
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
@ -1183,7 +1200,7 @@ class VideoLlavaPlugin(BasePlugin):
audios: Sequence["AudioInput"], audios: Sequence["AudioInput"],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
num_image_tokens, num_video_tokens = 0, 0 num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages) messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
@ -1241,7 +1258,7 @@ class VideoLlavaPlugin(BasePlugin):
batch_ids: Sequence[List[int]], batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos, audios) self._validate_input(processor, images, videos, audios)
return self._get_mm_inputs(images, videos, audios, processor) return self._get_mm_inputs(images, videos, audios, processor)

View File

@ -103,15 +103,17 @@ def _check_plugin(
expected_no_mm_inputs: Dict[str, Any] = {}, expected_no_mm_inputs: Dict[str, Any] = {},
) -> None: ) -> None:
# test mm_messages # test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages if plugin.__class__.__name__ != "BasePlugin":
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == ( assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
expected_input_ids, assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_labels, expected_input_ids,
) expected_labels,
_is_close( )
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor), _is_close(
expected_mm_inputs, plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
) expected_mm_inputs,
)
# test text_messages # test text_messages
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == ( assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
@ -128,7 +130,7 @@ def _check_plugin(
def test_base_plugin(): def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA) tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>") base_plugin = get_mm_plugin(name="base")
check_inputs = {"plugin": base_plugin, **tokenizer_module} check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs) _check_plugin(**check_inputs)