fix api server

This commit is contained in:
hiyouga
2024-01-07 17:14:42 +08:00
parent d2a676c8ba
commit 08464183b9
2 changed files with 18 additions and 5 deletions

View File

@@ -1,4 +1,6 @@
import os
import json import json
import asyncio
from typing import List, Tuple from typing import List, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@@ -63,6 +65,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_headers=["*"], allow_headers=["*"],
) )
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
@app.get("/v1/models", response_model=ModelList) @app.get("/v1/models", response_model=ModelList)
async def list_models(): async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id="gpt-3.5-turbo")
@@ -93,8 +97,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
else: else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
if request.stream: if request.stream:
generate = predict(query, history, system, request) generate = stream_chat_completion(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat( responses = chat_model.chat(
@@ -125,7 +134,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role=Role.ASSISTANT, content=""), delta=DeltaMessage(role=Role.ASSISTANT, content=""),
@@ -169,6 +178,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_score, request)
def get_score(request: ScoreEvaluationRequest):
scores = chat_model.get_scores(request.messages, max_length=request.max_length) scores = chat_model.get_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores) return ScoreEvaluationResponse(model=request.model, scores=scores)
@@ -178,4 +192,4 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if __name__ == "__main__": if __name__ == "__main__":
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)

View File

@@ -152,7 +152,6 @@ class ChatModel:
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024), max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
pad_to_multiple_of=8,
return_tensors="pt", return_tensors="pt",
**kwargs **kwargs
).to(device) ).to(device)