mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
plugin test & check
Former-commit-id: 2df2be1c47aded0132b5cc86acd3926dca585bc1
This commit is contained in:
parent
df722bf18e
commit
5e440a467d
@ -697,7 +697,7 @@ def get_mm_plugin(
|
|||||||
try:
|
try:
|
||||||
require_version("transformers==4.46.0.dev0")
|
require_version("transformers==4.46.0.dev0")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ImportError("PixtralPlugin requires transformers==4.46.0.dev0. Please install it first.")
|
raise ImportError("PixtralPlugin requires transformers>=4.46.0.dev0. Please install it first.")
|
||||||
if plugin_class is None:
|
if plugin_class is None:
|
||||||
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
raise ValueError("Multimodal plugin `{}` not found.".format(name))
|
||||||
|
|
||||||
|
@ -178,7 +178,18 @@ def test_paligemma_plugin():
|
|||||||
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
|
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
def test_pixtral_plugin():
|
||||||
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
||||||
|
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
|
||||||
|
image_slice_heigt, image_slice_width = 2, 2
|
||||||
|
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||||
|
check_inputs["expected_mm_messages"] = [
|
||||||
|
{key: value.replace("<image>", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_heigt).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]"
|
||||||
|
for key, value in message.items()} for message in MM_MESSAGES
|
||||||
|
]
|
||||||
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
def test_qwen2_vl_plugin():
|
def test_qwen2_vl_plugin():
|
||||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||||
@ -206,3 +217,6 @@ def test_video_llava_plugin():
|
|||||||
]
|
]
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_pixtral_plugin()
|
Loading…
x
Reference in New Issue
Block a user