Update hf_engine.py

Former-commit-id: 6e212fdab5f48c955db250ecfc197b89f8856e4b
This commit is contained in:
hoshi-hiyouga 2024-10-29 22:00:59 +08:00 committed by GitHub
parent eca50b89a2
commit 90cd3538de

View File

@ -165,17 +165,12 @@ class HuggingfaceEngine(BaseEngine):
) )
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
for key, value in mm_inputs.items(): for key, value in mm_inputs.items():
value = ( if isinstance(value, list) and all(isinstance(v, torch.Tensor for v in value)): # for pixtral inputs
value value = torch.stack(value) # assume they have same sizes
if isinstance(value, torch.Tensor) elif not isinstance(value, torch.Tensor):
else ( value = torch.tensor(value)
torch.stack(value)
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value)
else torch.tensor(value)
)
)
gen_kwargs[key] = value.to(model.device) gen_kwargs[key] = value.to(model.device)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length