[inference] fix hf_engine (#7120)

Former-commit-id: f8cf5319cb5d6e06a1b0d8b8db2b678627f2271e
This commit is contained in:
hoshi-hiyouga 2025-03-01 05:22:49 +08:00 committed by GitHub
parent e62dae37fe
commit 585c475f71

View File

@ -178,9 +178,11 @@ class HuggingfaceEngine(BaseEngine):
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes value = torch.stack(value) # assume they have same sizes
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs elif (
isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
): # for minicpmv inputs
value = torch.stack([torch.stack(v) for v in value]) value = torch.stack([torch.stack(v) for v in value])
elif not isinstance(value, torch.Tensor): elif not isinstance(value, torch.Tensor):
value = torch.tensor(value) value = torch.tensor(value)