From 6672ad7a83a27d72e1684820a09ea33bca98533c Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 4 May 2024 16:11:18 +0800 Subject: [PATCH] fix async stream api response Former-commit-id: 941924fdbd69c2529677564af61f9019086ef21f --- src/llmtuner/api/chat.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/llmtuner/api/chat.py b/src/llmtuner/api/chat.py index c9c00f16..716dec56 100644 --- a/src/llmtuner/api/chat.py +++ b/src/llmtuner/api/chat.py @@ -38,7 +38,7 @@ ROLE_MAPPING = { } -async def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: +def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: if len(request.messages) == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") @@ -77,11 +77,23 @@ async def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[ return input_messages, system, tools +def _create_stream_chat_completion_chunk( + completion_id: str, + model: str, + delta: "ChatCompletionMessage", + index: Optional[int] = 0, + finish_reason: Optional["Finish"] = None, +) -> str: + choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) + chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data]) + return jsonify(chunk) + + async def create_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> "ChatCompletionResponse": completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) - input_messages, system, tools = await _process_request(request) + input_messages, system, tools = _process_request(request) responses = await chat_model.achat( input_messages, system, @@ -124,23 +136,11 @@ async def create_chat_completion_response( return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage) -async def _create_stream_chat_completion_chunk( - completion_id: str, - model: str, - delta: "ChatCompletionMessage", - index: Optional[int] = 0, - finish_reason: Optional["Finish"] = None, -) -> str: - choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) - chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data]) - return jsonify(chunk) - - async def create_stream_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> AsyncGenerator[str, None]: completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) - input_messages, system, tools = await _process_request(request) + input_messages, system, tools = _process_request(request) if tools: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")