match api with OpenAI format

Former-commit-id: 76ecb8c222cec34fa6dbcef71e3907c95f67c22f
This commit is contained in:
hiyouga 2023-06-22 20:27:00 +08:00
parent 993d005242
commit 620cd2eb7e
4 changed files with 144 additions and 135 deletions

View File

@ -2,157 +2,157 @@
# Implements API for fine-tuned models. # Implements API for fine-tuned models.
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Request:
# curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}'
# Response:
# {
# "response": "'Hi there!'",
# "history": "[('Hello there!', 'Hi there!')]",
# "status": 200,
# "time": "2000-00-00 00:00:00"
# }
import json import json
import datetime import time
import torch import torch
import uvicorn import uvicorn
from fastapi import FastAPI
from threading import Thread from threading import Thread
from fastapi import FastAPI, Request from contextlib import asynccontextmanager
from starlette.responses import StreamingResponse
from pydantic import BaseModel, Field
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union
from utils import Template, load_pretrained, prepare_infer_args, get_logits_processor from utils import (
Template,
load_pretrained,
prepare_infer_args,
get_logits_processor
)
def torch_gc(): @asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available(): if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
for device_id in range(num_gpus):
with torch.cuda.device(device_id):
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
app = FastAPI() app = FastAPI(lifespan=lifespan)
@app.post("/v1/chat/completions") class ChatMessage(BaseModel):
async def create_item(request: Request): role: Literal["system", "user", "assistant"]
global model, tokenizer content: str
json_post_raw = await request.json()
prompt = json_post_raw.get("messages")[-1]["content"]
history = json_post_raw.get("messages")[:-1]
max_token = json_post_raw.get("max_tokens", None)
top_p = json_post_raw.get("top_p", None)
temperature = json_post_raw.get("temperature", None)
stream = check_stream(json_post_raw.get("stream"))
if stream: class DeltaMessage(BaseModel):
generate = predict(prompt, max_token, top_p, temperature, history) role: Optional[Literal["system", "user", "assistant"]] = None
return StreamingResponse(generate, media_type="text/event-stream") content: Optional[str] = None
input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")[
"input_ids"] class ChatCompletionRequest(BaseModel):
input_ids = input_ids.to(model.device) model: str
messages: List[ChatMessage]
temperature: Optional[float] = None
top_p: Optional[float] = None
max_new_tokens: Optional[int] = None
stream: Optional[bool] = False
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]]
class ChatCompletionResponse(BaseModel):
model: str
object: str
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system":
source_prefix = prev_messages.pop(0).content
history = []
if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content])
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = generating_args.to_dict() gen_kwargs = generating_args.to_dict()
gen_kwargs["input_ids"] = input_ids gen_kwargs.update({
gen_kwargs["logits_processor"] = get_logits_processor() "input_ids": inputs["input_ids"],
gen_kwargs["max_new_tokens"] = max_token if max_token else gen_kwargs["max_new_tokens"] "temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"] "top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"] "max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"],
"logits_processor": get_logits_processor()
})
if request.stream:
generate = predict(gen_kwargs, request.model)
return StreamingResponse(generate, media_type="text/event-stream")
generation_output = model.generate(**gen_kwargs) generation_output = model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
outputs = generation_output.tolist()[0][len(input_ids[0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True) response = tokenizer.decode(outputs, skip_special_tokens=True)
now = datetime.datetime.now() choice_data = ChatCompletionResponseChoice(
time = now.strftime("%Y-%m-%d %H:%M:%S") index=0,
answer = { message=ChatMessage(role="assistant", content=response),
"choices": [ finish_reason="stop"
{
"message": {
"role": "assistant",
"content": response
}
}
]
}
log = (
"["
+ time
+ "] "
+ "\", prompt:\""
+ prompt
+ "\", response:\""
+ repr(response)
+ "\""
) )
print(log)
torch_gc()
return answer return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion")
def check_stream(stream): async def predict(gen_kwargs: Dict[str, Any], model_id: str):
if isinstance(stream, bool):
# stream 是布尔类型,直接使用
stream_value = stream
else:
# 不是布尔类型,尝试进行类型转换
if isinstance(stream, str):
stream = stream.lower()
if stream in ["true", "false"]:
# 使用字符串值转换为布尔值
stream_value = stream == "true"
else:
# 非法的字符串值
stream_value = False
else:
# 非布尔类型也非字符串类型
stream_value = False
return stream_value
async def predict(query, max_length, top_p, temperature, history):
global model, tokenizer global model, tokenizer
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
gen_kwargs = {
"input_ids": input_ids,
"do_sample": generating_args.do_sample,
"top_p": top_p,
"temperature": temperature,
"num_beams": generating_args.num_beams,
"max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(),
"streamer": streamer
}
thread = Thread(target=model.generate, kwargs=gen_kwargs) thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start() thread.start()
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
for new_text in streamer: for new_text in streamer:
answer = { if len(new_text) == 0:
"choices": [ continue
{
"message": { choice_data = ChatCompletionResponseStreamChoice(
"role": "assistant", index=0,
"content": new_text delta=DeltaMessage(content=new_text),
} finish_reason=None
} )
] chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
} yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
yield "data: " + json.dumps(answer) + '\n\n'
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -18,7 +18,6 @@ def main():
model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
prompt_template = Template(data_args.prompt_template) prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else "" source_prefix = data_args.source_prefix if data_args.source_prefix else ""
@ -29,34 +28,39 @@ def main():
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = generating_args.to_dict() gen_kwargs = generating_args.to_dict()
gen_kwargs["input_ids"] = input_ids gen_kwargs.update({
gen_kwargs["logits_processor"] = get_logits_processor() "input_ids": input_ids,
gen_kwargs["streamer"] = streamer "logits_processor": get_logits_processor(),
"streamer": streamer
})
thread = Thread(target=model.generate, kwargs=gen_kwargs) thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start() thread.start()
print("{}: ".format(model_name), end="", flush=True) print("Assistant: ", end="", flush=True)
response = "" response = ""
for new_text in streamer: for new_text in streamer:
print(new_text, end="", flush=True) print(new_text, end="", flush=True)
response += new_text response += new_text
print() print()
history = history + [(query, response)] history = history + [(query, response)]
return history return history
history = [] history = []
print("欢迎使用 {} 模型输入内容即可对话clear清空对话历史stop终止程序".format(model_name)) print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True: while True:
try: try:
query = input("\nInput: ") query = input("\nUser: ")
except UnicodeDecodeError: except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue continue
except Exception: except Exception:
raise raise
if query.strip() == "stop": if query.strip() == "exit":
break break
if query.strip() == "clear": if query.strip() == "clear":

View File

@ -285,6 +285,10 @@ class GeneratingArguments:
default=1.0, default=1.0,
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
) )
length_penalty: Optional[float] = field(
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return asdict(self) return asdict(self)

View File

@ -77,7 +77,7 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
return text return text
def predict(query, chatbot, max_length, top_p, temperature, history): def predict(query, chatbot, max_new_tokens, top_p, temperature, history):
chatbot.append((parse_text(query), "")) chatbot.append((parse_text(query), ""))
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"] input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
@ -85,17 +85,15 @@ def predict(query, chatbot, max_length, top_p, temperature, history):
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = { gen_kwargs = generating_args.to_dict()
gen_kwargs.update({
"input_ids": input_ids, "input_ids": input_ids,
"do_sample": generating_args.do_sample,
"top_p": top_p, "top_p": top_p,
"temperature": temperature, "temperature": temperature,
"num_beams": generating_args.num_beams, "max_new_tokens": max_new_tokens,
"max_length": max_length,
"repetition_penalty": generating_args.repetition_penalty,
"logits_processor": get_logits_processor(), "logits_processor": get_logits_processor(),
"streamer": streamer "streamer": streamer
} })
thread = Thread(target=model.generate, kwargs=gen_kwargs) thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start() thread.start()
@ -137,13 +135,16 @@ with gr.Blocks() as demo:
with gr.Column(scale=1): with gr.Column(scale=1):
emptyBtn = gr.Button("Clear History") emptyBtn = gr.Button("Clear History")
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True) max_new_tokens = gr.Slider(10, 2048, value=generating_args.max_new_tokens, step=1.0,
top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True) label="Maximum new tokens", interactive=True)
temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True) top_p = gr.Slider(0.01, 1, value=generating_args.top_p, step=0.01,
label="Top P", interactive=True)
temperature = gr.Slider(0.01, 1.5, value=generating_args.temperature, step=0.01,
label="Temperature", interactive=True)
history = gr.State([]) history = gr.State([])
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) submitBtn.click(predict, [user_input, chatbot, max_new_tokens, top_p, temperature, history], [chatbot, history], show_progress=True)
submitBtn.click(reset_user_input, [], [user_input]) submitBtn.click(reset_user_input, [], [user_input])
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)