From 88745c9bb5a3b553c58e8d640f22695c97db8c50 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 4 Jun 2024 00:21:50 +0800 Subject: [PATCH] fix #3873 Former-commit-id: 91611d68c47dada2b74a141a5842dd289e46d356 --- src/llamafactory/chat/hf_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index ad0e90fe..28e6a409 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -83,6 +83,7 @@ class HuggingfaceEngine(BaseEngine): prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) + attention_mask = torch.ones_like(inputs, dtype=torch.bool) do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) temperature: Optional[float] = input_kwargs.pop("temperature", None) @@ -136,6 +137,7 @@ class HuggingfaceEngine(BaseEngine): gen_kwargs = dict( inputs=inputs, + attention_mask=attention_mask, generation_config=GenerationConfig(**generating_args), logits_processor=get_logits_processor(), )