From ee1b580328593c6f7049b34a07245e425bb18d42 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sat, 1 Mar 2025 05:22:49 +0800 Subject: [PATCH] [inference] fix hf_engine (#7120) Former-commit-id: 1036311826a61fed2346a261c8a060c355778318 --- src/llamafactory/chat/hf_engine.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 260b1dc5..897017e0 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -178,9 +178,11 @@ class HuggingfaceEngine(BaseEngine): mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor) for key, value in mm_inputs.items(): - if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs + if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs value = torch.stack(value) # assume they have same sizes - elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs + elif ( + isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor) + ): # for minicpmv inputs value = torch.stack([torch.stack(v) for v in value]) elif not isinstance(value, torch.Tensor): value = torch.tensor(value)