diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index ad0e90fe..28e6a409 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -83,6 +83,7 @@ class HuggingfaceEngine(BaseEngine): prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) + attention_mask = torch.ones_like(inputs, dtype=torch.bool) do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) temperature: Optional[float] = input_kwargs.pop("temperature", None) @@ -136,6 +137,7 @@ class HuggingfaceEngine(BaseEngine): gen_kwargs = dict( inputs=inputs, + attention_mask=attention_mask, generation_config=GenerationConfig(**generating_args), logits_processor=get_logits_processor(), )