diff --git a/README.md b/README.md index 1db23423..34b7021f 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,13 @@ ## Changelog -[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` argument to use the baichuan-7B model. +[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications. + +[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` and `--lora_target W_pack` arguments to use the baichuan-7B model. [23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature) -[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model. +[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model. ## Supported Models @@ -75,9 +77,9 @@ huggingface-cli login - Python 3.8+ and PyTorch 1.13.1+ - 🤗Transformers, Datasets, Accelerate, PEFT and TRL -- protobuf, cpm_kernels and sentencepiece - jieba, rouge_chinese and nltk (used at evaluation) - gradio and mdtex2html (used in web_demo.py) +- uvicorn and fastapi (used in api_demo.py) And **powerful GPUs**! @@ -99,7 +101,7 @@ cd LLaMA-Efficient-Tuning pip install -r requirements.txt ``` -### LLaMA Weights Preparation +### LLaMA Weights Preparation (optional) 1. Download the weights of the LLaMA models. 2. Convert them to HF format using the following command. @@ -216,17 +218,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \ We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. -### CLI Demo +### API / CLI / Web Demo ```bash -python src/cli_demo.py \ - --model_name_or_path path_to_your_model \ - --checkpoint_dir path_to_checkpoint -``` - -### Web Demo -```bash -python src/web_demo.py \ +python src/xxx_demo.py \ --model_name_or_path path_to_your_model \ --checkpoint_dir path_to_checkpoint ``` diff --git a/requirements.txt b/requirements.txt index 4d583e8d..6079fb2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,4 @@ torch>=1.13.1 -protobuf -cpm_kernels -sentencepiece transformers>=4.29.1 datasets>=2.12.0 accelerate>=0.19.0 @@ -12,3 +9,5 @@ rouge_chinese nltk gradio mdtex2html +uvicorn +fastapi diff --git a/src/api_demo.py b/src/api_demo.py index a61fd632..b05157a1 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Implements API for fine-tuned models. +# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Visit http://localhost:8000/docs for document. @@ -7,11 +7,10 @@ import time import torch import uvicorn -from fastapi import FastAPI, HTTPException from threading import Thread -from contextlib import asynccontextmanager - from pydantic import BaseModel, Field +from fastapi import FastAPI, HTTPException +from contextlib import asynccontextmanager from transformers import TextIteratorStreamer from starlette.responses import StreamingResponse from typing import Any, Dict, List, Literal, Optional, Union @@ -68,14 +67,14 @@ class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponse(BaseModel): model: str - object: str + object: Literal["chat.completion", "chat.completion.chunk"] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): - global model, tokenizer, source_prefix + global model, tokenizer, source_prefix, generating_args if request.messages[-1].role != "user": raise HTTPException(status_code=400, detail="Invalid request") @@ -83,7 +82,9 @@ async def create_chat_completion(request: ChatCompletionRequest): prev_messages = request.messages[:-1] if len(prev_messages) > 0 and prev_messages[0].role == "system": - source_prefix = prev_messages.pop(0).content + prefix = prev_messages.pop(0).content + else: + prefix = source_prefix history = [] if len(prev_messages) % 2 == 0: @@ -91,7 +92,7 @@ async def create_chat_completion(request: ChatCompletionRequest): if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": history.append([prev_messages[i].content, prev_messages[i+1].content]) - inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt") + inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt") inputs = inputs.to(model.device) gen_kwargs = generating_args.to_dict() @@ -134,7 +135,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str): delta=DeltaMessage(role="assistant"), finish_reason=None ) - chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk") + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) for new_text in streamer: @@ -146,7 +147,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str): delta=DeltaMessage(content=new_text), finish_reason=None ) - chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk") + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) choice_data = ChatCompletionResponseStreamChoice( @@ -154,7 +155,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str): delta=DeltaMessage(), finish_reason="stop" ) - chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk") + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))