mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
add cpm_o test
Former-commit-id: c506f763dff1c1d2c85ac8fe6beb9f40ca4fcde9
This commit is contained in:
parent
e7f928adc4
commit
63bb2b7235
@ -127,6 +127,27 @@ def test_base_plugin():
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_cpm_o_plugin():
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="openbmb/MiniCPM-V-2_6")
|
||||
cpm_o_plugin = get_mm_plugin(name="cpm_o", image_token="<image>")
|
||||
check_inputs = {"plugin": cpm_o_plugin, **tokenizer_module}
|
||||
image_seqlen = 64
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", f"<image_id>0</image_id><image>{'<unk>' * image_seqlen}</image>") for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = {
|
||||
"pixel_values": [[]],
|
||||
"image_sizes": [[]],
|
||||
"tgt_sizes": [[]],
|
||||
"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)]
|
||||
}
|
||||
check_inputs["expected_no_mm_inputs"] = {
|
||||
"image_bound": [torch.tensor([], dtype=torch.int64).reshape(0,2)]
|
||||
}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
def test_llava_plugin():
|
||||
image_seqlen = 576
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||
|
Loading…
x
Reference in New Issue
Block a user