Former-commit-id: 91c672f0147bb6eb998871a42f8a89992af88528
This commit is contained in:
hiyouga 2024-11-23 19:13:32 +00:00
parent 3b91839a55
commit 9efd1fec90

View File

@ -61,7 +61,7 @@ INPUT_IDS = [0, 1, 2, 3, 4]
LABELS = [0, 1, 2, 3, 4] LABELS = [0, 1, 2, 3, 4]
SEQLENS = [1024] BATCH_IDS = [[1] * 1024]
def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
@ -105,7 +105,7 @@ def _check_plugin(
expected_labels, expected_labels,
) )
_is_close( _is_close(
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, SEQLENS, processor), plugin.get_mm_inputs(IMAGES, NO_VIDEOS, IMGLENS, NO_VIDLENS, BATCH_IDS, processor),
expected_mm_inputs, expected_mm_inputs,
) )
# test text_messages # test text_messages
@ -115,7 +115,7 @@ def _check_plugin(
LABELS, LABELS,
) )
_is_close( _is_close(
plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, SEQLENS, processor), plugin.get_mm_inputs(NO_IMAGES, NO_VIDEOS, NO_IMGLENS, NO_VIDLENS, BATCH_IDS, processor),
expected_no_mm_inputs, expected_no_mm_inputs,
) )