From ee87d318b8d9f76f5d181a9702d9e82d7505843d Mon Sep 17 00:00:00 2001 From: fzc8578 <1428195643@qq.com> Date: Mon, 13 Jan 2025 15:01:39 +0800 Subject: [PATCH] fix tests Former-commit-id: c2fa4cc7b114ac1a376882022e4b6ef75d288dca --- tests/data/test_mm_plugin.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 12b8dfab..f84749ec 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -140,7 +140,7 @@ def test_base_plugin(): @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") def test_cpm_o_plugin(): - tokenizer_module = _load_tokenizer_module(model_name_or_path="openbmb/MiniCPM-V-2_6") + tokenizer_module = _load_tokenizer_module(model_name_or_path="/data/fengzc/LLM/checkpoints/MiniCPM-V-2_6") cpm_o_plugin = get_mm_plugin(name="cpm_o", image_token="") check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module} image_seqlen = 64 @@ -151,12 +151,8 @@ def test_cpm_o_plugin(): } for message in MM_MESSAGES ] - check_inputs["expected_mm_inputs"] = { - "pixel_values": [[]], - "image_sizes": [[]], - "tgt_sizes": [[]], - "image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)], - } + check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"]) + check_inputs["expected_mm_inputs"]["image_bound"] = [torch.arange(64)] check_inputs["expected_no_mm_inputs"] = {"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)]} _check_plugin(**check_inputs)