mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
Merge pull request #3527 from zhaonx/dev
"add support for vllm api stop parameter" Former-commit-id: bcf7ec5ceb13920786831166861f18edd2506bb1
This commit is contained in:
commit
c198db4db2
@ -103,6 +103,7 @@ async def create_chat_completion_response(
|
|||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens,
|
max_new_tokens=request.max_tokens,
|
||||||
num_return_sequences=request.n,
|
num_return_sequences=request.n,
|
||||||
|
stop=request.stop
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_length, response_length = 0, 0
|
prompt_length, response_length = 0, 0
|
||||||
@ -155,6 +156,7 @@ async def create_stream_chat_completion_response(
|
|||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
max_new_tokens=request.max_tokens,
|
max_new_tokens=request.max_tokens,
|
||||||
|
stop=request.stop
|
||||||
):
|
):
|
||||||
if len(new_token) != 0:
|
if len(new_token) != 0:
|
||||||
yield _create_stream_chat_completion_chunk(
|
yield _create_stream_chat_completion_chunk(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
@ -78,6 +78,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
n: int = 1
|
n: int = 1
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
stop: Union[Optional[str], List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
@ -96,6 +96,7 @@ class VllmEngine(BaseEngine):
|
|||||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
stop = input_kwargs.pop("stop", None)
|
||||||
|
|
||||||
generating_args = self.generating_args.copy()
|
generating_args = self.generating_args.copy()
|
||||||
generating_args.update(
|
generating_args.update(
|
||||||
@ -122,6 +123,7 @@ class VllmEngine(BaseEngine):
|
|||||||
top_k=generating_args["top_k"],
|
top_k=generating_args["top_k"],
|
||||||
use_beam_search=generating_args["num_beams"] > 1,
|
use_beam_search=generating_args["num_beams"] > 1,
|
||||||
length_penalty=generating_args["length_penalty"],
|
length_penalty=generating_args["length_penalty"],
|
||||||
|
stop=stop,
|
||||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
max_tokens=generating_args["max_new_tokens"],
|
max_tokens=generating_args["max_new_tokens"],
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user