update readme

Former-commit-id: 0697643358
This commit is contained in:
hiyouga
2023-06-23 00:17:05 +08:00
parent 0c7eb90f6b
commit cf29a9af35
3 changed files with 22 additions and 27 deletions

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Implements API for fine-tuned models.
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
@@ -7,11 +7,10 @@
import time
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from threading import Thread
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
@@ -68,14 +67,14 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel):
model: str
object: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix
global model, tokenizer, source_prefix, generating_args
if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
@@ -83,7 +82,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
source_prefix = prev_messages.pop(0).content
prefix = prev_messages.pop(0).content
else:
prefix = source_prefix
history = []
if len(prev_messages) % 2 == 0:
@@ -91,7 +92,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = generating_args.to_dict()
@@ -134,7 +135,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
for new_text in streamer:
@@ -146,7 +147,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
@@ -154,7 +155,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))