add cpm_o test

Former-commit-id: c506f763dff1c1d2c85ac8fe6beb9f40ca4fcde9
This commit is contained in:
fzc8578 2025-01-11 11:49:03 +08:00
parent e7f928adc4
commit 63bb2b7235

View File

@ -127,6 +127,27 @@ def test_base_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
def test_cpm_o_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path="openbmb/MiniCPM-V-2_6")
cpm_o_plugin = get_mm_plugin(name="cpm_o", image_token="<image>")
check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module}
image_seqlen = 64
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", f"<image_id>0</image_id><image>{'<unk>' * image_seqlen}</image>") for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = {
"pixel_values": [[]],
"image_sizes": [[]],
"tgt_sizes": [[]],
"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)]
}
check_inputs["expected_no_mm_inputs"] = {
"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)]
}
_check_plugin(**check_inputs)
def test_llava_plugin(): def test_llava_plugin():
image_seqlen = 576 image_seqlen = 576
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf") tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")