update readme

Former-commit-id: 0697643358ade295f3c6eb239765d231b46afe0b
This commit is contained in:
hiyouga 2023-06-23 00:17:05 +08:00
parent 0c7eb90f6b
commit cf29a9af35
3 changed files with 22 additions and 27 deletions

View File

@ -9,11 +9,13 @@
## Changelog ## Changelog
[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` argument to use the baichuan-7B model. [23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications.
[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` and `--lora_target W_pack` arguments to use the baichuan-7B model.
[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature) [23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model. [23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
## Supported Models ## Supported Models
@ -75,9 +77,9 @@ huggingface-cli login
- Python 3.8+ and PyTorch 1.13.1+ - Python 3.8+ and PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL - 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- protobuf, cpm_kernels and sentencepiece
- jieba, rouge_chinese and nltk (used at evaluation) - jieba, rouge_chinese and nltk (used at evaluation)
- gradio and mdtex2html (used in web_demo.py) - gradio and mdtex2html (used in web_demo.py)
- uvicorn and fastapi (used in api_demo.py)
And **powerful GPUs**! And **powerful GPUs**!
@ -99,7 +101,7 @@ cd LLaMA-Efficient-Tuning
pip install -r requirements.txt pip install -r requirements.txt
``` ```
### LLaMA Weights Preparation ### LLaMA Weights Preparation (optional)
1. Download the weights of the LLaMA models. 1. Download the weights of the LLaMA models.
2. Convert them to HF format using the following command. 2. Convert them to HF format using the following command.
@ -216,17 +218,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation. We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
### CLI Demo ### API / CLI / Web Demo
```bash ```bash
python src/cli_demo.py \ python src/xxx_demo.py \
--model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint
```
### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_your_model \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```

View File

@ -1,7 +1,4 @@
torch>=1.13.1 torch>=1.13.1
protobuf
cpm_kernels
sentencepiece
transformers>=4.29.1 transformers>=4.29.1
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.19.0 accelerate>=0.19.0
@ -12,3 +9,5 @@ rouge_chinese
nltk nltk
gradio gradio
mdtex2html mdtex2html
uvicorn
fastapi

View File

@ -1,5 +1,5 @@
# coding=utf-8 # coding=utf-8
# Implements API for fine-tuned models. # Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document. # Visit http://localhost:8000/docs for document.
@ -7,11 +7,10 @@
import time import time
import torch import torch
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException
from threading import Thread from threading import Thread
from contextlib import asynccontextmanager
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from fastapi import FastAPI, HTTPException
from contextlib import asynccontextmanager
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from typing import Any, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
@ -68,14 +67,14 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
model: str model: str
object: str object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
global model, tokenizer, source_prefix global model, tokenizer, source_prefix, generating_args
if request.messages[-1].role != "user": if request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
@ -83,7 +82,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system": if len(prev_messages) > 0 and prev_messages[0].role == "system":
source_prefix = prev_messages.pop(0).content prefix = prev_messages.pop(0).content
else:
prefix = source_prefix
history = [] history = []
if len(prev_messages) % 2 == 0: if len(prev_messages) % 2 == 0:
@ -91,7 +92,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt") inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(model.device) inputs = inputs.to(model.device)
gen_kwargs = generating_args.to_dict() gen_kwargs = generating_args.to_dict()
@ -134,7 +135,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role="assistant"),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
for new_text in streamer: for new_text in streamer:
@ -146,7 +147,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(content=new_text), delta=DeltaMessage(content=new_text),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
@ -154,7 +155,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
delta=DeltaMessage(), delta=DeltaMessage(),
finish_reason="stop" finish_reason="stop"
) )
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False)) yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))