mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
fix some
Former-commit-id: 25641af04c98e902ff024c8fa7b4c2c36ed797de
This commit is contained in:
parent
5e440a467d
commit
a3f37777c1
@ -178,18 +178,20 @@ def test_paligemma_plugin():
|
|||||||
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
|
check_inputs["expected_no_mm_inputs"] = {"token_type_ids": [[1] * 1024]}
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
def test_pixtral_plugin():
|
def test_pixtral_plugin():
|
||||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
||||||
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
|
pixtral_plugin = get_mm_plugin(name="pixtral", image_token="[IMG]")
|
||||||
image_slice_heigt, image_slice_width = 2, 2
|
image_slice_height, image_slice_width = 2, 2
|
||||||
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
|
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||||
check_inputs["expected_mm_messages"] = [
|
check_inputs["expected_mm_messages"] = [
|
||||||
{key: value.replace("<image>", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_heigt).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]"
|
{key: value.replace("<image>", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]"
|
||||||
for key, value in message.items()} for message in MM_MESSAGES
|
for key, value in message.items()} for message in MM_MESSAGES
|
||||||
]
|
]
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
def test_qwen2_vl_plugin():
|
def test_qwen2_vl_plugin():
|
||||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2-VL-7B-Instruct")
|
||||||
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
qwen2_vl_plugin = get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>")
|
||||||
@ -217,6 +219,3 @@ def test_video_llava_plugin():
|
|||||||
]
|
]
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_pixtral_plugin()
|
|
Loading…
x
Reference in New Issue
Block a user