From b8b38a9adea8fa7cc229790ec13bb41860a2b655 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 15 Jul 2023 19:26:20 +0800 Subject: [PATCH] create chat model Former-commit-id: 657cf0f55a7f0886bc837bdd44528971dc5e5caa --- src/api_demo.py | 1 - src/cli_demo.py | 49 ++++----------- src/llmtuner/__init__.py | 3 +- src/llmtuner/api/app.py | 66 ++++++-------------- src/llmtuner/chat/__init__.py | 1 + src/llmtuner/chat/stream_chat.py | 82 +++++++++++++++++++++++++ src/llmtuner/hparams/data_args.py | 2 +- src/llmtuner/hparams/finetuning_args.py | 2 +- 8 files changed, 117 insertions(+), 89 deletions(-) create mode 100644 src/llmtuner/chat/__init__.py create mode 100644 src/llmtuner/chat/stream_chat.py diff --git a/src/api_demo.py b/src/api_demo.py index f27df455..3041b2e1 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -3,7 +3,6 @@ # Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint # Visit http://localhost:8000/docs for document. - import uvicorn from llmtuner import create_app diff --git a/src/cli_demo.py b/src/cli_demo.py index 1a32e35c..dd91da77 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -2,46 +2,11 @@ # Implements stream chat in command line for fine-tuned models. # Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint -from threading import Thread -from transformers import TextIteratorStreamer - -from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor +from llmtuner import ChatModel, get_infer_args def main(): - model_args, data_args, finetuning_args, generating_args = get_infer_args() - model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) - - prompt_template = Template(data_args.prompt_template) - source_prefix = data_args.source_prefix if data_args.source_prefix else "" - - def predict_and_print(query, history: list) -> list: - input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"] - input_ids = input_ids.to(model.device) - - streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - - gen_kwargs = generating_args.to_dict() - gen_kwargs.update({ - "input_ids": input_ids, - "logits_processor": get_logits_processor(), - "streamer": streamer - }) - - thread = Thread(target=model.generate, kwargs=gen_kwargs) - thread.start() - - print("Assistant: ", end="", flush=True) - - response = "" - for new_text in streamer: - print(new_text, end="", flush=True) - response += new_text - print() - - history = history + [(query, response)] - return history - + chat_model = ChatModel(*get_infer_args()) history = [] print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") @@ -62,7 +27,15 @@ def main(): print("History has been removed.") continue - history = predict_and_print(query, history) + print("Assistant: ", end="", flush=True) + + response = "" + for new_text in chat_model.stream_chat(query, history): + print(new_text, end="", flush=True) + response += new_text + print() + + history = history + [(query, response)] if __name__ == "__main__": diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index bcbac9db..9785981a 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -1,6 +1,5 @@ from llmtuner.api import create_app -from llmtuner.extras.misc import get_logits_processor -from llmtuner.extras.template import Template +from llmtuner.chat import ChatModel from llmtuner.tuner import get_train_args, get_infer_args, load_model_and_tokenizer, run_pt, run_sft, run_rm, run_ppo diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 3f31cb9a..12a4d95c 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -1,15 +1,13 @@ import uvicorn -from threading import Thread from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware -from transformers import TextIteratorStreamer from contextlib import asynccontextmanager from sse_starlette import EventSourceResponse -from typing import Any, Dict +from typing import List, Tuple -from llmtuner.tuner import get_infer_args, load_model_and_tokenizer -from llmtuner.extras.misc import get_logits_processor, torch_gc -from llmtuner.extras.template import Template +from llmtuner.tuner import get_infer_args +from llmtuner.extras.misc import torch_gc +from llmtuner.chat.stream_chat import ChatModel from llmtuner.api.protocol import ( ModelCard, ModelList, @@ -31,11 +29,7 @@ async def lifespan(app: FastAPI): # collects GPU memory def create_app(): - model_args, data_args, finetuning_args, generating_args = get_infer_args() - model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) - - prompt_template = Template(data_args.prompt_template) - source_prefix = data_args.source_prefix if data_args.source_prefix else "" + chat_model = ChatModel(*get_infer_args()) app = FastAPI(lifespan=lifespan) @@ -49,7 +43,6 @@ def create_app(): @app.get("/v1/models", response_model=ModelList) async def list_models(): - global model_args model_card = ModelCard(id="gpt-3.5-turbo") return ModelList(data=[model_card]) @@ -63,7 +56,7 @@ def create_app(): if len(prev_messages) > 0 and prev_messages[0].role == "system": prefix = prev_messages.pop(0).content else: - prefix = source_prefix + prefix = None history = [] if len(prev_messages) % 2 == 0: @@ -71,33 +64,18 @@ def create_app(): if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": history.append([prev_messages[i].content, prev_messages[i+1].content]) - inputs = tokenizer([prompt_template.get_prompt(query, history, prefix)], return_tensors="pt") - inputs = inputs.to(model.device) - - gen_kwargs = generating_args.to_dict() - gen_kwargs.update({ - "input_ids": inputs["input_ids"], - "temperature": request.temperature if request.temperature else gen_kwargs["temperature"], - "top_p": request.top_p if request.top_p else gen_kwargs["top_p"], - "logits_processor": get_logits_processor() - }) - - if request.max_tokens: - gen_kwargs.pop("max_length", None) - gen_kwargs["max_new_tokens"] = request.max_tokens - if request.stream: - generate = predict(gen_kwargs, request.model) + generate = predict(query, history, prefix, request) return EventSourceResponse(generate, media_type="text/event-stream") - generation_output = model.generate(**gen_kwargs) - outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] - response = tokenizer.decode(outputs, skip_special_tokens=True) + response, (prompt_length, response_length) = chat_model.chat( + query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + ) usage = ChatCompletionResponseUsage( - prompt_tokens=len(inputs["input_ids"][0]), - completion_tokens=len(outputs), - total_tokens=len(inputs["input_ids"][0]) + len(outputs) + prompt_tokens=prompt_length, + completion_tokens=response_length, + total_tokens=prompt_length+response_length ) choice_data = ChatCompletionResponseChoice( @@ -108,22 +86,18 @@ def create_app(): return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") - async def predict(gen_kwargs: Dict[str, Any], model_id: str): - streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - gen_kwargs["streamer"] = streamer - - thread = Thread(target=model.generate, kwargs=gen_kwargs) - thread.start() - + async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) - chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") yield chunk.json(exclude_unset=True, ensure_ascii=False) - for new_text in streamer: + for new_text in chat_model.stream_chat( + query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + ): if len(new_text) == 0: continue @@ -132,7 +106,7 @@ def create_app(): delta=DeltaMessage(content=new_text), finish_reason=None ) - chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") yield chunk.json(exclude_unset=True, ensure_ascii=False) choice_data = ChatCompletionResponseStreamChoice( @@ -140,7 +114,7 @@ def create_app(): delta=DeltaMessage(), finish_reason="stop" ) - chunk = ChatCompletionStreamResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") yield chunk.json(exclude_unset=True, ensure_ascii=False) yield "[DONE]" diff --git a/src/llmtuner/chat/__init__.py b/src/llmtuner/chat/__init__.py new file mode 100644 index 00000000..ba240d05 --- /dev/null +++ b/src/llmtuner/chat/__init__.py @@ -0,0 +1 @@ +from llmtuner.chat.stream_chat import ChatModel diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py new file mode 100644 index 00000000..07fc7bf0 --- /dev/null +++ b/src/llmtuner/chat/stream_chat.py @@ -0,0 +1,82 @@ +from typing import Any, Dict, Generator, List, Optional, Tuple +from threading import Thread +from transformers import TextIteratorStreamer + +from llmtuner.extras.misc import get_logits_processor +from llmtuner.extras.template import Template +from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments +from llmtuner.tuner import load_model_and_tokenizer + + +class ChatModel: + + def __init__( + self, + model_args: ModelArguments, + data_args: DataArguments, + finetuning_args: FinetuningArguments, + generating_args: GeneratingArguments + ) -> None: + self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + self.template = Template(data_args.prompt_template) + self.source_prefix = data_args.source_prefix if data_args.source_prefix else "" + self.generating_args = generating_args + + def process_args( + self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs + ) -> Tuple[Dict[str, Any], int]: + prefix = prefix if prefix else self.source_prefix + + inputs = self.tokenizer([self.template.get_prompt(query, history, prefix)], return_tensors="pt") + inputs = inputs.to(self.model.device) + prompt_length = len(inputs["input_ids"][0]) + + temperature = input_kwargs.pop("temperature", None) + top_p = input_kwargs.pop("top_p", None) + top_k = input_kwargs.pop("top_k", None) + repetition_penalty = input_kwargs.pop("repetition_penalty", None) + max_length = input_kwargs.pop("max_length", None) + max_new_tokens = input_kwargs.pop("max_new_tokens", None) + + gen_kwargs = self.generating_args.to_dict() + gen_kwargs.update(dict( + input_ids=inputs["input_ids"], + temperature=temperature if temperature else gen_kwargs["temperature"], + top_p=top_p if top_p else gen_kwargs["top_p"], + top_k=top_k if top_k else gen_kwargs["top_k"], + repetition_penalty=repetition_penalty if repetition_penalty else gen_kwargs["repetition_penalty"], + logits_processor=get_logits_processor() + )) + + if max_length: + gen_kwargs.pop("max_new_tokens", None) + gen_kwargs["max_length"] = max_length + + if max_new_tokens: + gen_kwargs.pop("max_length", None) + gen_kwargs["max_new_tokens"] = max_new_tokens + + return gen_kwargs, prompt_length + + def chat( + self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs + ) -> Tuple[str, Tuple[int, int]]: + gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) + generation_output = self.model.generate(**gen_kwargs) + outputs = generation_output.tolist()[0][prompt_length:] + response = self.tokenizer.decode(outputs, skip_special_tokens=True) + response_length = len(outputs) + return response, (prompt_length, response_length) + + def stream_chat( + self, query: str, history: List[Tuple[str, str]], prefix: Optional[str] = None, **input_kwargs + ) -> Generator[str, None, None]: + gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) + streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) + gen_kwargs["streamer"] = streamer + + thread = Thread(target=self.model.generate, kwargs=gen_kwargs) + thread.start() + + for new_text in streamer: + yield new_text diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index df4c0557..7e68486b 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -29,7 +29,7 @@ class DataArguments: """ dataset: Optional[str] = field( default="alpaca_zh", - metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."} + metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} ) dataset_dir: Optional[str] = field( default="data", diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index 6f01ef29..23bb324b 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -45,7 +45,7 @@ class FinetuningArguments: ) lora_target: Optional[str] = field( default="q_proj,v_proj", - metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ + metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}