mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
parent
8ecf606230
commit
88745c9bb5
@ -83,6 +83,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
prompt_length = len(prompt_ids)
|
||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
|
||||
|
||||
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||
@ -136,6 +137,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
gen_kwargs = dict(
|
||||
inputs=inputs,
|
||||
attention_mask=attention_mask,
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor(),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user