1.Change the name of is_fastapi_available function

2. Added the log of printing requests when deploying using vllm


Former-commit-id: fd2e6dec589f4ebe55d4c203991c47bf5b728ef8
This commit is contained in:
Tendo33 2024-05-09 14:28:01 +08:00
parent 5ff89a0f32
commit dd42439b03
4 changed files with 64 additions and 33 deletions

View File

@ -4,7 +4,7 @@ from typing import Annotated, Optional
from ..chat import ChatModel from ..chat import ChatModel
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
from .chat import ( from .chat import (
create_chat_completion_response, create_chat_completion_response,
create_score_evaluation_response, create_score_evaluation_response,
@ -20,7 +20,7 @@ from .protocol import (
) )
if is_fastapi_availble(): if is_fastapi_available():
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
@ -54,7 +54,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
if api_key and (auth is None or auth.credentials != api_key): if api_key and (auth is None or auth.credentials != api_key):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
@app.get( @app.get(
"/v1/models", "/v1/models",
@ -74,10 +75,12 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
) )
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.engine.can_generate: if not chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") raise HTTPException(
status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if request.stream: if request.stream:
generate = create_stream_chat_completion_response(request, chat_model) generate = create_stream_chat_completion_response(
request, chat_model)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
else: else:
return await create_chat_completion_response(request, chat_model) return await create_chat_completion_response(request, chat_model)
@ -90,7 +93,8 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
) )
async def create_score_evaluation(request: ScoreEvaluationRequest): async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.engine.can_generate: if chat_model.engine.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") raise HTTPException(
status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
return await create_score_evaluation_response(request, chat_model) return await create_score_evaluation_response(request, chat_model)

View File

@ -3,7 +3,8 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.packages import is_fastapi_availble from ..extras.packages import is_fastapi_available
from ..extras.logging import get_logger
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
ChatCompletionMessage, ChatCompletionMessage,
@ -19,8 +20,9 @@ from .protocol import (
ScoreEvaluationResponse, ScoreEvaluationResponse,
) )
logger = get_logger(__name__)
if is_fastapi_availble(): if is_fastapi_available():
from fastapi import HTTPException, status from fastapi import HTTPException, status
@ -39,8 +41,13 @@ ROLE_MAPPING = {
def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]: def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, str]], str, str]:
params = dictify(request)
logger.info(f"==== request ====\n{params}")
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
if request.messages[0].role == Role.SYSTEM: if request.messages[0].role == Role.SYSTEM:
system = request.messages.pop(0).content system = request.messages.pop(0).content
@ -48,29 +55,37 @@ def _process_request(request: "ChatCompletionRequest") -> Tuple[List[Dict[str, s
system = "" system = ""
if len(request.messages) % 2 == 0: if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) content = json.dumps(
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) {"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append(
{"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append(
{"role": ROLE_MAPPING[message.role], "content": message.content})
tool_list = request.tools tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list): if isinstance(tool_list, list) and len(tool_list):
try: try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) tools = json.dumps([dictify(tool.function)
for tool in tool_list], ensure_ascii=False)
except Exception: except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:
tools = "" tools = ""
@ -84,8 +99,10 @@ def _create_stream_chat_completion_chunk(
index: Optional[int] = 0, index: Optional[int] = 0,
finish_reason: Optional["Finish"] = None, finish_reason: Optional["Finish"] = None,
) -> str: ) -> str:
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) choice_data = ChatCompletionStreamResponseChoice(
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data]) index=index, delta=delta, finish_reason=finish_reason)
chunk = ChatCompletionStreamResponse(
id=completion_id, model=model, choices=[choice_data])
return jsonify(chunk) return jsonify(chunk)
@ -110,21 +127,26 @@ async def create_chat_completion_response(
choices = [] choices = []
for i, response in enumerate(responses): for i, response in enumerate(responses):
if tools: if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text) result = chat_model.engine.template.format_tools.extract(
response.response_text)
else: else:
result = response.response_text result = response.response_text
if isinstance(result, tuple): if isinstance(result, tuple):
name, arguments = result name, arguments = result
function = Function(name=name, arguments=arguments) function = Function(name=name, arguments=arguments)
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) tool_call = FunctionCall(id="call_{}".format(
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call]) uuid.uuid4().hex), function=function)
response_message = ChatCompletionMessage(
role=Role.ASSISTANT, tool_calls=[tool_call])
finish_reason = Finish.TOOL finish_reason = Finish.TOOL
else: else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) response_message = ChatCompletionMessage(
role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)) choices.append(ChatCompletionResponseChoice(
index=i, message=response_message, finish_reason=finish_reason))
prompt_length = response.prompt_length prompt_length = response.prompt_length
response_length += response.response_length response_length += response.response_length
@ -143,13 +165,16 @@ async def create_stream_chat_completion_response(
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) completion_id = "chatcmpl-{}".format(uuid.uuid4().hex)
input_messages, system, tools = _process_request(request) input_messages, system, tools = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot stream function calls.")
if request.n > 1: if request.n > 1:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot stream multiple responses.")
yield _create_stream_chat_completion_chunk( yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="") completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(
role=Role.ASSISTANT, content="")
) )
async for new_token in chat_model.astream_chat( async for new_token in chat_model.astream_chat(
input_messages, input_messages,
@ -163,7 +188,8 @@ async def create_stream_chat_completion_response(
): ):
if len(new_token) != 0: if len(new_token) != 0:
yield _create_stream_chat_completion_chunk( yield _create_stream_chat_completion_chunk(
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token) completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(
content=new_token)
) )
yield _create_stream_chat_completion_chunk( yield _create_stream_chat_completion_chunk(
@ -176,7 +202,8 @@ async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel" request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse": ) -> "ScoreEvaluationResponse":
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores) return ScoreEvaluationResponse(model=request.model, scores=scores)

View File

@ -6,11 +6,11 @@ if TYPE_CHECKING:
from pydantic import BaseModel from pydantic import BaseModel
def dictify(data: "BaseModel") -> Dict[str, Any]: def dictify(data: "BaseModel", **kwargs) -> Dict[str, Any]:
try: # pydantic v2 try: # pydantic v2
return data.model_dump(exclude_unset=True) return data.model_dump(**kwargs)
except AttributeError: # pydantic v1 except AttributeError: # pydantic v1
return data.dict(exclude_unset=True) return data.dict(**kwargs)
def jsonify(data: "BaseModel") -> str: def jsonify(data: "BaseModel") -> str:

View File

@ -20,7 +20,7 @@ def _get_package_version(name: str) -> "Version":
return version.parse("0.0.0") return version.parse("0.0.0")
def is_fastapi_availble(): def is_fastapi_available():
return _is_package_available("fastapi") return _is_package_available("fastapi")