From bf0286e1e3d97c14c1062eab9260f202c151ee46 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 6 May 2025 15:39:13 +0200 Subject: [PATCH] [misc] fix qwen2 omni (#7962) --- src/llamafactory/chat/hf_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index b2c03c60..5ed47886 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine): ) prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) - attention_mask = torch.ones_like(inputs, dtype=torch.bool) + attention_mask = torch.ones_like(inputs, dtype=torch.long) do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) temperature: Optional[float] = input_kwargs.pop("temperature", None)