mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
[inference] fix hf_engine (#7120)
Former-commit-id: 1036311826a61fed2346a261c8a060c355778318
This commit is contained in:
parent
54a090079c
commit
ee1b580328
@ -178,9 +178,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
||||||
for key, value in mm_inputs.items():
|
for key, value in mm_inputs.items():
|
||||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
|
if isinstance(value, list) and isinstance(value[0], torch.Tensor): # for pixtral inputs
|
||||||
value = torch.stack(value) # assume they have same sizes
|
value = torch.stack(value) # assume they have same sizes
|
||||||
elif isinstance(value, list) and all(isinstance(v, list) for v in value): # for minicpmv inputs
|
elif (
|
||||||
|
isinstance(value, list) and isinstance(value[0], list) and isinstance(value[0][0], torch.Tensor)
|
||||||
|
): # for minicpmv inputs
|
||||||
value = torch.stack([torch.stack(v) for v in value])
|
value = torch.stack([torch.stack(v) for v in value])
|
||||||
elif not isinstance(value, torch.Tensor):
|
elif not isinstance(value, torch.Tensor):
|
||||||
value = torch.tensor(value)
|
value = torch.tensor(value)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user