mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
parent
8ecf606230
commit
88745c9bb5
@ -83,6 +83,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
inputs = torch.tensor([prompt_ids], device=model.device)
|
inputs = torch.tensor([prompt_ids], device=model.device)
|
||||||
|
attention_mask = torch.ones_like(inputs, dtype=torch.bool)
|
||||||
|
|
||||||
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
|
||||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||||
@ -136,6 +137,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
gen_kwargs = dict(
|
gen_kwargs = dict(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
generation_config=GenerationConfig(**generating_args),
|
generation_config=GenerationConfig(**generating_args),
|
||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user