finish agent

This commit is contained in:
hiyouga
2024-01-21 01:47:33 +08:00
parent 55f707196e
commit 3e982cc714
8 changed files with 105 additions and 41 deletions

View File

@@ -1,9 +1,11 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
import json
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel
from ..data import Role
from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments
from .common import get_save_dir
@@ -105,22 +107,37 @@ class WebChatModel(ChatModel):
self,
chatbot: List[Tuple[str, str]],
query: str,
history: List[Tuple[str, str]],
messages: Sequence[Tuple[str, str]],
system: str,
tools: str,
max_new_tokens: int,
top_p: float,
temperature: float,
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
chatbot.append([query, ""])
query_messages = messages + [{"role": Role.USER, "content": query}]
response = ""
for new_text in self.stream_chat(
query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
):
response += new_text
new_history = history + [(query, response)]
chatbot[-1] = [query, self.postprocess(response)]
yield chatbot, new_history
if tools:
result = self.template.format_tools.extract(response)
else:
result = response
if isinstance(result, tuple):
name, arguments = result
arguments = json.loads(arguments)
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
bot_text = "```json\n" + tool_call + "\n```"
else:
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
bot_text = result
chatbot[-1] = [query, self.postprocess(bot_text)]
yield chatbot, output_messages
def postprocess(self, response: str) -> str:
blocks = response.split("```")

View File

@@ -17,7 +17,7 @@ def create_chat_box(
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
with gr.Box(visible=visible) as chat_box:
chatbot = gr.Chatbot()
history = gr.State([])
messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):
system = gr.Textbox(show_label=False)
@@ -32,21 +32,21 @@ def create_chat_box(
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
tools.input(check_json_schema, [tools])
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
submit_btn.click(
engine.chatter.predict,
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
[chatbot, history],
[chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature],
[chatbot, messages],
show_progress=True,
).then(lambda: gr.update(value=""), outputs=[query])
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
return (
chat_box,
chatbot,
history,
messages,
dict(
system=system,
tools=tools,

View File

@@ -208,6 +208,8 @@ ALERTS = {
"zh": "展示模式不支持训练,请先复制到私人空间。",
},
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
"err_tool_name": {"en": "Tool name not found.", "zh": "工具名称未找到。"},
"err_json_schema": {"en": "Invalid JSON schema.", "zh": "Json 格式错误。"},
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},

View File

@@ -8,6 +8,7 @@ import gradio as gr
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth
from .common import get_save_dir
from .locales import ALERTS
if TYPE_CHECKING:
@@ -40,11 +41,15 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True)
def check_json_schema(text: str) -> None:
def check_json_schema(text: str, lang: str) -> None:
try:
json.loads(text)
tools = json.loads(text)
for tool in tools:
assert "name" in tool
except AssertionError:
gr.Warning(ALERTS["err_tool_name"][lang])
except json.JSONDecodeError:
gr.Warning("Invalid JSON schema")
gr.Warning(ALERTS["err_json_schema"][lang])
def gen_cmd(args: Dict[str, Any]) -> str: