Former-commit-id: c122b9f8657d1ca3032b1b6a6cf9cc61f11aaa82
This commit is contained in:
hiyouga 2024-09-05 02:16:49 +08:00
parent 9df7a26e6b
commit ac33d2f4da

View File

@ -88,11 +88,11 @@ def _check_plugin(
plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin",
expected_mm_messages: Sequence[Dict[str, str]],
expected_input_ids: List[int],
expected_labels: List[int],
expected_mm_inputs: Dict[str, Any],
expected_no_mm_inputs: Dict[str, Any],
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
expected_input_ids: List[int] = INPUT_IDS,
expected_labels: List[int] = LABELS,
expected_mm_inputs: Dict[str, Any] = {},
expected_no_mm_inputs: Dict[str, Any] = {},
) -> None:
# test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
@ -120,11 +120,6 @@ def test_base_plugin():
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>")
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = MM_MESSAGES
check_inputs["expected_input_ids"] = INPUT_IDS
check_inputs["expected_labels"] = LABELS
check_inputs["expected_mm_inputs"] = {}
check_inputs["expected_no_mm_inputs"] = {}
_check_plugin(**check_inputs)
@ -137,10 +132,7 @@ def test_llava_plugin():
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
for message in MM_MESSAGES
]
check_inputs["expected_input_ids"] = INPUT_IDS
check_inputs["expected_labels"] = LABELS
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_no_mm_inputs"] = {}
_check_plugin(**check_inputs)
@ -173,8 +165,5 @@ def test_qwen2_vl_plugin():
}
for message in MM_MESSAGES
]
check_inputs["expected_input_ids"] = INPUT_IDS
check_inputs["expected_labels"] = LABELS
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
check_inputs["expected_no_mm_inputs"] = {}
_check_plugin(**check_inputs)