create chat model

Former-commit-id: 657cf0f55a7f0886bc837bdd44528971dc5e5caa
This commit is contained in:
hiyouga 2023-07-15 19:26:20 +08:00
parent 8ba0996a53
commit b8b38a9ade
8 changed files with 117 additions and 89 deletions

View File

@ -3,7 +3,6 @@
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document. # Visit http://localhost:8000/docs for document.
import uvicorn import uvicorn
from llmtuner import create_app from llmtuner import create_app

View File

@ -2,46 +2,11 @@
# Implements stream chat in command line for fine-tuned models. # Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from threading import Thread from llmtuner import ChatModel, get_infer_args
from transformers import TextIteratorStreamer
from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor
def main(): def main():
model_args, data_args, finetuning_args, generating_args = get_infer_args() chat_model = ChatModel(*get_infer_args())
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
def predict_and_print(query, history: list) -> list:
input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = generating_args.to_dict()
gen_kwargs.update({
"input_ids": input_ids,
"logits_processor": get_logits_processor(),
"streamer": streamer
})
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
print("Assistant: ", end="", flush=True)
response = ""
for new_text in streamer:
print(new_text, end="", flush=True)
response += new_text
print()
history = history + [(query, response)]
return history
history = [] history = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
@ -62,7 +27,15 @@ def main():
print("History has been removed.") print("History has been removed.")
continue continue
history = predict_and_print(query, history) print("Assistant: ", end="", flush=True)
response = ""
for new_text in chat_model.stream_chat(query, history):
print(new_text, end="", flush=True)
response += new_text
print()
history = history + [(query, response)]
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,6 +1,5 @@
from llmtuner.api import create_app from llmtuner.api import create_app
from llmtuner.extras.misc import get_logits_processor from llmtuner.chat import ChatModel
from llmtuner.extras.template import Template
from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo

View File

@ -1,15 +1,13 @@
import uvicorn import uvicorn
from threading import Thread
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from transformers import TextIteratorStreamer
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from typing import Any, Dict from typing import List, Tuple
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import get_logits_processor, torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.extras.template import Template from llmtuner.chat.stream_chat import ChatModel
from llmtuner.api.protocol import ( from llmtuner.api.protocol import (
ModelCard, ModelCard,
ModelList, ModelList,
@ -31,11 +29,7 @@ async def lifespan(app: FastAPI): # collects GPU memory
def create_app(): def create_app():
model_args, data_args, finetuning_args, generating_args = get_infer_args() chat_model = ChatModel(*get_infer_args())
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
prompt_template = Template(data_args.prompt_template)
source_prefix = data_args.source_prefix if data_args.source_prefix else ""
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
@ -49,7 +43,6 @@ def create_app():
@app.get("/v1/models", response_model=ModelList) @app.get("/v1/models", response_model=ModelList)
async def list_models(): async def list_models():
global model_args
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id="gpt-3.5-turbo")
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@ -63,7 +56,7 @@ def create_app():
if len(prev_messages) > 0 and prev_messages[0].role == "system": if len(prev_messages) > 0 and prev_messages[0].role == "system":
prefix = prev_messages.pop(0).content prefix = prev_messages.pop(0).content
else: else:
prefix = source_prefix prefix = None
history = [] history = []
if len(prev_messages) % 2 == 0: if len(prev_messages) % 2 == 0:
@ -71,33 +64,18 @@ def create_app():
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(model.device)
gen_kwargs = generating_args.to_dict()
gen_kwargs.update({
"input_ids": inputs["input_ids"],
"temperature": request.temperature if request.temperature else gen_kwargs["temperature"],
"top_p": request.top_p if request.top_p else gen_kwargs["top_p"],
"logits_processor": get_logits_processor()
})
if request.max_tokens:
gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = request.max_tokens
if request.stream: if request.stream:
generate = predict(gen_kwargs, request.model) generate = predict(query, history, prefix, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
generation_output = model.generate(**gen_kwargs) response, (prompt_length, response_length) = chat_model.chat(
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
response = tokenizer.decode(outputs, skip_special_tokens=True) )
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
prompt_tokens=len(inputs["input_ids"][0]), prompt_tokens=prompt_length,
completion_tokens=len(outputs), completion_tokens=response_length,
total_tokens=len(inputs["input_ids"][0]) + len(outputs) total_tokens=prompt_length+response_length
) )
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
@ -108,22 +86,18 @@ def create_app():
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
async def predict(gen_kwargs: Dict[str, Any], model_id: str): async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role="assistant"),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in streamer: for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
):
if len(new_text) == 0: if len(new_text) == 0:
continue continue
@ -132,7 +106,7 @@ def create_app():
delta=DeltaMessage(content=new_text), delta=DeltaMessage(content=new_text),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
@ -140,7 +114,7 @@ def create_app():
delta=DeltaMessage(), delta=DeltaMessage(),
finish_reason="stop" finish_reason="stop"
) )
chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
yield "[DONE]" yield "[DONE]"

View File

@ -0,0 +1 @@
from llmtuner.chat.stream_chat import ChatModel

View File

@ -0,0 +1,82 @@
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.template import Template
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
from llmtuner.tuner import load_model_and_tokenizer
class ChatModel:
def __init__(
self,
model_args: ModelArguments,
data_args: DataArguments,
finetuning_args: FinetuningArguments,
generating_args: GeneratingArguments
) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.template = Template(data_args.prompt_template)
self.source_prefix = data_args.source_prefix if data_args.source_prefix else ""
self.generating_args = generating_args
def process_args(
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
) -> Tuple[Dict[str, Any], int]:
prefix = prefix if prefix else self.source_prefix
inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt")
inputs = inputs.to(self.model.device)
prompt_length = len(inputs["input_ids"][0])
temperature = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", None)
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict(
input_ids=inputs["input_ids"],
temperature=temperature if temperature else gen_kwargs["temperature"],
top_p=top_p if top_p else gen_kwargs["top_p"],
top_k=top_k if top_k else gen_kwargs["top_k"],
repetition_penalty=repetition_penalty if repetition_penalty else gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor()
))
if max_length:
gen_kwargs.pop("max_new_tokens", None)
gen_kwargs["max_length"] = max_length
if max_new_tokens:
gen_kwargs.pop("max_length", None)
gen_kwargs["max_new_tokens"] = max_new_tokens
return gen_kwargs, prompt_length
def chat(
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
response_length = len(outputs)
return response, (prompt_length, response_length)
def stream_chat(
self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs
) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()
for new_text in streamer:
yield new_text

View File

@ -29,7 +29,7 @@ class DataArguments:
""" """
dataset: Optional[str] = field( dataset: Optional[str] = field(
default="alpaca_zh", default="alpaca_zh",
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."} metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
) )
dataset_dir: Optional[str] = field( dataset_dir: Optional[str] = field(
default="data", default="data",

View File

@ -45,7 +45,7 @@ class FinetuningArguments:
) )
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default="q_proj,v_proj", default="q_proj,v_proj",
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}