mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 21:52:50 +08:00
fix api server
Former-commit-id: 08464183b9b034abdbf179d7043705a0754837e5
This commit is contained in:
parent
6bbcf5ad16
commit
5b61063048
@ -1,4 +1,6 @@
|
|||||||
|
import os
|
||||||
import json
|
import json
|
||||||
|
import asyncio
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
@ -63,6 +65,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||||
|
|
||||||
@app.get("/v1/models", response_model=ModelList)
|
@app.get("/v1/models", response_model=ModelList)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||||
@ -93,8 +97,13 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(None, chat_completion, query, history, system, request)
|
||||||
|
|
||||||
|
def chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||||
if request.stream:
|
if request.stream:
|
||||||
generate = predict(query, history, system, request)
|
generate = stream_chat_completion(query, history, system, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
responses = chat_model.chat(
|
responses = chat_model.chat(
|
||||||
@ -125,7 +134,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||||
|
|
||||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
def stream_chat_completion(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
|
delta=DeltaMessage(role=Role.ASSISTANT, content=""),
|
||||||
@ -169,6 +178,11 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
|
async with semaphore:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(None, get_score, request)
|
||||||
|
|
||||||
|
def get_score(request: ScoreEvaluationRequest):
|
||||||
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
|
||||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
||||||
|
|
||||||
@ -178,4 +192,4 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8000)), workers=1)
|
||||||
|
@ -152,7 +152,6 @@ class ChatModel:
|
|||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
|
||||||
pad_to_multiple_of=8,
|
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**kwargs
|
**kwargs
|
||||||
).to(device)
|
).to(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user