From 17d32fb5c70f620541958dd5a3a5bc905844d071 Mon Sep 17 00:00:00 2001 From: fzc8578 <1428195643@qq.com> Date: Mon, 13 Jan 2025 15:01:39 +0800 Subject: [PATCH] fix tests Former-commit-id: 582a17a12010943c7ca1cc0e25ebc8d125d10b45 --- tests/data/test_mm_plugin.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 12b8dfab..f84749ec 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -140,7 +140,7 @@ def test_base_plugin(): @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") def test_cpm_o_plugin(): - tokenizer_module = _load_tokenizer_module(model_name_or_path="openbmb/MiniCPM-V-2_6") + tokenizer_module = _load_tokenizer_module(model_name_or_path="/data/fengzc/LLM/checkpoints/MiniCPM-V-2_6") cpm_o_plugin = get_mm_plugin(name="cpm_o", image_token="") check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module} image_seqlen = 64 @@ -151,12 +151,8 @@ def test_cpm_o_plugin(): } for message in MM_MESSAGES ] - check_inputs["expected_mm_inputs"] = { - "pixel_values": [[]], - "image_sizes": [[]], - "tgt_sizes": [[]], - "image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)], - } + check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"]) + check_inputs["expected_mm_inputs"]["image_bound"] = [torch.arange(64)] check_inputs["expected_no_mm_inputs"] = {"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)]} _check_plugin(**check_inputs)