diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index de7e362a..f67737f5 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -513,6 +513,12 @@ class PixtralPlugin(BasePlugin): ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) mm_inputs = self._get_mm_inputs(images, videos, processor) + # hack for hf engine + if mm_inputs.get("pixel_values") and len(mm_inputs.get("pixel_values")[0]) == 1: + mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0][0].unsqueeze(0) + + if mm_inputs.get("image_sizes"): + del mm_inputs["image_sizes"] return mm_inputs diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index da64ccc3..32b89f7f 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -110,8 +110,6 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None: for key in batch_a.keys(): if isinstance(batch_a[key], torch.Tensor): assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5) - elif _is_nested_tensor_list(batch_a[key]) and _is_nested_tensor_list(batch_b[key]): - assert _equal_nested_tensor_list(batch_a[key], batch_b[key]) else: assert batch_a[key] == batch_b[key] @@ -227,6 +225,9 @@ def test_pixtral_plugin(): for key, value in message.items()} for message in MM_MESSAGES ] check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor) + # TODO works needed for pixtral plugin test & hack hf engine input below for now + check_inputs["expected_mm_inputs"].pop("image_sizes") + check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0][0].unsqueeze(0) _check_plugin(**check_inputs)