Merge pull request #3527 from zhaonx/dev

"add support for vllm api stop parameter"

Former-commit-id: e7d436403af6ac4c6a33cf36411098a0b0fefce2
This commit is contained in:
hoshi-hiyouga 2024-05-07 00:37:49 +08:00 committed by GitHub
commit b9e167e6ca
3 changed files with 6 additions and 1 deletions

View File

@ -103,6 +103,7 @@ async def create_chat_completion_response(
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
num_return_sequences=request.n, num_return_sequences=request.n,
stop=request.stop
) )
prompt_length, response_length = 0, 0 prompt_length, response_length = 0, 0
@ -155,6 +156,7 @@ async def create_stream_chat_completion_response(
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
stop=request.stop
): ):
if len(new_token) != 0: if len(new_token) != 0:
yield _create_stream_chat_completion_chunk( yield _create_stream_chat_completion_chunk(

View File

@ -1,6 +1,6 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal from typing_extensions import Literal
@ -78,6 +78,7 @@ class ChatCompletionRequest(BaseModel):
n: int = 1 n: int = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stream: bool = False stream: bool = False
stop: Union[Optional[str], List[str]] = None
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):

View File

@ -96,6 +96,7 @@ class VllmEngine(BaseEngine):
repetition_penalty = input_kwargs.pop("repetition_penalty", None) repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None)
generating_args = self.generating_args.copy() generating_args = self.generating_args.copy()
generating_args.update( generating_args.update(
@ -122,6 +123,7 @@ class VllmEngine(BaseEngine):
top_k=generating_args["top_k"], top_k=generating_args["top_k"],
use_beam_search=generating_args["num_beams"] > 1, use_beam_search=generating_args["num_beams"] > 1,
length_penalty=generating_args["length_penalty"], length_penalty=generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=generating_args["max_new_tokens"], max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True, skip_special_tokens=True,