From 5b61063048e690f95eabc53f20a8f0cda98c7738 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 7 Jan 2024 17:14:42 +0800 Subject: [PATCH] fix api server Former-commit-id: 08464183b9b034abdbf179d7043705a0754837e5 --- src/llmtuner/api/app.py | 22 ++++++++++++++++++---- src/llmtuner/chat/chat_model.py | 1 - 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 4ac08608..f130eab6 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -1,4 +1,6 @@ +import os import json +import asyncio from typing import List, Tuple from pydantic import BaseModel from contextlib import asynccontextmanager @@ -63,6 +65,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": allow_headers=["*"], ) + semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) + @app.get("/v1/models", response_model=ModelList) async def list_models(): model_card = ModelCard(id="gpt-3.5-turbo") @@ -93,8 +97,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": else: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") + async with semaphore: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, chat_completion, query, history, system, request) + + def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): if request.stream: - generate = predict(query, history, system, request) + generate = stream_chat_completion(query, history, system, request) return EventSourceResponse(generate, media_type="text/event-stream") responses = chat_model.chat( @@ -125,7 +134,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) - async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): + def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), @@ -168,7 +177,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": if len(request.messages) == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") - + + async with semaphore: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, get_score, request) + + def get_score(request: ScoreEvaluationRequest): scores = chat_model.get_scores(request.messages, max_length=request.max_length) return ScoreEvaluationResponse(model=request.model, scores=scores) @@ -178,4 +192,4 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": if __name__ == "__main__": chat_model = ChatModel() app = create_app(chat_model) - uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) + uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1) diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index 5c06109c..0c2f9c92 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -152,7 +152,6 @@ class ChatModel: padding=True, truncation=True, max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024), - pad_to_multiple_of=8, return_tensors="pt", **kwargs ).to(device)