mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
update readme
Former-commit-id: 0697643358ade295f3c6eb239765d231b46afe0b
This commit is contained in:
parent
0c7eb90f6b
commit
cf29a9af35
21
README.md
21
README.md
@ -9,11 +9,13 @@
|
|||||||
|
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` argument to use the baichuan-7B model.
|
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in arbitrary ChatGPT-based applications.
|
||||||
|
|
||||||
|
[23/06/15] Now we support training the baichuan-7B model in this repo. Try `--model_name_or_path baichuan-inc/baichuan-7B` and `--lora_target W_pack` arguments to use the baichuan-7B model.
|
||||||
|
|
||||||
[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
|
||||||
|
|
||||||
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
|
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
@ -75,9 +77,9 @@ huggingface-cli login
|
|||||||
|
|
||||||
- Python 3.8+ and PyTorch 1.13.1+
|
- Python 3.8+ and PyTorch 1.13.1+
|
||||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||||
- protobuf, cpm_kernels and sentencepiece
|
|
||||||
- jieba, rouge_chinese and nltk (used at evaluation)
|
- jieba, rouge_chinese and nltk (used at evaluation)
|
||||||
- gradio and mdtex2html (used in web_demo.py)
|
- gradio and mdtex2html (used in web_demo.py)
|
||||||
|
- uvicorn and fastapi (used in api_demo.py)
|
||||||
|
|
||||||
And **powerful GPUs**!
|
And **powerful GPUs**!
|
||||||
|
|
||||||
@ -99,7 +101,7 @@ cd LLaMA-Efficient-Tuning
|
|||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
### LLaMA Weights Preparation
|
### LLaMA Weights Preparation (optional)
|
||||||
|
|
||||||
1. Download the weights of the LLaMA models.
|
1. Download the weights of the LLaMA models.
|
||||||
2. Convert them to HF format using the following command.
|
2. Convert them to HF format using the following command.
|
||||||
@ -216,17 +218,10 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
|
|||||||
|
|
||||||
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||||
|
|
||||||
### CLI Demo
|
### API / CLI / Web Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python src/cli_demo.py \
|
python src/xxx_demo.py \
|
||||||
--model_name_or_path path_to_your_model \
|
|
||||||
--checkpoint_dir path_to_checkpoint
|
|
||||||
```
|
|
||||||
|
|
||||||
### Web Demo
|
|
||||||
```bash
|
|
||||||
python src/web_demo.py \
|
|
||||||
--model_name_or_path path_to_your_model \
|
--model_name_or_path path_to_your_model \
|
||||||
--checkpoint_dir path_to_checkpoint
|
--checkpoint_dir path_to_checkpoint
|
||||||
```
|
```
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
torch>=1.13.1
|
torch>=1.13.1
|
||||||
protobuf
|
|
||||||
cpm_kernels
|
|
||||||
sentencepiece
|
|
||||||
transformers>=4.29.1
|
transformers>=4.29.1
|
||||||
datasets>=2.12.0
|
datasets>=2.12.0
|
||||||
accelerate>=0.19.0
|
accelerate>=0.19.0
|
||||||
@ -12,3 +9,5 @@ rouge_chinese
|
|||||||
nltk
|
nltk
|
||||||
gradio
|
gradio
|
||||||
mdtex2html
|
mdtex2html
|
||||||
|
uvicorn
|
||||||
|
fastapi
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Implements API for fine-tuned models.
|
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
|
||||||
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||||
# Visit http://localhost:8000/docs for document.
|
# Visit http://localhost:8000/docs for document.
|
||||||
|
|
||||||
@ -7,11 +7,10 @@
|
|||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
@ -68,14 +67,14 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
|||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
object: str
|
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
global model, tokenizer, source_prefix
|
global model, tokenizer, source_prefix, generating_args
|
||||||
|
|
||||||
if request.messages[-1].role != "user":
|
if request.messages[-1].role != "user":
|
||||||
raise HTTPException(status_code=400, detail="Invalid request")
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
@ -83,7 +82,9 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
|
|
||||||
prev_messages = request.messages[:-1]
|
prev_messages = request.messages[:-1]
|
||||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
||||||
source_prefix = prev_messages.pop(0).content
|
prefix = prev_messages.pop(0).content
|
||||||
|
else:
|
||||||
|
prefix = source_prefix
|
||||||
|
|
||||||
history = []
|
history = []
|
||||||
if len(prev_messages) % 2 == 0:
|
if len(prev_messages) % 2 == 0:
|
||||||
@ -91,7 +92,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||||
|
|
||||||
inputs = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")
|
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
|
||||||
inputs = inputs.to(model.device)
|
inputs = inputs.to(model.device)
|
||||||
|
|
||||||
gen_kwargs = generating_args.to_dict()
|
gen_kwargs = generating_args.to_dict()
|
||||||
@ -134,7 +135,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||||||
delta=DeltaMessage(role="assistant"),
|
delta=DeltaMessage(role="assistant"),
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||||
|
|
||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
@ -146,7 +147,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||||||
delta=DeltaMessage(content=new_text),
|
delta=DeltaMessage(content=new_text),
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
@ -154,7 +155,7 @@ async def predict(gen_kwargs: Dict[str, Any], model_id: str):
|
|||||||
delta=DeltaMessage(),
|
delta=DeltaMessage(),
|
||||||
finish_reason="stop"
|
finish_reason="stop"
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object = "chat.completion.chunk")
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||||||
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
yield "data: {}\n\n".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user