mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
add extra test for pixtral mm_input
Former-commit-id: 0fc949783dec2d038dc3d1bf52051c256b69ac20
This commit is contained in:
parent
9a9716c228
commit
ae869639dd
@ -513,8 +513,6 @@ class PixtralPlugin(BasePlugin):
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
if mm_inputs.get("image_sizes"):
|
||||
mm_inputs.pop("image_sizes")
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -68,12 +68,50 @@ def _get_mm_inputs(processor: "ProcessorMixin") -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
return image_processor(images=IMAGES, return_tensors="pt")
|
||||
|
||||
def _is_nested_tensor_list(element):
|
||||
if not isinstance(element, list):
|
||||
return False
|
||||
|
||||
for item in element:
|
||||
if isinstance(item, list):
|
||||
if not _is_nested_tensor_list(item):
|
||||
return False
|
||||
|
||||
elif not isinstance(item, torch.Tensor):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _equal_nested_tensor_list(a: List[List[torch.Tensor]], b: List[List[torch.Tensor]]) -> bool:
|
||||
if type(a) != type(b):
|
||||
return False
|
||||
|
||||
if isinstance(a, list) and isinstance(b, list):
|
||||
if len(a) != len(b):
|
||||
return False
|
||||
|
||||
for sub_a, sub_b in zip(a, b):
|
||||
if isinstance(sub_a, list) and isinstance(sub_b, list):
|
||||
if not _equal_nested_tensor_list(sub_a, sub_b):
|
||||
return False
|
||||
elif isinstance(sub_a, torch.Tensor) and isinstance(sub_b, torch.Tensor):
|
||||
if not torch.equal(sub_a, sub_b):
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||
assert batch_a.keys() == batch_b.keys()
|
||||
for key in batch_a.keys():
|
||||
if isinstance(batch_a[key], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||
elif _is_nested_tensor_list(batch_a[key]) and _is_nested_tensor_list(batch_b[key]):
|
||||
assert _equal_nested_tensor_list(batch_a[key], batch_b[key])
|
||||
else:
|
||||
assert batch_a[key] == batch_b[key]
|
||||
|
||||
@ -185,7 +223,7 @@ def test_pixtral_plugin():
|
||||
image_slice_height, image_slice_width = 2, 2
|
||||
check_inputs = {"plugin": pixtral_plugin, "tokenizer": tokenizer, "processor": processor}
|
||||
check_inputs["expected_mm_messages"] = [
|
||||
{key: value.replace("<image>", "{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]"
|
||||
{key: value.replace("<image>", ("{}[IMG_BREAK]".format("[IMG]" * image_slice_width) * image_slice_height).rsplit("[IMG_BREAK]", 1)[0] + "[IMG_END]")
|
||||
for key, value in message.items()} for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
|
Loading…
x
Reference in New Issue
Block a user