diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index cb8576ed..3ccba7f4 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -88,11 +88,11 @@ def _check_plugin( plugin: "BasePlugin", tokenizer: "PreTrainedTokenizer", processor: "ProcessorMixin", - expected_mm_messages: Sequence[Dict[str, str]], - expected_input_ids: List[int], - expected_labels: List[int], - expected_mm_inputs: Dict[str, Any], - expected_no_mm_inputs: Dict[str, Any], + expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES, + expected_input_ids: List[int] = INPUT_IDS, + expected_labels: List[int] = LABELS, + expected_mm_inputs: Dict[str, Any] = {}, + expected_no_mm_inputs: Dict[str, Any] = {}, ) -> None: # test mm_messages assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages @@ -120,11 +120,6 @@ def test_base_plugin(): tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA) base_plugin = get_mm_plugin(name="base", image_token="") check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor} - check_inputs["expected_mm_messages"] = MM_MESSAGES - check_inputs["expected_input_ids"] = INPUT_IDS - check_inputs["expected_labels"] = LABELS - check_inputs["expected_mm_inputs"] = {} - check_inputs["expected_no_mm_inputs"] = {} _check_plugin(**check_inputs) @@ -137,10 +132,7 @@ def test_llava_plugin(): {key: value.replace("", "" * image_seqlen) for key, value in message.items()} for message in MM_MESSAGES ] - check_inputs["expected_input_ids"] = INPUT_IDS - check_inputs["expected_labels"] = LABELS check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) - check_inputs["expected_no_mm_inputs"] = {} _check_plugin(**check_inputs) @@ -173,8 +165,5 @@ def test_qwen2_vl_plugin(): } for message in MM_MESSAGES ] - check_inputs["expected_input_ids"] = INPUT_IDS - check_inputs["expected_labels"] = LABELS check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) - check_inputs["expected_no_mm_inputs"] = {} _check_plugin(**check_inputs)