fix api server

Former-commit-id: 08464183b9b034abdbf179d7043705a0754837e5
This commit is contained in:
hiyouga 2024-01-07 17:14:42 +08:00
parent 6bbcf5ad16
commit 5b61063048
2 changed files with 18 additions and 5 deletions

View File

@ -1,4 +1,6 @@
import os
import json import json
import asyncio
from typing import List, Tuple from typing import List, Tuple
from pydantic import BaseModel from pydantic import BaseModel
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -63,6 +65,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_headers=["*"], allow_headers=["*"],
) )
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
@app.get("/v1/models", response_model=ModelList) @app.get("/v1/models", response_model=ModelList)
async def list_models(): async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id="gpt-3.5-turbo")
@ -93,8 +97,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
else: else:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
if request.stream: if request.stream:
generate = predict(query, history, system, request) generate = stream_chat_completion(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
responses = chat_model.chat( responses = chat_model.chat(
@ -125,7 +134,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role=Role.ASSISTANT, content=""), delta=DeltaMessage(role=Role.ASSISTANT, content=""),
@ -168,7 +177,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, get_score, request)
def get_score(request: ScoreEvaluationRequest):
scores = chat_model.get_scores(request.messages, max_length=request.max_length) scores = chat_model.get_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores) return ScoreEvaluationResponse(model=request.model, scores=scores)
@ -178,4 +192,4 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if __name__ == "__main__": if __name__ == "__main__":
chat_model = ChatModel() chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)

View File

@ -152,7 +152,6 @@ class ChatModel:
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024), max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
pad_to_multiple_of=8,
return_tensors="pt", return_tensors="pt",
**kwargs **kwargs
).to(device) ).to(device)