From 5e440a467dcb6e9ebbc1408418611046966be2b1 Mon Sep 17 00:00:00 2001 From: KUANGDD Date: Tue, 15 Oct 2024 12:12:46 +0800 Subject: [PATCH] plugin test & check Former-commit-id: 2df2be1c47aded0132b5cc86acd3926dca585bc1 --- src/llamafactory/data/mm_plugin.py | 2 +- tests/data/test_mm_plugin.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 9d81848b..5f128706 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -697,7 +697,7 @@ def get_mm_plugin( try: require_version("transformers==4.46.0.dev0") except Exception as e: - raise ImportError("PixtralPlugin requires transformers==4.46.0.dev0. Please install it first.") + raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.") if plugin_class is None: raise ValueError("Multimodal plugin `{}` not found.".format(name)) diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 75541000..70b61444 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -178,7 +178,18 @@ def test_paligemma_plugin(): check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]} _check_plugin(**check_inputs) - +def test_pixtral_plugin(): + tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b") + pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]") + image_slice_heigt, image_slice_width = 2, 2 + check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor} + check_inputs["expected_mm_messages"] = [ + {key: value.replace("", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_heigt).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]" + for key, value in message.items()} for message in MM_MESSAGES + ] + check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) + _check_plugin(**check_inputs) + def test_qwen2_vl_plugin(): tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct") qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>") @@ -206,3 +217,6 @@ def test_video_llava_plugin(): ] check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) _check_plugin(**check_inputs) + +if __name__ == "__main__": + test_pixtral_plugin() \ No newline at end of file