mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
add cpm_o test
Former-commit-id: c506f763dff1c1d2c85ac8fe6beb9f40ca4fcde9
This commit is contained in:
parent
e7f928adc4
commit
63bb2b7235
@ -127,6 +127,27 @@ def test_base_plugin():
|
|||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpm_o_plugin():
|
||||||
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="openbmb/MiniCPM-V-2_6")
|
||||||
|
cpm_o_plugin = get_mm_plugin(name="cpm_o", image_token="<image>")
|
||||||
|
check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module}
|
||||||
|
image_seqlen = 64
|
||||||
|
check_inputs["expected_mm_messages"] = [
|
||||||
|
{key: value.replace("<image>", f"<image_id>0</image_id><image>{'<unk>' * image_seqlen}</image>") for key, value in message.items()}
|
||||||
|
for message in MM_MESSAGES
|
||||||
|
]
|
||||||
|
check_inputs["expected_mm_inputs"] = {
|
||||||
|
"pixel_values": [[]],
|
||||||
|
"image_sizes": [[]],
|
||||||
|
"tgt_sizes": [[]],
|
||||||
|
"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)]
|
||||||
|
}
|
||||||
|
check_inputs["expected_no_mm_inputs"] = {
|
||||||
|
"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)]
|
||||||
|
}
|
||||||
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
def test_llava_plugin():
|
def test_llava_plugin():
|
||||||
image_seqlen = 576
|
image_seqlen = 576
|
||||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user