Former-commit-id: 91611d68c47dada2b74a141a5842dd289e46d356
This commit is contained in:
hiyouga 2024-06-04 00:21:50 +08:00
parent 8ecf606230
commit 88745c9bb5

View File

@ -83,6 +83,7 @@ class HuggingfaceEngine(BaseEngine):
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
@ -136,6 +137,7 @@ class HuggingfaceEngine(BaseEngine):
gen_kwargs = dict(
inputs=inputs,
attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
)