mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Update hf_engine.py
Former-commit-id: 6e212fdab5f48c955db250ecfc197b89f8856e4b
This commit is contained in:
parent
eca50b89a2
commit
90cd3538de
@ -165,17 +165,12 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
|
||||||
|
|
||||||
for key, value in mm_inputs.items():
|
for key, value in mm_inputs.items():
|
||||||
value = (
|
if isinstance(value, list) and all(isinstance(v, torch.Tensor for v in value)): # for pixtral inputs
|
||||||
value
|
value = torch.stack(value) # assume they have same sizes
|
||||||
if isinstance(value, torch.Tensor)
|
elif not isinstance(value, torch.Tensor):
|
||||||
else (
|
value = torch.tensor(value)
|
||||||
torch.stack(value)
|
|
||||||
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value)
|
|
||||||
else torch.tensor(value)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
gen_kwargs[key] = value.to(model.device)
|
gen_kwargs[key] = value.to(model.device)
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
|
Loading…
x
Reference in New Issue
Block a user