From ae869639dd1fafef8d9c7c738f46c42ff4322ddb Mon Sep 17 00:00:00 2001 From: Kingsley Date: Tue, 15 Oct 2024 17:09:24 +0800 Subject: [PATCH] add extra test for pixtral mm_input Former-commit-id: 0fc949783dec2d038dc3d1bf52051c256b69ac20 --- src/llamafactory/data/mm_plugin.py | 2 -- tests/data/test_mm_plugin.py | 42 ++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 5f128706..f3f6433c 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -513,8 +513,6 @@ class PixtralPlugin(BasePlugin): ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) mm_inputs = self._get_mm_inputs(images, videos, processor) - if mm_inputs.get("image_sizes"): - mm_inputs.pop("image_sizes") return mm_inputs diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index d3c3f021..da64ccc3 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union import pytest import torch @@ -68,12 +68,50 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]: image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") return image_processor(images=IMAGES, return_tensors="pt") +def _is_nested_tensor_list(element): + if not isinstance(element, list): + return False + + for item in element: + if isinstance(item, list): + if not _is_nested_tensor_list(item): + return False + + elif not isinstance(item, torch.Tensor): + return False + + return True + + +def _equal_nested_tensor_list(a: List[List[torch.Tensor]], b: List[List[torch.Tensor]]) -> bool: + if type(a) != type(b): + return False + + if isinstance(a, list) and isinstance(b, list): + if len(a) != len(b): + return False + + for sub_a, sub_b in zip(a, b): + if isinstance(sub_a, list) and isinstance(sub_b, list): + if not _equal_nested_tensor_list(sub_a, sub_b): + return False + elif isinstance(sub_a, torch.Tensor) and isinstance(sub_b, torch.Tensor): + if not torch.equal(sub_a, sub_b): + return False + else: + return False + + return True + + return False def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: assert batch_a.keys() == batch_b.keys() for key in batch_a.keys(): if isinstance(batch_a[key], torch.Tensor): assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5) + elif _is_nested_tensor_list(batch_a[key]) and _is_nested_tensor_list(batch_b[key]): + assert _equal_nested_tensor_list(batch_a[key], batch_b[key]) else: assert batch_a[key] == batch_b[key] @@ -185,7 +223,7 @@ def test_pixtral_plugin(): image_slice_height, image_slice_width = 2, 2 check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor} check_inputs["expected_mm_messages"] = [ - {key: value.replace("", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]" + {key: value.replace("", ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]") for key, value in message.items()} for message in MM_MESSAGES ] check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)