mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
fix bug for webui infer
Former-commit-id: 7ea29bbfe03550ac59ff9cb01a4bc41c95ac3adf
This commit is contained in:
parent
bcb40fddc0
commit
9c4941a1ea
@ -513,6 +513,12 @@ class PixtralPlugin(BasePlugin):
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
# hack for hf engine
|
||||
if mm_inputs.get("pixel_values") and len(mm_inputs.get("pixel_values")[0]) == 1:
|
||||
mm_inputs["pixel_values"] = mm_inputs["pixel_values"][0][0].unsqueeze(0)
|
||||
|
||||
if mm_inputs.get("image_sizes"):
|
||||
del mm_inputs["image_sizes"]
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
@ -110,8 +110,6 @@ def _is_close(batch_a: Dict[str, Any], batch_b: Dict[str, Any]) -> None:
|
||||
for key in batch_a.keys():
|
||||
if isinstance(batch_a[key], torch.Tensor):
|
||||
assert torch.allclose(batch_a[key], batch_b[key], rtol=1e-4, atol=1e-5)
|
||||
elif _is_nested_tensor_list(batch_a[key]) and _is_nested_tensor_list(batch_b[key]):
|
||||
assert _equal_nested_tensor_list(batch_a[key], batch_b[key])
|
||||
else:
|
||||
assert batch_a[key] == batch_b[key]
|
||||
|
||||
@ -227,6 +225,9 @@ def test_pixtral_plugin():
|
||||
for key, value in message.items()} for message in MM_MESSAGES
|
||||
]
|
||||
check_inputs["expected_mm_inputs"] = _get_mm_inputs(processor)
|
||||
# TODO works needed for pixtral plugin test & hack hf engine input below for now
|
||||
check_inputs["expected_mm_inputs"].pop("image_sizes")
|
||||
check_inputs["expected_mm_inputs"]["pixel_values"] = check_inputs["expected_mm_inputs"]["pixel_values"][0][0].unsqueeze(0)
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user