Merge pull request #5970 from hiyouga/hiyouga/fix_beam

[generation] fix vllm v0.6.3

Former-commit-id: 39e330196d8e2774ac43d6f37ccabc0a07efd970
This commit is contained in:
hoshi-hiyouga 2024-11-08 23:58:15 +08:00 committed by GitHub
commit 162f7028fc

View File

@ -115,7 +115,6 @@ class VllmEngine(BaseEngine):
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
@ -126,6 +125,9 @@ class VllmEngine(BaseEngine):
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
@ -149,8 +151,6 @@ class VllmEngine(BaseEngine):
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search,
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens,