From 97f4451912b05a73b0c30ce8a4112196524305f2 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Fri, 8 Nov 2024 23:49:16 +0800 Subject: [PATCH] fix #5966 Former-commit-id: 8f3a32286ebcfb3234e3981db2292dd165b1568d --- src/llamafactory/chat/vllm_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 5f6612be..21f09a58 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -115,7 +115,6 @@ class VllmEngine(BaseEngine): prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) prompt_length = len(prompt_ids) - use_beam_search: bool = self.generating_args["num_beams"] > 1 temperature: Optional[float] = input_kwargs.pop("temperature", None) top_p: Optional[float] = input_kwargs.pop("top_p", None) top_k: Optional[float] = input_kwargs.pop("top_k", None) @@ -126,6 +125,9 @@ class VllmEngine(BaseEngine): max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) + if length_penalty is not None: + logger.warning_rank0("Length penalty is not supported by the vllm engine yet.") + if "max_new_tokens" in self.generating_args: max_tokens = self.generating_args["max_new_tokens"] elif "max_length" in self.generating_args: @@ -149,8 +151,6 @@ class VllmEngine(BaseEngine): temperature=temperature if temperature is not None else self.generating_args["temperature"], top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 top_k=top_k if top_k is not None else self.generating_args["top_k"], - use_beam_search=use_beam_search, - length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"], stop=stop, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, max_tokens=max_tokens,