fix tests

Former-commit-id: c2fa4cc7b114ac1a376882022e4b6ef75d288dca
This commit is contained in:
fzc8578 2025-01-13 15:01:39 +08:00
parent 4741eec2d1
commit ee87d318b8

View File

@ -140,7 +140,7 @@ def test_base_plugin():
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
def test_cpm_o_plugin(): def test_cpm_o_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path="openbmb/MiniCPM-V-2_6") tokenizer_module = _load_tokenizer_module(model_name_or_path="/data/fengzc/LLM/checkpoints/MiniCPM-V-2_6")
cpm_o_plugin = get_mm_plugin(name="cpm_o", image_token="<image>") cpm_o_plugin = get_mm_plugin(name="cpm_o", image_token="<image>")
check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module} check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module}
image_seqlen = 64 image_seqlen = 64
@ -151,12 +151,8 @@ def test_cpm_o_plugin():
} }
for message in MM_MESSAGES for message in MM_MESSAGES
] ]
check_inputs["expected_mm_inputs"] = { check_inputs["expected_mm_inputs"] = _get_mm_inputs(tokenizer_module["processor"])
"pixel_values": [[]], check_inputs["expected_mm_inputs"]["image_bound"] = [torch.arange(64)]
"image_sizes": [[]],
"tgt_sizes": [[]],
"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)],
}
check_inputs["expected_no_mm_inputs"] = {"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)]} check_inputs["expected_no_mm_inputs"] = {"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0, 2)]}
_check_plugin(**check_inputs) _check_plugin(**check_inputs)