diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 406307d0..92238b35 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -332,7 +332,9 @@ def test_qwen2_omni_plugin(): image_seqlen, audio_seqlen = 4, 2 tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B") qwen2_omni_plugin = get_mm_plugin( - name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" + name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>", + vision_bos_token="<|vision_bos|>", vision_eos_token="<|vision_eos|>", + audio_bos_token="<|audio_bos|>", audio_eos_token="<|audio_eos|>" ) check_inputs = {"plugin": qwen2_omni_plugin, **tokenizer_module} check_inputs["expected_mm_messages"] = [