fix tool formatter, allow parallel function #4362

Former-commit-id: cd75b1fe9d91fb52a9ae6de7435302ff06b4d933
This commit is contained in:
hiyouga 2024-06-19 03:23:51 +08:00
parent 6db02615d4
commit bccc852f76
5 changed files with 207 additions and 86 deletions

View File

@ -92,9 +92,11 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name tool_calls = [
arguments = message.tool_calls[0].function.arguments {"name": tool_call.function.name, "argument": tool_call.function.arguments}
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) for tool_call in message.tool_calls
]
content = json.dumps(tool_calls, ensure_ascii=False)
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
elif isinstance(message.content, list): elif isinstance(message.content, list):
for input_item in message.content: for input_item in message.content:
@ -118,7 +120,7 @@ def _process_request(
if isinstance(tool_list, list) and len(tool_list): if isinstance(tool_list, list) and len(tool_list):
try: try:
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception: except json.JSONDecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:
tools = None tools = None
@ -160,17 +162,16 @@ async def create_chat_completion_response(
choices = [] choices = []
for i, response in enumerate(responses): for i, response in enumerate(responses):
if tools: if tools:
result = chat_model.engine.template.format_tools.extract(response.response_text) result = chat_model.engine.template.extract_tool(response.response_text)
else: else:
result = response.response_text result = response.response_text
if isinstance(result, list): if isinstance(result, list):
tool_calls = [] tool_calls = []
for tool in result: for tool in result:
name, arguments = tool function = Function(name=tool[0], arguments=tool[1])
function = Function(name=name, arguments=arguments) tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
tool_calls.append(tool_call)
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
finish_reason = Finish.TOOL finish_reason = Finish.TOOL
else: else:

View File

@ -22,29 +22,20 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Uni
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
JSON_FORMAT_PROMPT = ( DEFAULT_TOOL_PROMPT = (
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
)
TOOL_SYSTEM_PROMPT = (
"You have access to the following tools:\n{tool_text}" "You have access to the following tools:\n{tool_text}"
"Use the following format if using a tool:\n" "Use the following format if using a tool:\n"
"```\n" "```\n"
"Action: tool name (one of [{tool_names}]).\n" "Action: tool name (one of [{tool_names}]).\n"
"Action Input: the input to the tool{format_prompt}.\n" "Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
"```\n" "```\n"
) )
GLM4_TOOL_SUFFIX_PROMPT = (
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
)
GLM4_TOOL_PROMPT = ( GLM4_TOOL_PROMPT = (
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持," "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"{tool_text}" "你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
) )
@ -73,32 +64,19 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
) )
tool_names.append(tool["name"]) tool_names.append(tool["name"])
return TOOL_SYSTEM_PROMPT.format( return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
)
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_name = tool["name"]
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL) regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|$)", re.DOTALL)
action_match = re.findall(regex, content) action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match: if not action_match:
return content return content
results = [] results = []
for match in action_match: for match in action_match:
tool_name, tool_input = match tool_name = match[0].strip()
tool_name = tool_name.strip() tool_input = match[1].strip().strip('"').strip("```")
tool_input = tool_input.strip().strip('"').strip("```")
try: try:
arguments = json.loads(tool_input) arguments = json.loads(tool_input)
results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
@ -108,18 +86,27 @@ def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
return results return results
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
)
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
lines = content.strip().split("\n") if "\n" not in content:
if len(lines) != 2:
return content return content
tool_name = lines[0].strip()
tool_input = lines[1].strip() tool_name, tool_input = content.split("\n", maxsplit=1)
try: try:
arguments = json.loads(tool_input) arguments = json.loads(tool_input)
except json.JSONDecodeError: except json.JSONDecodeError:
return content return content
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
@dataclass @dataclass
@ -193,22 +180,28 @@ class FunctionFormatter(Formatter):
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
functions: List[Tuple[str, str]] = []
try: try:
function = json.loads(content) tool_calls = json.loads(content)
name = function["name"] if not isinstance(tool_calls, list): # parallel function call
arguments = json.dumps(function["arguments"], ensure_ascii=False) tool_calls = [tool_calls]
except Exception:
name, arguments = "", "" for tool_call in tool_calls:
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
except json.JSONDecodeError:
functions = []
elements = [] elements = []
for slot in self.slots: for name, arguments in functions:
if isinstance(slot, str): for slot in self.slots:
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) if isinstance(slot, str):
elements.append(slot) slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elif isinstance(slot, (dict, set)): elements.append(slot)
elements.append(slot) elif isinstance(slot, (dict, set)):
else: elements.append(slot)
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) else:
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
return elements return elements
@ -216,29 +209,22 @@ class FunctionFormatter(Formatter):
@dataclass @dataclass
class ToolFormatter(Formatter): class ToolFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
if self.tool_format is None: if self.tool_format == "default":
self._tool_formatter = default_tool_formatter
self._tool_extractor = default_tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = glm4_tool_formatter
self._tool_extractor = glm4_tool_extractor
else:
raise ValueError("Tool format was not found.") raise ValueError("Tool format was not found.")
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
try: try:
tools = json.loads(content) tools = json.loads(content)
if not len(tools): return [self._tool_formatter(tools) if len(tools) != 0 else ""]
return [""] except json.JSONDecodeError:
if self.tool_format == "default":
return [default_tool_formatter(tools)]
elif self.tool_format == "glm4":
return [glm4_tool_formatter(tools)]
else:
raise NotImplementedError
except Exception:
return [""] return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
if self.tool_format == "default": return self._tool_extractor(content)
return default_tool_extractor(content)
elif self.tool_format == "glm4":
return glm4_tool_extractor(content)
else:
raise NotImplementedError

