From 175a7ea951dccdc49851d77fb73da67676f72172 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 May 2024 00:41:04 +0800 Subject: [PATCH] fix stop param Former-commit-id: 09f3ef1de49f97001faa91ef3dc2bd16790f9717 --- data/dataset_info.json | 2 +- src/llmtuner/api/chat.py | 7 +++++-- src/llmtuner/api/protocol.py | 2 +- src/llmtuner/chat/hf_engine.py | 4 ++++ 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/data/dataset_info.json b/data/dataset_info.json index 4848abd5..c5d0c693 100644 --- a/data/dataset_info.json +++ b/data/dataset_info.json @@ -459,4 +459,4 @@ }, "folder": "python" } -} +} \ No newline at end of file diff --git a/src/llmtuner/api/chat.py b/src/llmtuner/api/chat.py index 972ee906..2a703877 100644 --- a/src/llmtuner/api/chat.py +++ b/src/llmtuner/api/chat.py @@ -103,7 +103,7 @@ async def create_chat_completion_response( top_p=request.top_p, max_new_tokens=request.max_tokens, num_return_sequences=request.n, - stop=request.stop + stop=request.stop, ) prompt_length, response_length = 0, 0 @@ -145,6 +145,9 @@ async def create_stream_chat_completion_response( if tools: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") + if request.n > 1: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.") + yield _create_stream_chat_completion_chunk( completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="") ) @@ -156,7 +159,7 @@ async def create_stream_chat_completion_response( temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens, - stop=request.stop + stop=request.stop, ): if len(new_token) != 0: yield _create_stream_chat_completion_chunk( diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index f526c813..525fa6a7 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -77,8 +77,8 @@ class ChatCompletionRequest(BaseModel): top_p: Optional[float] = None n: int = 1 max_tokens: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None stream: bool = False - stop: Union[Optional[str], List[str]] = None class ChatCompletionResponseChoice(BaseModel): diff --git a/src/llmtuner/chat/hf_engine.py b/src/llmtuner/chat/hf_engine.py index e8f06a73..97160d57 100644 --- a/src/llmtuner/chat/hf_engine.py +++ b/src/llmtuner/chat/hf_engine.py @@ -73,6 +73,10 @@ class HuggingfaceEngine(BaseEngine): repetition_penalty = input_kwargs.pop("repetition_penalty", None) max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) + stop = input_kwargs.pop("stop", None) + + if stop is not None: + raise ValueError("Stop parameter is not supported in Huggingface engine yet.") generating_args.update( dict(