match api with OpenAI format

Former-commit-id: 76ecb8c222cec34fa6dbcef71e3907c95f67c22f
This commit is contained in:
hiyouga 2023-06-22 20:27:00 +08:00
parent 993d005242
commit 620cd2eb7e
4 changed files with 144 additions and 135 deletions

View File

@ -2,157 +2,157 @@
# Implements API for fine-tuned models.
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Request:
# curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}'
# Response:
# {
# "response": "'Hi there!'",
# "history": "[('Hello there!', 'Hi there!')]",
# "status": 200,
# "time": "2000-00-00 00:00:00"
# }
import json
import datetime
import time
import torch
import uvicorn
from fastapi import FastAPI
from threading import Thread
from fastapi import FastAPI, Request
from starlette.responses import StreamingResponse
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field
from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
from utils import Template, load_pretrained, prepare_infer_args, get_logits_processor
from utils import (
Template,
load_pretrained,
prepare_infer_args,
get_logits_processor
)
def torch_gc():
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
for device_id in range(num_gpus):
with torch.cuda.device(device_id):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
app = FastAPI(lifespan=lifespan)
@app.post("/v1/chat/completions")
async def create_item(request: Request):
global model, tokenizer
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
json_post_raw = await request.json()
prompt = json_post_raw.get("messages")[-1]["content"]
history = json_post_raw.get("messages")[:-1]
max_token = json_post_raw.get("max_tokens", None)
top_p = json_post_raw.get("top_p", None)
temperature = json_post_raw.get("temperature", None)
stream = check_stream(json_post_raw.get("stream"))
if stream:
generate = predict(prompt, max_token, top_p, temperature, history)
return StreamingResponse(generate, media_type="text/event-stream")
class DeltaMessage(BaseModel):
role: Optional[Literal["system", "user", "assistant"]] = None
content: Optional[str] = None
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")[
"input_ids"]
input_ids = input_ids.to(model.device)
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_new_tokens: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: str
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
source_prefix = prev_messages.pop(0).content
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = generating_args.to_dict()
gen_kwargs["input_ids"] = input_ids
gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["max_new_tokens"] = max_token if max_token else gen_kwargs["max_new_tokens"]
gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"]
gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"]
gen_kwargs.update({
"input_ids": inputs["input_ids"],
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
"max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"],
"logits_processor": get_logits_processor()
})
if request.stream:
generate = predict(gen_kwargs, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
generation_output = model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][len(input_ids[0]):]
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"choices": [
{
"message": {
"role": "assistant",
"content": response
}
}
]
}
log = (
"["
+ time
+ "] "
+ "\", prompt:\""
+ prompt
+ "\", response:\""
+ repr(response)
+ "\""
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop"
)
print(log)
torch_gc()
return answer
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
def check_stream(stream):
if isinstance(stream, bool):
# stream 是布尔类型,直接使用
stream_value = stream
else:
# 不是布尔类型,尝试进行类型转换
if isinstance(stream, str):
stream = stream.lower()
if stream in ["true", "false"]:
# 使用字符串值转换为布尔值
stream_value = stream == "true"
else:
# 非法的字符串值
stream_value = False
else:
# 非布尔类型也非字符串类型
stream_value = False
return stream_value
async def predict(query, max_length, top_p, temperature, history):
async def predict(gen_kwargs: Dict[str, Any], model_id: str):
global model, tokenizer
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
"input_ids": input_ids,
"do_sample": generating_args.do_sample,
"top_p": top_p,
"temperature": temperature,
"num_beams": generating_args.num_beams,
"max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
for new_text in streamer:
answer = {
"choices": [
{
"message": {
"role": "assistant",
"content": new_text
}
}
]
}
yield "data: " + json.dumps(answer) + '\n\n'
if len(new_text) == 0:
continue
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=new_text),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
if __name__ == "__main__":

View File

@ -18,7 +18,6 @@ def main():
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args)
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
@ -29,34 +28,39 @@ def main():
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = generating_args.to_dict()
gen_kwargs["input_ids"] = input_ids
gen_kwargs["logits_processor"] = get_logits_processor()
gen_kwargs["streamer"] = streamer
gen_kwargs.update({
"input_ids": input_ids,
"logits_processor": get_logits_processor(),
"streamer": streamer
})
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
print("{}: ".format(model_name), end="", flush=True)
print("Assistant: ", end="", flush=True)
response = ""
for new_text in streamer:
print(new_text, end="", flush=True)
response += new_text
print()
history = history + [(query, response)]
return history
history = []
print("欢迎使用 {} 模型输入内容即可对话clear清空对话历史stop终止程序".format(model_name))
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nInput: ")
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "stop":
if query.strip() == "exit":
break
if query.strip() == "clear":

View File

@ -285,6 +285,10 @@ class GeneratingArguments:
default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
)
length_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

View File

@ -77,7 +77,7 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
return text
def predict(query, chatbot, max_length, top_p, temperature, history):
def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
chatbot.append((parse_text(query), ""))
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
@ -85,17 +85,15 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {
gen_kwargs = generating_args.to_dict()
gen_kwargs.update({
"input_ids": input_ids,
"do_sample": generating_args.do_sample,
"top_p": top_p,
"temperature": temperature,
"num_beams": generating_args.num_beams,
"max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty,
"max_new_tokens": max_new_tokens,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
})
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
@ -137,13 +135,16 @@ with gr.Blocks() as demo:
with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True)
top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True)
temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True)
max_new_tokens = gr.Slider(10, 2048, value=generating_args.max_new_tokens, step=1.0,
label="Maximum new tokens", interactive=True)
top_p = gr.Slider(0.01, 1, value=generating_args.top_p, step=0.01,
label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1.5, value=generating_args.temperature, step=0.01,
label="Temperature", interactive=True)
history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True)
submitBtn.click(predict, [user_input, chatbot, max_new_tokens, top_p, temperature, history], [chatbot, history], show_progress=True)
submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)