mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
fix tool formatter, allow parallel function #4362
Former-commit-id: cd75b1fe9d91fb52a9ae6de7435302ff06b4d933
This commit is contained in:
parent
6db02615d4
commit
bccc852f76
@ -92,9 +92,11 @@ def _process_request(
|
|||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
|
||||||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||||
name = message.tool_calls[0].function.name
|
tool_calls = [
|
||||||
arguments = message.tool_calls[0].function.arguments
|
{"name": tool_call.function.name, "argument": tool_call.function.arguments}
|
||||||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
for tool_call in message.tool_calls
|
||||||
|
]
|
||||||
|
content = json.dumps(tool_calls, ensure_ascii=False)
|
||||||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content})
|
||||||
elif isinstance(message.content, list):
|
elif isinstance(message.content, list):
|
||||||
for input_item in message.content:
|
for input_item in message.content:
|
||||||
@ -118,7 +120,7 @@ def _process_request(
|
|||||||
if isinstance(tool_list, list) and len(tool_list):
|
if isinstance(tool_list, list) and len(tool_list):
|
||||||
try:
|
try:
|
||||||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||||
except Exception:
|
except json.JSONDecodeError:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||||
else:
|
else:
|
||||||
tools = None
|
tools = None
|
||||||
@ -160,17 +162,16 @@ async def create_chat_completion_response(
|
|||||||
choices = []
|
choices = []
|
||||||
for i, response in enumerate(responses):
|
for i, response in enumerate(responses):
|
||||||
if tools:
|
if tools:
|
||||||
result = chat_model.engine.template.format_tools.extract(response.response_text)
|
result = chat_model.engine.template.extract_tool(response.response_text)
|
||||||
else:
|
else:
|
||||||
result = response.response_text
|
result = response.response_text
|
||||||
|
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for tool in result:
|
for tool in result:
|
||||||
name, arguments = tool
|
function = Function(name=tool[0], arguments=tool[1])
|
||||||
function = Function(name=name, arguments=arguments)
|
tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function))
|
||||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
|
||||||
tool_calls.append(tool_call)
|
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||||
finish_reason = Finish.TOOL
|
finish_reason = Finish.TOOL
|
||||||
else:
|
else:
|
||||||
|
@ -22,29 +22,20 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Uni
|
|||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
|
|
||||||
|
|
||||||
JSON_FORMAT_PROMPT = (
|
DEFAULT_TOOL_PROMPT = (
|
||||||
""", in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TOOL_SYSTEM_PROMPT = (
|
|
||||||
"You have access to the following tools:\n{tool_text}"
|
"You have access to the following tools:\n{tool_text}"
|
||||||
"Use the following format if using a tool:\n"
|
"Use the following format if using a tool:\n"
|
||||||
"```\n"
|
"```\n"
|
||||||
"Action: tool name (one of [{tool_names}]).\n"
|
"Action: tool name (one of [{tool_names}]).\n"
|
||||||
"Action Input: the input to the tool{format_prompt}.\n"
|
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||||
|
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
||||||
"```\n"
|
"```\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
GLM4_TOOL_SUFFIX_PROMPT = (
|
|
||||||
"在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
|
||||||
)
|
|
||||||
|
|
||||||
GLM4_TOOL_PROMPT = (
|
GLM4_TOOL_PROMPT = (
|
||||||
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持,"
|
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||||
"{tool_text}"
|
"你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -73,32 +64,19 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
|||||||
)
|
)
|
||||||
tool_names.append(tool["name"])
|
tool_names.append(tool["name"])
|
||||||
|
|
||||||
return TOOL_SYSTEM_PROMPT.format(
|
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||||
tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
|
||||||
tool_text = ""
|
|
||||||
for tool in tools:
|
|
||||||
tool_name = tool["name"]
|
|
||||||
tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}"
|
|
||||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
|
||||||
|
|
||||||
|
|
||||||
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL)
|
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|$)", re.DOTALL)
|
||||||
action_match = re.findall(regex, content)
|
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||||
if not action_match:
|
if not action_match:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
for match in action_match:
|
for match in action_match:
|
||||||
tool_name, tool_input = match
|
tool_name = match[0].strip()
|
||||||
tool_name = tool_name.strip()
|
tool_input = match[1].strip().strip('"').strip("```")
|
||||||
tool_input = tool_input.strip().strip('"').strip("```")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
arguments = json.loads(tool_input)
|
arguments = json.loads(tool_input)
|
||||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||||
@ -108,18 +86,27 @@ def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
for tool in tools:
|
||||||
|
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||||
|
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||||
|
|
||||||
|
|
||||||
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
lines = content.strip().split("\n")
|
if "\n" not in content:
|
||||||
if len(lines) != 2:
|
|
||||||
return content
|
return content
|
||||||
tool_name = lines[0].strip()
|
|
||||||
tool_input = lines[1].strip()
|
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||||
try:
|
try:
|
||||||
arguments = json.loads(tool_input)
|
arguments = json.loads(tool_input)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return content
|
return content
|
||||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
|
||||||
|
|
||||||
|
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -193,22 +180,28 @@ class FunctionFormatter(Formatter):
|
|||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
|
functions: List[Tuple[str, str]] = []
|
||||||
try:
|
try:
|
||||||
function = json.loads(content)
|
tool_calls = json.loads(content)
|
||||||
name = function["name"]
|
if not isinstance(tool_calls, list): # parallel function call
|
||||||
arguments = json.dumps(function["arguments"], ensure_ascii=False)
|
tool_calls = [tool_calls]
|
||||||
except Exception:
|
|
||||||
name, arguments = "", ""
|
for tool_call in tool_calls:
|
||||||
|
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
functions = []
|
||||||
|
|
||||||
elements = []
|
elements = []
|
||||||
for slot in self.slots:
|
for name, arguments in functions:
|
||||||
if isinstance(slot, str):
|
for slot in self.slots:
|
||||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
if isinstance(slot, str):
|
||||||
elements.append(slot)
|
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||||
elif isinstance(slot, (dict, set)):
|
elements.append(slot)
|
||||||
elements.append(slot)
|
elif isinstance(slot, (dict, set)):
|
||||||
else:
|
elements.append(slot)
|
||||||
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
else:
|
||||||
|
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
|
|
||||||
@ -216,29 +209,22 @@ class FunctionFormatter(Formatter):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ToolFormatter(Formatter):
|
class ToolFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tool_format is None:
|
if self.tool_format == "default":
|
||||||
|
self._tool_formatter = default_tool_formatter
|
||||||
|
self._tool_extractor = default_tool_extractor
|
||||||
|
elif self.tool_format == "glm4":
|
||||||
|
self._tool_formatter = glm4_tool_formatter
|
||||||
|
self._tool_extractor = glm4_tool_extractor
|
||||||
|
else:
|
||||||
raise ValueError("Tool format was not found.")
|
raise ValueError("Tool format was not found.")
|
||||||
|
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
tools = json.loads(content)
|
tools = json.loads(content)
|
||||||
if not len(tools):
|
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
||||||
return [""]
|
except json.JSONDecodeError:
|
||||||
|
|
||||||
if self.tool_format == "default":
|
|
||||||
return [default_tool_formatter(tools)]
|
|
||||||
elif self.tool_format == "glm4":
|
|
||||||
return [glm4_tool_formatter(tools)]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
except Exception:
|
|
||||||
return [""]
|
return [""]
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
if self.tool_format == "default":
|
return self._tool_extractor(content)
|
||||||
return default_tool_extractor(content)
|
|
||||||
elif self.tool_format == "glm4":
|
|
||||||
return glm4_tool_extractor(content)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -79,6 +79,12 @@ class Template:
|
|||||||
"""
|
"""
|
||||||
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len)
|
||||||
|
|
||||||
|
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||||
|
r"""
|
||||||
|
Extracts tool message.
|
||||||
|
"""
|
||||||
|
return self.format_tools.extract(content)
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
@ -100,7 +106,8 @@ class Template:
|
|||||||
if i == 0 and (system or tools or self.force_system):
|
if i == 0 and (system or tools or self.force_system):
|
||||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||||
elements += self.format_system.apply(content=(system + tool_text))
|
elements += self.format_system.apply(content=(system + tool_text))
|
||||||
elif i > 0 and i % 2 == 0:
|
|
||||||
|
if i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER.value:
|
if message["role"] == Role.USER.value:
|
||||||
@ -191,7 +198,8 @@ class Llama2Template(Template):
|
|||||||
if i == 0 and (system or tools or self.force_system):
|
if i == 0 and (system or tools or self.force_system):
|
||||||
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
||||||
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
||||||
elif i > 0 and i % 2 == 0:
|
|
||||||
|
if i > 0 and i % 2 == 0:
|
||||||
elements += self.format_separator.apply()
|
elements += self.format_separator.apply()
|
||||||
|
|
||||||
if message["role"] == Role.USER.value:
|
if message["role"] == Role.USER.value:
|
||||||
@ -259,7 +267,9 @@ def _register_template(
|
|||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
default_function_formatter = FunctionFormatter(
|
||||||
|
slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots
|
||||||
|
)
|
||||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||||
default_separator_formatter = EmptyFormatter()
|
default_separator_formatter = EmptyFormatter()
|
||||||
TEMPLATES[name] = template_class(
|
TEMPLATES[name] = template_class(
|
||||||
|
@ -140,16 +140,15 @@ class WebChatModel(ChatModel):
|
|||||||
):
|
):
|
||||||
response += new_text
|
response += new_text
|
||||||
if tools:
|
if tools:
|
||||||
result = self.engine.template.format_tools.extract(response)
|
result = self.engine.template.extract_tool(response)
|
||||||
else:
|
else:
|
||||||
result = response
|
result = response
|
||||||
|
|
||||||
if isinstance(result, tuple):
|
if isinstance(result, list):
|
||||||
name, arguments = result
|
tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
|
||||||
arguments = json.loads(arguments)
|
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
|
||||||
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
|
||||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}]
|
bot_text = "```json\n" + tool_calls + "\n```"
|
||||||
bot_text = "```json\n" + tool_call + "\n```"
|
|
||||||
else:
|
else:
|
||||||
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||||
bot_text = result
|
bot_text = result
|
||||||
|
125
tests/data/test_formatter.py
Normal file
125
tests/data/test_formatter.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
# Copyright 2024 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_formatter():
|
||||||
|
formatter = EmptyFormatter(slots=["\n"])
|
||||||
|
assert formatter.apply() == ["\n"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_formatter():
|
||||||
|
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
|
||||||
|
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
|
||||||
|
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"])
|
||||||
|
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
|
||||||
|
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_tool_formatter():
|
||||||
|
formatter = ToolFormatter(tool_format="default")
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"name": "test_tool",
|
||||||
|
"description": "tool_desc",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"foo": {"type": "string", "description": "foo_desc"},
|
||||||
|
"bar": {"type": "number", "description": "bar_desc"},
|
||||||
|
},
|
||||||
|
"required": ["foo"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert formatter.apply(content=json.dumps(tools)) == [
|
||||||
|
"You have access to the following tools:\n"
|
||||||
|
"> Tool Name: test_tool\n"
|
||||||
|
"Tool Description: tool_desc\n"
|
||||||
|
"Tool Args:\n"
|
||||||
|
" - foo (string, required): foo_desc\n"
|
||||||
|
" - bar (number): bar_desc\n\n"
|
||||||
|
"Use the following format if using a tool:\n"
|
||||||
|
"```\n"
|
||||||
|
"Action: tool name (one of [test_tool]).\n"
|
||||||
|
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
||||||
|
"""(e.g. ```{"input": "hello world", "num_beams": 5}```).\n"""
|
||||||
|
"```\n"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="default")
|
||||||
|
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||||
|
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_multi_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="default")
|
||||||
|
result = (
|
||||||
|
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||||
|
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
|
||||||
|
)
|
||||||
|
assert formatter.extract(result) == [
|
||||||
|
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||||
|
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_glm4_tool_formatter():
|
||||||
|
formatter = ToolFormatter(tool_format="glm4")
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"name": "test_tool",
|
||||||
|
"description": "tool_desc",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"foo": {"type": "string", "description": "foo_desc"},
|
||||||
|
"bar": {"type": "number", "description": "bar_desc"},
|
||||||
|
},
|
||||||
|
"required": ["foo"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
assert formatter.apply(content=json.dumps(tools)) == [
|
||||||
|
"你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||||
|
"你的任务是针对用户的问题和要求提供适当的答复和支持。"
|
||||||
|
"\n\n## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
||||||
|
json.dumps(tools[0], indent=4)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_glm4_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="glm4")
|
||||||
|
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
|
||||||
|
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
Loading…
x
Reference in New Issue
Block a user