[data] fix mllama (#7053)

* fix mllama

* fix test

Former-commit-id: 76314e6ad1
This commit is contained in:
hoshi-hiyouga
2025-02-24 22:05:38 +08:00
committed by GitHub
parent ca78ba964d
commit dca5fe14c2
2 changed files with 85 additions and 66 deletions

View File

@@ -103,15 +103,17 @@ def _check_plugin(
expected_no_mm_inputs: Dict[str, Any] = {},
) -> None:
# test mm_messages
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
if plugin.__class__.__name__ != "BasePlugin":
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
expected_input_ids,
expected_labels,
)
_is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
expected_mm_inputs,
)
# test text_messages
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
@@ -128,7 +130,7 @@ def _check_plugin(
def test_base_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
base_plugin = get_mm_plugin(name="base", image_token="<image>")
base_plugin = get_mm_plugin(name="base")
check_inputs = {"plugin": base_plugin, **tokenizer_module}
_check_plugin(**check_inputs)