Merge pull request #3527 from zhaonx/dev

"add support for vllm api stop parameter"

Former-commit-id: bcf7ec5ceb13920786831166861f18edd2506bb1
This commit is contained in:
hoshi-hiyouga 2024-05-07 00:37:49 +08:00 committed by GitHub
commit c198db4db2
3 changed files with 6 additions and 1 deletions

View File

@ -103,6 +103,7 @@ async def create_chat_completion_response(
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
num_return_sequences=request.n, num_return_sequences=request.n,
stop=request.stop
) )
prompt_length, response_length = 0, 0 prompt_length, response_length = 0, 0
@ -155,6 +156,7 @@ async def create_stream_chat_completion_response(
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
stop=request.stop
): ):
if len(new_token) != 0: if len(new_token) != 0:
yield _create_stream_chat_completion_chunk( yield _create_stream_chat_completion_chunk(

View File

@ -1,6 +1,6 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal from typing_extensions import Literal
@ -78,6 +78,7 @@ class ChatCompletionRequest(BaseModel):
n: int = 1 n: int = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stream: bool = False stream: bool = False
stop: Union[Optional[str], List[str]] = None
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):

View File

@ -96,6 +96,7 @@ class VllmEngine(BaseEngine):
repetition_penalty = input_kwargs.pop("repetition_penalty", None) repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None)
generating_args = self.generating_args.copy() generating_args = self.generating_args.copy()
generating_args.update( generating_args.update(
@ -122,6 +123,7 @@ class VllmEngine(BaseEngine):
top_k=generating_args["top_k"], top_k=generating_args["top_k"],
use_beam_search=generating_args["num_beams"] > 1, use_beam_search=generating_args["num_beams"] > 1,
length_penalty=generating_args["length_penalty"], length_penalty=generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=generating_args["max_new_tokens"], max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True, skip_special_tokens=True,