diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 6cce2c4c..6187fa5e 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -139,7 +139,14 @@ def test_llava_next_plugin(): tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/llava-v1.6-vicuna-7b-hf") llava_next_plugin = get_mm_plugin(name="llava_next", image_token="") check_inputs = {"plugin": llava_next_plugin, "tokenizer": tokenizer, "processor": processor} - check_inputs["expected_mm_messages"] = MM_MESSAGES + image_seqlen = 1176 + check_inputs["expected_mm_messages"] = [ + { + key: value.replace("", "" * image_seqlen) + for key, value in message.items() + } + for message in MM_MESSAGES + ] check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) _check_plugin(**check_inputs) @@ -148,7 +155,14 @@ def test_llava_next_video_plugin(): tokenizer, processor = _load_tokenizer_module(model_name_or_path="llava-hf/LLaVA-NeXT-Video-7B-hf") llava_next_video_plugin = get_mm_plugin(name="llava_next_video", image_token="", video_token="