add extra test for pixtral mm_input

Former-commit-id: 0fc949783dec2d038dc3d1bf52051c256b69ac20
This commit is contained in:
Kingsley 2024-10-15 17:09:24 +08:00
parent 9a9716c228
commit ae869639dd
2 changed files with 40 additions and 4 deletions

View File

@ -513,8 +513,6 @@ class PixtralPlugin(BasePlugin):
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if mm_inputs.get("image_sizes"):
mm_inputs.pop("image_sizes")
return mm_inputs

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union
import pytest
import torch
@ -68,12 +68,50 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
return image_processor(images=IMAGES, return_tensors="pt")
def _is_nested_tensor_list(element):
if not isinstance(element, list):
return False
for item in element:
if isinstance(item, list):
if not _is_nested_tensor_list(item):
return False
elif not isinstance(item, torch.Tensor):
return False
return True
def _equal_nested_tensor_list(a: List[List[torch.Tensor]], b: List[List[torch.Tensor]]) -> bool:
if type(a) != type(b):
return False
if isinstance(a, list) and isinstance(b, list):
if len(a) != len(b):
return False
for sub_a, sub_b in zip(a, b):
if isinstance(sub_a, list) and isinstance(sub_b, list):
if not _equal_nested_tensor_list(sub_a, sub_b):
return False
elif isinstance(sub_a, torch.Tensor) and isinstance(sub_b, torch.Tensor):
if not torch.equal(sub_a, sub_b):
return False
else:
return False
return True
return False
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
assert batch_a.keys() == batch_b.keys()
for key in batch_a.keys():
if isinstance(batch_a[key], torch.Tensor):
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
elif _is_nested_tensor_list(batch_a[key]) and _is_nested_tensor_list(batch_b[key]):
assert _equal_nested_tensor_list(batch_a[key], batch_b[key])
else:
assert batch_a[key] == batch_b[key]
@ -185,7 +223,7 @@ def test_pixtral_plugin():
image_slice_height, image_slice_width = 2, 2
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
check_inputs["expected_mm_messages"] = [
{key: value.replace("<image>", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]"
{key: value.replace("<image>", ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]")
for key, value in message.items()} for message in MM_MESSAGES
]
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)