From 90cd3538decdd35d30bcd1c7131f2cc65ec90ad7 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 29 Oct 2024 22:00:59 +0800 Subject: [PATCH] Update hf_engine.py Former-commit-id: 6e212fdab5f48c955db250ecfc197b89f8856e4b --- src/llamafactory/chat/hf_engine.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 53fb666a..87d9c451 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -165,17 +165,12 @@ class HuggingfaceEngine(BaseEngine): ) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) - for key, value in mm_inputs.items(): - value = ( - value - if isinstance(value, torch.Tensor) - else ( - torch.stack(value) - if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value) - else torch.tensor(value) - ) - ) + if isinstance(value, list) and all(isinstance(v, torch.Tensor for v in value)): # for pixtral inputs + value = torch.stack(value) # assume they have same sizes + elif not isinstance(value, torch.Tensor): + value = torch.tensor(value) + gen_kwargs[key] = value.to(model.device) return gen_kwargs, prompt_length