mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
Merge pull request #5970 from hiyouga/hiyouga/fix_beam
[generation] fix vllm v0.6.3 Former-commit-id: 39e330196d8e2774ac43d6f37ccabc0a07efd970
This commit is contained in:
commit
162f7028fc
@ -115,7 +115,6 @@ class VllmEngine(BaseEngine):
|
|||||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
|
|
||||||
use_beam_search: bool = self.generating_args["num_beams"] > 1
|
|
||||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||||
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
top_p: Optional[float] = input_kwargs.pop("top_p", None)
|
||||||
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
top_k: Optional[float] = input_kwargs.pop("top_k", None)
|
||||||
@ -126,6 +125,9 @@ class VllmEngine(BaseEngine):
|
|||||||
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
|
||||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||||
|
|
||||||
|
if length_penalty is not None:
|
||||||
|
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
|
||||||
|
|
||||||
if "max_new_tokens" in self.generating_args:
|
if "max_new_tokens" in self.generating_args:
|
||||||
max_tokens = self.generating_args["max_new_tokens"]
|
max_tokens = self.generating_args["max_new_tokens"]
|
||||||
elif "max_length" in self.generating_args:
|
elif "max_length" in self.generating_args:
|
||||||
@ -149,8 +151,6 @@ class VllmEngine(BaseEngine):
|
|||||||
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
temperature=temperature if temperature is not None else self.generating_args["temperature"],
|
||||||
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
||||||
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
||||||
use_beam_search=use_beam_search,
|
|
||||||
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
|
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user