finish agent

Former-commit-id: 3e982cc7146b408b359bb38ed4aa840183b23d72
This commit is contained in:
hiyouga 2024-01-21 01:47:33 +08:00
parent 865f48f1c3
commit 1fbf83acc2
8 changed files with 105 additions and 41 deletions

View File

@ -7,18 +7,20 @@ from typing import Any, Dict, Sequence
from pydantic import BaseModel from pydantic import BaseModel
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role as DataRole
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
from .protocol import ( from .protocol import (
ChatCompletionMessage,
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage, ChatCompletionResponseUsage,
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
Finish, Finish,
Function,
FunctionCall,
ModelCard, ModelCard,
ModelList, ModelList,
Role, Role,
@ -84,7 +86,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0 or request.messages[-1].role != Role.USER: if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
messages = [dictify(message) for message in request.messages] messages = [dictify(message) for message in request.messages]
if len(messages) and messages[0]["role"] == Role.SYSTEM: if len(messages) and messages[0]["role"] == Role.SYSTEM:
@ -96,16 +98,21 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
for i in range(len(messages)): for i in range(len(messages)):
if messages[i]["role"] == Role.USER: if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]:
if i % 2 == 1: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]:
elif messages[i]["role"] == Role.ASSISTANT: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if i % 2 == 0: elif messages[i]["role"] == Role.TOOL:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") messages[i]["role"] = DataRole.OBSERVATION
else:
raise NotImplementedError
tools = "" # TODO: add tools tool_list = request.tools
if len(tool_list):
try:
tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False)
except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else:
tools = ""
async with semaphore: async with semaphore:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
@ -130,12 +137,24 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
prompt_length, response_length = 0, 0 prompt_length, response_length = 0, 0
choices = [] choices = []
for i, response in enumerate(responses): for i, response in enumerate(responses):
choices.append( if tools:
ChatCompletionResponseChoice( result = chat_model.template.format_tools.extract(response.response_text)
index=i, else:
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text), result = response.response_text
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
if isinstance(result, tuple):
name, arguments = result
function = Function(name=name, arguments=arguments)
response_message = ChatCompletionMessage(
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
) )
finish_reason = Finish.TOOL
else:
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
choices.append(
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
) )
prompt_length = response.prompt_length prompt_length = response.prompt_length
response_length += response.response_length response_length += response.response_length
@ -152,7 +171,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
): ):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk) yield jsonify(chunk)
@ -170,12 +189,14 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
continue continue
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=DeltaMessage(content=new_text), finish_reason=None index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk) yield jsonify(chunk)
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP) choice_data = ChatCompletionResponseStreamChoice(
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
)
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield jsonify(chunk) yield jsonify(chunk)
yield "[DONE]" yield "[DONE]"

View File

@ -11,12 +11,15 @@ class Role(str, Enum):
USER = "user" USER = "user"
ASSISTANT = "assistant" ASSISTANT = "assistant"
SYSTEM = "system" SYSTEM = "system"
FUNCTION = "function"
TOOL = "tool"
@unique @unique
class Finish(str, Enum): class Finish(str, Enum):
STOP = "stop" STOP = "stop"
LENGTH = "length" LENGTH = "length"
TOOL = "tool_calls"
class ModelCard(BaseModel): class ModelCard(BaseModel):
@ -31,19 +34,32 @@ class ModelList(BaseModel):
data: List[ModelCard] = [] data: List[ModelCard] = []
class Function(BaseModel):
name: str
arguments: str
class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function"
function: Function
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: str content: str
class DeltaMessage(BaseModel): class ChatCompletionMessage(BaseModel):
role: Optional[Role] = None role: Optional[Role] = None
content: Optional[str] = None content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: List[ChatMessage]
tools: Optional[list] = []
do_sample: bool = True do_sample: bool = True
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
@ -54,13 +70,13 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatCompletionMessage
finish_reason: Finish finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: ChatCompletionMessage
finish_reason: Optional[Finish] = None finish_reason: Optional[Finish] = None

View File

