mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
plugin test & check
Former-commit-id: 2df2be1c47aded0132b5cc86acd3926dca585bc1
This commit is contained in:
parent
df722bf18e
commit
5e440a467d
@ -697,7 +697,7 @@ def get_mm_plugin(
|
||||
try:
|
||||
require_version("transformers==4.46.0.dev0")
|
||||
except Exception as e:
|
||||
raise ImportError("PixtralPlugin requires transformers==4.46.0.dev0. Please install it first.")
|
||||
raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.")
|
||||
if plugin_class is None:
|
||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||
|
||||
|
@ -178,6 +178,17 @@ def test_paligemma_plugin():
|
||||
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
def test_pixtral_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
||||
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
|
||||
image_slice_heigt, image_slice_width = 2, 2
|
||||
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_heigt).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]"
|
||||
for key, value in message.items()} for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
def test_qwen2_vl_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||
@ -206,3 +217,6 @@ def test_video_llava_plugin():
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pixtral_plugin()
|
Loading…
x
Reference in New Issue
Block a user