From 1fbf83acc2f053a21f7173870ac2880995e4ef02 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 21 Jan 2024 01:47:33 +0800 Subject: [PATCH] finish agent Former-commit-id: 3e982cc7146b408b359bb38ed4aa840183b23d72 --- src/llmtuner/api/app.py | 61 ++++++++++++++++-------- src/llmtuner/api/protocol.py | 22 +++++++-- src/llmtuner/chat/chat_model.py | 4 +- src/llmtuner/data/formatter.py | 3 ++ src/llmtuner/webui/chatter.py | 31 +++++++++--- src/llmtuner/webui/components/chatbot.py | 12 ++--- src/llmtuner/webui/locales.py | 2 + src/llmtuner/webui/utils.py | 11 +++-- 8 files changed, 105 insertions(+), 41 deletions(-) diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 973620af..2147a1db 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -7,18 +7,20 @@ from typing import Any, Dict, Sequence from pydantic import BaseModel from ..chat import ChatModel +from ..data import Role as DataRole from ..extras.misc import torch_gc from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available from .protocol import ( + ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionResponseUsage, ChatCompletionStreamResponse, - ChatMessage, - DeltaMessage, Finish, + Function, + FunctionCall, ModelCard, ModelList, Role, @@ -84,7 +86,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") if len(request.messages) == 0 or request.messages[-1].role != Role.USER: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") messages = [dictify(message) for message in request.messages] if len(messages) and messages[0]["role"] == Role.SYSTEM: @@ -96,16 +98,21 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") for i in range(len(messages)): - if messages[i]["role"] == Role.USER: - if i % 2 == 1: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") - elif messages[i]["role"] == Role.ASSISTANT: - if i % 2 == 0: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") - else: - raise NotImplementedError + if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") + elif messages[i]["role"] == Role.TOOL: + messages[i]["role"] = DataRole.OBSERVATION - tools = "" # TODO: add tools + tool_list = request.tools + if len(tool_list): + try: + tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False) + except Exception: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") + else: + tools = "" async with semaphore: loop = asyncio.get_running_loop() @@ -130,12 +137,24 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": prompt_length, response_length = 0, 0 choices = [] for i, response in enumerate(responses): - choices.append( - ChatCompletionResponseChoice( - index=i, - message=ChatMessage(role=Role.ASSISTANT, content=response.response_text), - finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH, + if tools: + result = chat_model.template.format_tools.extract(response.response_text) + else: + result = response.response_text + + if isinstance(result, tuple): + name, arguments = result + function = Function(name=name, arguments=arguments) + response_message = ChatCompletionMessage( + role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)] ) + finish_reason = Finish.TOOL + else: + response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) + finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH + + choices.append( + ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason) ) prompt_length = response.prompt_length response_length += response.response_length @@ -152,7 +171,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest ): choice_data = ChatCompletionResponseStreamChoice( - index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None + index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) yield jsonify(chunk) @@ -170,12 +189,14 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": continue choice_data = ChatCompletionResponseStreamChoice( - index=0, delta=DeltaMessage(content=new_text), finish_reason=None + index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) yield jsonify(chunk) - choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP) + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP + ) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) yield jsonify(chunk) yield "[DONE]" diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index 94c9acce..7a69acb4 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -11,12 +11,15 @@ class Role(str, Enum): USER = "user" ASSISTANT = "assistant" SYSTEM = "system" + FUNCTION = "function" + TOOL = "tool" @unique class Finish(str, Enum): STOP = "stop" LENGTH = "length" + TOOL = "tool_calls" class ModelCard(BaseModel): @@ -31,19 +34,32 @@ class ModelList(BaseModel): data: List[ModelCard] = [] +class Function(BaseModel): + name: str + arguments: str + + +class FunctionCall(BaseModel): + id: Literal["call_default"] = "call_default" + type: Literal["function"] = "function" + function: Function + + class ChatMessage(BaseModel): role: Role content: str -class DeltaMessage(BaseModel): +class ChatCompletionMessage(BaseModel): role: Optional[Role] = None content: Optional[str] = None + tool_calls: Optional[List[FunctionCall]] = None class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] + tools: Optional[list] = [] do_sample: bool = True temperature: Optional[float] = None top_p: Optional[float] = None @@ -54,13 +70,13 @@ class ChatCompletionRequest(BaseModel): class ChatCompletionResponseChoice(BaseModel): index: int - message: ChatMessage + message: ChatCompletionMessage finish_reason: Finish class ChatCompletionResponseStreamChoice(BaseModel): index: int - delta: DeltaMessage + delta: ChatCompletionMessage finish_reason: Optional[Finish] = None diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index 1a8f95aa..cbc831b2 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -37,9 +37,9 @@ class ChatModel: tools: Optional[str] = None, **input_kwargs, ) -> Tuple[Dict[str, Any], int]: - new_messages = messages + [{"role": "assistant", "content": ""}] + paired_messages = messages + [{"role": "assistant", "content": ""}] prompt, _ = self.template.encode_oneturn( - tokenizer=self.tokenizer, messages=new_messages, system=system, tools=tools + tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools ) prompt_length = len(prompt) input_ids = torch.tensor([prompt], device=self.model.device) diff --git a/src/llmtuner/data/formatter.py b/src/llmtuner/data/formatter.py index 934cb904..07438489 100644 --- a/src/llmtuner/data/formatter.py +++ b/src/llmtuner/data/formatter.py @@ -74,6 +74,9 @@ class Formatter(ABC): def apply(self, **kwargs) -> SLOTS: ... + def extract(self, content: str) -> Union[str, Tuple[str, str]]: + raise NotImplementedError + @dataclass class EmptyFormatter(Formatter): diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index e16328a0..de8b0ca0 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -1,9 +1,11 @@ -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple +import json +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple import gradio as gr from gradio.components import Component # cannot use TYPE_CHECKING here from ..chat import ChatModel +from ..data import Role from ..extras.misc import torch_gc from ..hparams import GeneratingArguments from .common import get_save_dir @@ -105,22 +107,37 @@ class WebChatModel(ChatModel): self, chatbot: List[Tuple[str, str]], query: str, - history: List[Tuple[str, str]], + messages: Sequence[Tuple[str, str]], system: str, tools: str, max_new_tokens: int, top_p: float, temperature: float, - ) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]: + ) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]: chatbot.append([query, ""]) + query_messages = messages + [{"role": Role.USER, "content": query}] response = "" for new_text in self.stream_chat( - query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature ): response += new_text - new_history = history + [(query, response)] - chatbot[-1] = [query, self.postprocess(response)] - yield chatbot, new_history + if tools: + result = self.template.format_tools.extract(response) + else: + result = response + + if isinstance(result, tuple): + name, arguments = result + arguments = json.loads(arguments) + tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False) + output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}] + bot_text = "```json\n" + tool_call + "\n```" + else: + output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}] + bot_text = result + + chatbot[-1] = [query, self.postprocess(bot_text)] + yield chatbot, output_messages def postprocess(self, response: str) -> str: blocks = response.split("```") diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 554fb686..fe739dd0 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -17,7 +17,7 @@ def create_chat_box( ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: with gr.Box(visible=visible) as chat_box: chatbot = gr.Chatbot() - history = gr.State([]) + messages = gr.State([]) with gr.Row(): with gr.Column(scale=4): system = gr.Textbox(show_label=False) @@ -32,21 +32,21 @@ def create_chat_box( top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) - tools.input(check_json_schema, [tools]) + tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")]) submit_btn.click( engine.chatter.predict, - [chatbot, query, history, system, tools, max_new_tokens, top_p, temperature], - [chatbot, history], + [chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature], + [chatbot, messages], show_progress=True, ).then(lambda: gr.update(value=""), outputs=[query]) - clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) + clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True) return ( chat_box, chatbot, - history, + messages, dict( system=system, tools=tools, diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 77f5854e..718dc57c 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -208,6 +208,8 @@ ALERTS = { "zh": "展示模式不支持训练,请先复制到私人空间。", }, "err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"}, + "err_tool_name": {"en": "Tool name not found.", "zh": "工具名称未找到。"}, + "err_json_schema": {"en": "Invalid JSON schema.", "zh": "Json 格式错误。"}, "info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"}, "info_aborted": {"en": "Ready.", "zh": "准备就绪。"}, "info_finished": {"en": "Finished.", "zh": "训练完毕。"}, diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index c94fa3e6..1d63f23c 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -8,6 +8,7 @@ import gradio as gr from ..extras.packages import is_matplotlib_available from ..extras.ploting import smooth from .common import get_save_dir +from .locales import ALERTS if TYPE_CHECKING: @@ -40,11 +41,15 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: return gr.update(interactive=True) -def check_json_schema(text: str) -> None: +def check_json_schema(text: str, lang: str) -> None: try: - json.loads(text) + tools = json.loads(text) + for tool in tools: + assert "name" in tool + except AssertionError: + gr.Warning(ALERTS["err_tool_name"][lang]) except json.JSONDecodeError: - gr.Warning("Invalid JSON schema") + gr.Warning(ALERTS["err_json_schema"][lang]) def gen_cmd(args: Dict[str, Any]) -> str: