diff --git a/src/llmtuner/api/chat.py b/src/llmtuner/api/chat.py index fa2f0d03..972ee906 100644 --- a/src/llmtuner/api/chat.py +++ b/src/llmtuner/api/chat.py @@ -103,6 +103,7 @@ async def create_chat_completion_response( top_p=request.top_p, max_new_tokens=request.max_tokens, num_return_sequences=request.n, + stop=request.stop ) prompt_length, response_length = 0, 0 @@ -155,6 +156,7 @@ async def create_stream_chat_completion_response( temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens, + stop=request.stop ): if len(new_token) != 0: yield _create_stream_chat_completion_chunk( diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index ae6e2e9b..f526c813 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -1,6 +1,6 @@ import time from enum import Enum, unique -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field from typing_extensions import Literal @@ -78,6 +78,7 @@ class ChatCompletionRequest(BaseModel): n: int = 1 max_tokens: Optional[int] = None stream: bool = False + stop: Union[Optional[str], List[str]] = None class ChatCompletionResponseChoice(BaseModel): diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index 0f0dc366..d50e41aa 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -96,6 +96,7 @@ class VllmEngine(BaseEngine): repetition_penalty = input_kwargs.pop("repetition_penalty", None) max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) + stop = input_kwargs.pop("stop", None) generating_args = self.generating_args.copy() generating_args.update( @@ -122,6 +123,7 @@ class VllmEngine(BaseEngine): top_k=generating_args["top_k"], use_beam_search=generating_args["num_beams"] > 1, length_penalty=generating_args["length_penalty"], + stop=stop, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, max_tokens=generating_args["max_new_tokens"], skip_special_tokens=True,