plugin test & check

Former-commit-id: 2df2be1c47aded0132b5cc86acd3926dca585bc1
This commit is contained in:
KUANGDD 2024-10-15 12:12:46 +08:00 committed by Junhao Zhang
parent df722bf18e
commit 5e440a467d
2 changed files with 16 additions and 2 deletions

View File

@ -697,7 +697,7 @@ def get_mm_plugin(
try:
require_version("transformers==4.46.0.dev0")
except Exception as e:
raise ImportError("PixtralPlugin requires transformers==4.46.0.dev0. Please install it first.")
raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.")
if plugin_class is None:
raise ValueError("Multimodal plugin `{}` not found.".format(name))

View File

@ -178,7 +178,18 @@ def test_paligemma_plugin():
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
_check_plugin(**check_inputs)
def test_pixtral_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
image_slice_heigt, image_slice_width = 2, 2
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_heigt).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]"
for key, value in message.items()} for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
def test_qwen2_vl_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
@ -206,3 +217,6 @@ def test_video_llava_plugin():
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
_check_plugin(**check_inputs)
if __name__ == "__main__":
test_pixtral_plugin()