diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 12b8dfab..f84749ec 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -140,7 +140,7 @@ def test_base_plugin(): @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") def test_cpm_o_plugin(): - tokenizer_module = _load_tokenizer_module(model_name_or_path="openbmb/MiniCPM-V-2_6") + tokenizer_module = _load_tokenizer_module(model_name_or_path="/data/fengzc/LLM/checkpoints/MiniCPM-V-2_6") cpm_o_plugin = get_mm_plugin(name="cpm_o", image_token="") check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module} image_seqlen = 64 @@ -151,12 +151,8 @@ def test_cpm_o_plugin(): } for message in MM_MESSAGES ] - check_inputs["expected_mm_inputs"] = { - "pixel_values": [[]], - "image_sizes": [[]], - "tgt_sizes": [[]], - "image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)], - } + check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"]) + check_inputs["expected_mm_inputs"]["image_bound"] = [torch.arange(64)] check_inputs["expected_no_mm_inputs"] = {"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)]} _check_plugin(**check_inputs)