diff --git a/src/api_demo.py b/src/api_demo.py index 28125db7..899aec11 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -2,157 +2,157 @@ # Implements API for fine-tuned models. # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint -# Request: -# curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}' - -# Response: -# { -# "response": "'Hi there!'", -# "history": "[('Hello there!', 'Hi there!')]", -# "status": 200, -# "time": "2000-00-00 00:00:00" -# } - import json -import datetime +import time import torch import uvicorn +from fastapi import FastAPI from threading import Thread -from fastapi import FastAPI, Request -from starlette.responses import StreamingResponse +from contextlib import asynccontextmanager + +from pydantic import BaseModel, Field from transformers import TextIteratorStreamer +from starlette.responses import StreamingResponse +from typing import Any, Dict, List, Literal, Optional, Union -from utils import Template, load_pretrained, prepare_infer_args, get_logits_processor +from utils import ( + Template, + load_pretrained, + prepare_infer_args, + get_logits_processor +) -def torch_gc(): +@asynccontextmanager +async def lifespan(app: FastAPI): # collects GPU memory + yield if torch.cuda.is_available(): - num_gpus = torch.cuda.device_count() - for device_id in range(num_gpus): - with torch.cuda.device(device_id): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() -app = FastAPI() +app = FastAPI(lifespan=lifespan) -@app.post("/v1/chat/completions") -async def create_item(request: Request): - global model, tokenizer +class ChatMessage(BaseModel): + role: Literal["system", "user", "assistant"] + content: str - json_post_raw = await request.json() - prompt = json_post_raw.get("messages")[-1]["content"] - history = json_post_raw.get("messages")[:-1] - max_token = json_post_raw.get("max_tokens", None) - top_p = json_post_raw.get("top_p", None) - temperature = json_post_raw.get("temperature", None) - stream = check_stream(json_post_raw.get("stream")) - if stream: - generate = predict(prompt, max_token, top_p, temperature, history) - return StreamingResponse(generate, media_type="text/event-stream") +class DeltaMessage(BaseModel): + role: Optional[Literal["system", "user", "assistant"]] = None + content: Optional[str] = None - input_ids = tokenizer([prompt_template.get_prompt(prompt, history, source_prefix)], return_tensors="pt")[ - "input_ids"] - input_ids = input_ids.to(model.device) + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + temperature: Optional[float] = None + top_p: Optional[float] = None + max_new_tokens: Optional[int] = None + stream: Optional[bool] = False + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Literal["stop", "length"] + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[Literal["stop", "length"]] + + +class ChatCompletionResponse(BaseModel): + model: str + object: str + choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + global model, tokenizer, source_prefix + + query = request.messages[-1].content + prev_messages = request.messages[:-1] + if len(prev_messages) > 0 and prev_messages[0].role == "system": + source_prefix = prev_messages.pop(0).content + + history = [] + if len(prev_messages) % 2 == 0: + for i in range(0, len(prev_messages), 2): + if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": + history.append([prev_messages[i].content, prev_messages[i+1].content]) + + inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt") + inputs = inputs.to(model.device) gen_kwargs = generating_args.to_dict() - gen_kwargs["input_ids"] = input_ids - gen_kwargs["logits_processor"] = get_logits_processor() - gen_kwargs["max_new_tokens"] = max_token if max_token else gen_kwargs["max_new_tokens"] - gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"] - gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"] + gen_kwargs.update({ + "input_ids": inputs["input_ids"], + "temperature": request.temperature if request.temperature else gen_kwargs["temperature"], + "top_p": request.top_p if request.top_p else gen_kwargs["top_p"], + "max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"], + "logits_processor": get_logits_processor() + }) + + if request.stream: + generate = predict(gen_kwargs, request.model) + return StreamingResponse(generate, media_type="text/event-stream") generation_output = model.generate(**gen_kwargs) - - outputs = generation_output.tolist()[0][len(input_ids[0]):] + outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs, skip_special_tokens=True) - now = datetime.datetime.now() - time = now.strftime("%Y-%m-%d %H:%M:%S") - answer = { - "choices": [ - { - "message": { - "role": "assistant", - "content": response - } - } - ] - } - - log = ( - "[" - + time - + "] " - + "\", prompt:\"" - + prompt - + "\", response:\"" - + repr(response) - + "\"" + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop" ) - print(log) - torch_gc() - return answer + return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion") -def check_stream(stream): - if isinstance(stream, bool): - # stream 是布尔类型,直接使用 - stream_value = stream - else: - # 不是布尔类型,尝试进行类型转换 - if isinstance(stream, str): - stream = stream.lower() - if stream in ["true", "false"]: - # 使用字符串值转换为布尔值 - stream_value = stream == "true" - else: - # 非法的字符串值 - stream_value = False - else: - # 非布尔类型也非字符串类型 - stream_value = False - return stream_value - - -async def predict(query, max_length, top_p, temperature, history): +async def predict(gen_kwargs: Dict[str, Any], model_id: str): global model, tokenizer - input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"] - input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - - gen_kwargs = { - "input_ids": input_ids, - "do_sample": generating_args.do_sample, - "top_p": top_p, - "temperature": temperature, - "num_beams": generating_args.num_beams, - "max_length": max_length, - "repetition_penalty": generating_args.repetition_penalty, - "logits_processor": get_logits_processor(), - "streamer": streamer - } + gen_kwargs["streamer"] = streamer thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk") + yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + for new_text in streamer: - answer = { - "choices": [ - { - "message": { - "role": "assistant", - "content": new_text - } - } - ] - } - yield "data: " + json.dumps(answer) + '\n\n' + if len(new_text) == 0: + continue + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=new_text), + finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk") + yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason="stop" + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk") + yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) if __name__ == "__main__": diff --git a/src/cli_demo.py b/src/cli_demo.py index 2a501f08..752b42c8 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -18,7 +18,6 @@ def main(): model_args, data_args, finetuning_args, generating_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) - model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" prompt_template = Template(data_args.prompt_template) source_prefix = data_args.source_prefix if data_args.source_prefix else "" @@ -29,34 +28,39 @@ def main(): streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs = generating_args.to_dict() - gen_kwargs["input_ids"] = input_ids - gen_kwargs["logits_processor"] = get_logits_processor() - gen_kwargs["streamer"] = streamer + gen_kwargs.update({ + "input_ids": input_ids, + "logits_processor": get_logits_processor(), + "streamer": streamer + }) thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() - print("{}: ".format(model_name), end="", flush=True) + print("Assistant: ", end="", flush=True) + response = "" for new_text in streamer: print(new_text, end="", flush=True) response += new_text print() + history = history + [(query, response)] return history history = [] - print("欢迎使用 {} 模型,输入内容即可对话,clear清空对话历史,stop终止程序".format(model_name)) + print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") + while True: try: - query = input("\nInput: ") + query = input("\nUser: ") except UnicodeDecodeError: print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") continue except Exception: raise - if query.strip() == "stop": + if query.strip() == "exit": break if query.strip() == "clear": diff --git a/src/utils/config.py b/src/utils/config.py index d7a6cc86..dfe09392 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -285,6 +285,10 @@ class GeneratingArguments: default=1.0, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} ) + length_penalty: Optional[float] = field( + default=1.0, + metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} + ) def to_dict(self) -> Dict[str, Any]: return asdict(self) diff --git a/src/web_demo.py b/src/web_demo.py index d081a9ee..9fcd906d 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -77,7 +77,7 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT return text -def predict(query, chatbot, max_length, top_p, temperature, history): +def predict(query, chatbot, max_new_tokens, top_p, temperature, history): chatbot.append((parse_text(query), "")) input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"] @@ -85,17 +85,15 @@ def predict(query, chatbot, max_length, top_p, temperature, history): streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - gen_kwargs = { + gen_kwargs = generating_args.to_dict() + gen_kwargs.update({ "input_ids": input_ids, - "do_sample": generating_args.do_sample, "top_p": top_p, "temperature": temperature, - "num_beams": generating_args.num_beams, - "max_length": max_length, - "repetition_penalty": generating_args.repetition_penalty, + "max_new_tokens": max_new_tokens, "logits_processor": get_logits_processor(), "streamer": streamer - } + }) thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() @@ -137,13 +135,16 @@ with gr.Blocks() as demo: with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") - max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True) - top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True) - temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True) + max_new_tokens = gr.Slider(10, 2048, value=generating_args.max_new_tokens, step=1.0, + label="Maximum new tokens", interactive=True) + top_p = gr.Slider(0.01, 1, value=generating_args.top_p, step=0.01, + label="Top P", interactive=True) + temperature = gr.Slider(0.01, 1.5, value=generating_args.temperature, step=0.01, + label="Temperature", interactive=True) history = gr.State([]) - submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True) + submitBtn.click(predict, [user_input, chatbot, max_new_tokens, top_p, temperature, history], [chatbot, history], show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)