"add support for vllm api stop parameter"

Former-commit-id: 42edc81585bc7170ab4e4871ad12094079e89bc9
This commit is contained in:
zhaonx 2024-04-30 17:17:09 +08:00
parent 35917001b1
commit 4a0aab86f1
4 changed files with 12 additions and 3 deletions

View File

@ -141,6 +141,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
num_return_sequences=request.n, num_return_sequences=request.n,
stop=request.stop
) )
prompt_length, response_length = 0, 0 prompt_length, response_length = 0, 0
@ -193,6 +194,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
max_new_tokens=request.max_tokens, max_new_tokens=request.max_tokens,
stop=request.stop
): ):
if len(new_token) == 0: if len(new_token) == 0:
continue continue

View File

@ -1,6 +1,6 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal from typing_extensions import Literal
@ -78,6 +78,7 @@ class ChatCompletionRequest(BaseModel):
n: int = 1 n: int = 1
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
stream: bool = False stream: bool = False
stop: Union[Optional[str], List[str]] = None
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):

View File

@ -96,6 +96,7 @@ class VllmEngine(BaseEngine):
repetition_penalty = input_kwargs.pop("repetition_penalty", None) repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None)
generating_args = self.generating_args.copy() generating_args = self.generating_args.copy()
generating_args.update( generating_args.update(
@ -105,6 +106,7 @@ class VllmEngine(BaseEngine):
top_k=top_k or generating_args["top_k"], top_k=top_k or generating_args["top_k"],
num_return_sequences=num_return_sequences or 1, num_return_sequences=num_return_sequences or 1,
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
stop=stop or generating_args["stop"]
) )
) )
@ -125,6 +127,7 @@ class VllmEngine(BaseEngine):
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=generating_args["max_new_tokens"], max_tokens=generating_args["max_new_tokens"],
skip_special_tokens=True, skip_special_tokens=True,
stop=generating_args["stop"],
) )
if self.processor is not None and image is not None: if self.processor is not None and image is not None:

View File

@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict from typing import Any, Dict, Union, Optional, List
@dataclass @dataclass
@ -46,7 +46,10 @@ class GeneratingArguments:
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
stop: Union[Optional[str], List[str]] = field(
default=None,
metadata={"help": "List of strings or string that stop the generation when they are generated. The returned output will not contain the stop strings."},
)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
args = asdict(self) args = asdict(self)
if args.get("max_new_tokens", -1) > 0: if args.get("max_new_tokens", -1) > 0: