mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 11:42:49 +08:00
tiny fix
Former-commit-id: c122b9f8657d1ca3032b1b6a6cf9cc61f11aaa82
This commit is contained in:
parent
9df7a26e6b
commit
ac33d2f4da
@ -88,11 +88,11 @@ def _check_plugin(
|
||||
plugin: "BasePlugin",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: "ProcessorMixin",
|
||||
expected_mm_messages: Sequence[Dict[str, str]],
|
||||
expected_input_ids: List[int],
|
||||
expected_labels: List[int],
|
||||
expected_mm_inputs: Dict[str, Any],
|
||||
expected_no_mm_inputs: Dict[str, Any],
|
||||
expected_mm_messages: Sequence[Dict[str, str]] = MM_MESSAGES,
|
||||
expected_input_ids: List[int] = INPUT_IDS,
|
||||
expected_labels: List[int] = LABELS,
|
||||
expected_mm_inputs: Dict[str, Any] = {},
|
||||
expected_no_mm_inputs: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
# test mm_messages
|
||||
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, processor) == expected_mm_messages
|
||||
@ -120,11 +120,6 @@ def test_base_plugin():
|
||||
tokenizer, processor = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||
check_inputs = {"plugin": base_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = MM_MESSAGES
|
||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
||||
check_inputs["expected_labels"] = LABELS
|
||||
check_inputs["expected_mm_inputs"] = {}
|
||||
check_inputs["expected_no_mm_inputs"] = {}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@ -137,10 +132,7 @@ def test_llava_plugin():
|
||||
{key: value.replace("<image>", "<image>" * image_seqlen) for key, value in message.items()}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
||||
check_inputs["expected_labels"] = LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_no_mm_inputs"] = {}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
@ -173,8 +165,5 @@ def test_qwen2_vl_plugin():
|
||||
}
|
||||
for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_input_ids"] = INPUT_IDS
|
||||
check_inputs["expected_labels"] = LABELS
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
check_inputs["expected_no_mm_inputs"] = {}
|
||||
_check_plugin(**check_inputs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user