@ -37,9 +37,9 @@ class ChatModel:
tools: Optional[str] = None, tools: Optional[str] = None,
**input_kwargs, **input_kwargs,
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
new_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
prompt, _ = self.template.encode_oneturn( prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=new_messages, system=system, tools=tools tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
) )
prompt_length = len(prompt) prompt_length = len(prompt)
input_ids = torch.tensor([prompt], device=self.model.device) input_ids = torch.tensor([prompt], device=self.model.device)

View File

@ -74,6 +74,9 @@ class Formatter(ABC):
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
... ...
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
raise NotImplementedError
@dataclass @dataclass
class EmptyFormatter(Formatter): class EmptyFormatter(Formatter):

View File

@ -1,9 +1,11 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple import json
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
import gradio as gr import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments from ..hparams import GeneratingArguments
from .common import get_save_dir from .common import get_save_dir
@ -105,22 +107,37 @@ class WebChatModel(ChatModel):
self, self,
chatbot: List[Tuple[str, str]], chatbot: List[Tuple[str, str]],
query: str, query: str,
history: List[Tuple[str, str]], messages: Sequence[Tuple[str, str]],
system: str, system: str,
tools: str, tools: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float, temperature: float,
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]: ) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
chatbot.append([query, ""]) chatbot.append([query, ""])
query_messages = messages + [{"role": Role.USER, "content": query}]
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
new_history = history + [(query, response)] if tools:
chatbot[-1] = [query, self.postprocess(response)] result = self.template.format_tools.extract(response)
yield chatbot, new_history else:
result = response
if isinstance(result, tuple):
name, arguments = result
arguments = json.loads(arguments)
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
bot_text = "```json\n" + tool_call + "\n```"
else:
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
bot_text = result
chatbot[-1] = [query, self.postprocess(bot_text)]
yield chatbot, output_messages
def postprocess(self, response: str) -> str: def postprocess(self, response: str) -> str:
blocks = response.split("```") blocks = response.split("```")

View File

@ -17,7 +17,7 @@ def create_chat_box(
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box: with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot() chatbot = gr.Chatbot()
history = gr.State([]) messages = gr.State([])
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
system = gr.Textbox(show_label=False) system = gr.Textbox(show_label=False)
@ -32,21 +32,21 @@ def create_chat_box(
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
tools.input(check_json_schema, [tools]) tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
submit_btn.click( submit_btn.click(
engine.chatter.predict, engine.chatter.predict,
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature], [chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, messages],
show_progress=True, show_progress=True,
).then(lambda: gr.update(value=""), outputs=[query]) ).then(lambda: gr.update(value=""), outputs=[query])
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
return ( return (
chat_box, chat_box,
chatbot, chatbot,
history, messages,
dict( dict(
system=system, system=system,
tools=tools, tools=tools,

View File

@ -208,6 +208,8 @@ ALERTS = {
"zh": "展示模式不支持训练,请先复制到私人空间。", "zh": "展示模式不支持训练,请先复制到私人空间。",
}, },
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"}, "err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
"err_tool_name": {"en": "Tool name not found.", "zh": "工具名称未找到。"},
"err_json_schema": {"en": "Invalid JSON schema.", "zh": "Json 格式错误。"},
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"}, "info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"}, "info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
"info_finished": {"en": "Finished.", "zh": "训练完毕。"}, "info_finished": {"en": "Finished.", "zh": "训练完毕。"},

View File

@ -8,6 +8,7 @@ import gradio as gr
from ..extras.packages import is_matplotlib_available from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth from ..extras.ploting import smooth
from .common import get_save_dir from .common import get_save_dir
from .locales import ALERTS
if TYPE_CHECKING: if TYPE_CHECKING:
@ -40,11 +41,15 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True) return gr.update(interactive=True)
def check_json_schema(text: str) -> None: def check_json_schema(text: str, lang: str) -> None:
try: try:
json.loads(text) tools = json.loads(text)
for tool in tools:
assert "name" in tool
except AssertionError:
gr.Warning(ALERTS["err_tool_name"][lang])
except json.JSONDecodeError: except json.JSONDecodeError:
gr.Warning("Invalid JSON schema") gr.Warning(ALERTS["err_json_schema"][lang])
def gen_cmd(args: Dict[str, Any]) -> str: def gen_cmd(args: Dict[str, Any]) -> str: