diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index c9084af0..ea92f0df 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -127,6 +127,27 @@ def test_base_plugin(): _check_plugin(**check_inputs) +def test_cpm_o_plugin(): + tokenizer_module = _load_tokenizer_module(model_name_or_path="openbmb/MiniCPM-V-2_6") + cpm_o_plugin = get_mm_plugin(name="cpm_o", image_token="") + check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module} + image_seqlen = 64 + check_inputs["expected_mm_messages"] = [ + {key: value.replace("", f"0{'' * image_seqlen}") for key, value in message.items()} + for message in MM_MESSAGES + ] + check_inputs["expected_mm_inputs"] = { + "pixel_values": [[]], + "image_sizes": [[]], + "tgt_sizes": [[]], + "image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)] + } + check_inputs["expected_no_mm_inputs"] = { + "image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)] + } + _check_plugin(**check_inputs) + + def test_llava_plugin(): image_seqlen = 576 tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")