mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
tiny fix
Former-commit-id: c122b9f8657d1ca3032b1b6a6cf9cc61f11aaa82
This commit is contained in:
parent
9df7a26e6b
commit
ac33d2f4da
@ -88,11 +88,11 @@ def _check_plugin(
|
|||||||
plugin: "BasePlugin",
|
plugin: "BasePlugin",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: "ProcessorMixin",
|
processor: "ProcessorMixin",
|
||||||
expected_mm_messages: Sequence[Dict[str, str]],
|
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
|
||||||
expected_input_ids: List[int],
|
expected_input_ids: List[int] = INPUT_IDS,
|
||||||
expected_labels: List[int],
|
expected_labels: List[int] = LABELS,
|
||||||
expected_mm_inputs: Dict[str, Any],
|
expected_mm_inputs: Dict[str, Any] = {},
|
||||||
expected_no_mm_inputs: Dict[str, Any],
|
expected_no_mm_inputs: Dict[str, Any] = {},
|
||||||
) -> None:
|
) -> None:
|
||||||
# test mm_messages
|
# test mm_messages
|
||||||
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
|
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
|
||||||
@ -120,11 +120,6 @@ def test_base_plugin():
|
|||||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||||
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
|
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||||
check_inputs["expected_mm_messages"] = MM_MESSAGES
|
|
||||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
|
||||||
check_inputs["expected_labels"] = LABELS
|
|
||||||
check_inputs["expected_mm_inputs"] = {}
|
|
||||||
check_inputs["expected_no_mm_inputs"] = {}
|
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
@ -137,10 +132,7 @@ def test_llava_plugin():
|
|||||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||||
for message in MM_MESSAGES
|
for message in MM_MESSAGES
|
||||||
]
|
]
|
||||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
|
||||||
check_inputs["expected_labels"] = LABELS
|
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
check_inputs["expected_no_mm_inputs"] = {}
|
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
@ -173,8 +165,5 @@ def test_qwen2_vl_plugin():
|
|||||||
}
|
}
|
||||||
for message in MM_MESSAGES
|
for message in MM_MESSAGES
|
||||||
]
|
]
|
||||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
|
||||||
check_inputs["expected_labels"] = LABELS
|
|
||||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||||
check_inputs["expected_no_mm_inputs"] = {}
|
|
||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user