View File

@ -79,6 +79,12 @@ class Template:
""" """
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
r"""
Extracts tool message.
"""
return self.format_tools.extract(content)
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
@ -100,7 +106,8 @@ class Template:
if i == 0 and (system or tools or self.force_system): if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
elements += self.format_system.apply(content=(system + tool_text)) elements += self.format_system.apply(content=(system + tool_text))
elif i > 0 and i % 2 == 0:
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply() elements += self.format_separator.apply()
if message["role"] == Role.USER.value: if message["role"] == Role.USER.value:
@ -191,7 +198,8 @@ class Llama2Template(Template):
if i == 0 and (system or tools or self.force_system): if i == 0 and (system or tools or self.force_system):
tool_text = self.format_tools.apply(content=tools)[0] if tools else "" tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
system_text = self.format_system.apply(content=(system + tool_text))[0] system_text = self.format_system.apply(content=(system + tool_text))[0]
elif i > 0 and i % 2 == 0:
if i > 0 and i % 2 == 0:
elements += self.format_separator.apply() elements += self.format_separator.apply()
if message["role"] == Role.USER.value: if message["role"] == Role.USER.value:
@ -259,7 +267,9 @@ def _register_template(
template_class = Llama2Template if name.startswith("llama2") else Template template_class = Llama2Template if name.startswith("llama2") else Template
default_user_formatter = StringFormatter(slots=["{{content}}"]) default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) default_function_formatter = FunctionFormatter(
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
)
default_tool_formatter = ToolFormatter(tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter() default_separator_formatter = EmptyFormatter()
TEMPLATES[name] = template_class( TEMPLATES[name] = template_class(

View File

@ -140,16 +140,15 @@ class WebChatModel(ChatModel):
): ):
response += new_text response += new_text
if tools: if tools:
result = self.engine.template.format_tools.extract(response) result = self.engine.template.extract_tool(response)
else: else:
result = response result = response
if isinstance(result, tuple): if isinstance(result, list):
name, arguments = result tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
arguments = json.loads(arguments) tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False) output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}] bot_text = "```json\n" + tool_calls + "\n```"
bot_text = "```json\n" + tool_call + "\n```"
else: else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
bot_text = result bot_text = result

View File

@ -0,0 +1,125 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
def test_string_formatter():
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
def test_function_formatter():
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
]
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
]
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
tools = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
"You have access to the following tools:\n"
"> Tool Name: test_tool\n"
"Tool Description: tool_desc\n"
"Tool Args:\n"
" - foo (string, required): foo_desc\n"
" - bar (number): bar_desc\n\n"
"Use the following format if using a tool:\n"
"```\n"
"Action: tool name (one of [test_tool]).\n"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{"input": "hello world", "num_beams": 5}```).\n"""
"```\n"
]
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
tools = [
{
"name": "test_tool",
"description": "tool_desc",
"parameters": {
"type": "object",
"properties": {
"foo": {"type": "string", "description": "foo_desc"},
"bar": {"type": "number", "description": "bar_desc"},
},
"required": ["foo"],
},
}
]
assert formatter.apply(content=json.dumps(tools)) == [
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。"
"\n\n## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
json.dumps(tools[0], indent=4)
)
]
def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]