diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 70b61444..d3c3f021 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -178,18 +178,20 @@ def test_paligemma_plugin(): check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]} _check_plugin(**check_inputs) + def test_pixtral_plugin(): tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b") pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]") - image_slice_heigt, image_slice_width = 2, 2 + image_slice_height, image_slice_width = 2, 2 check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor} check_inputs["expected_mm_messages"] = [ - {key: value.replace("", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_heigt).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]" + {key: value.replace("", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]" for key, value in message.items()} for message in MM_MESSAGES ] check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) _check_plugin(**check_inputs) - + + def test_qwen2_vl_plugin(): tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>") @@ -217,6 +219,3 @@ def test_video_llava_plugin(): ] check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) _check_plugin(**check_inputs) - -if __name__ == "__main__": - test_pixtral_plugin() \ No newline at end of file