mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-13 00:22:48 +08:00
Merge pull request #5473 from AlongWY/mistral
Support Mistral format tools Former-commit-id: 2fad3792d98f4181ae23e861c3d050fe1bcd8e4e
This commit is contained in:
commit
d8f6569be1
@ -168,7 +168,7 @@ async def create_chat_completion_response(
|
|||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
for tool in result:
|
for tool in result:
|
||||||
function = Function(name=tool[0], arguments=tool[1])
|
function = Function(name=tool.name, arguments=tool.arguments)
|
||||||
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function))
|
||||||
|
|
||||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||||
|
@ -16,16 +16,12 @@ import json
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from .data_utils import SLOTS
|
from .data_utils import SLOTS
|
||||||
from .tool_utils import get_tool_utils
|
from .tool_utils import FunctionCall, get_tool_utils
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .tool_utils import FunctionCall
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -98,19 +94,21 @@ class StringFormatter(Formatter):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FunctionFormatter(Formatter):
|
class FunctionFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.function_slots = get_tool_utils(self.tool_format).get_function_slots()
|
self.tool_utils = get_tool_utils(self.tool_format)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
functions: List[Tuple[str, str]] = []
|
functions: List["FunctionCall"] = []
|
||||||
try:
|
try:
|
||||||
tool_calls = json.loads(content)
|
tool_calls = json.loads(content)
|
||||||
if not isinstance(tool_calls, list): # parallel function call
|
if not isinstance(tool_calls, list): # parallel function call
|
||||||
tool_calls = [tool_calls]
|
tool_calls = [tool_calls]
|
||||||
|
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
functions.append(
|
||||||
|
FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))
|
||||||
|
)
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||||
@ -118,15 +116,7 @@ class FunctionFormatter(Formatter):
|
|||||||
elements = []
|
elements = []
|
||||||
for slot in self.slots:
|
for slot in self.slots:
|
||||||
if slot == "{{content}}":
|
if slot == "{{content}}":
|
||||||
for name, arguments in functions:
|
elements += self.tool_utils.function_formatter(functions)
|
||||||
for slot in self.function_slots:
|
|
||||||
if isinstance(slot, str):
|
|
||||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
|
||||||
elements.append(slot)
|
|
||||||
elif isinstance(slot, (dict, set)):
|
|
||||||
elements.append(slot)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
|
||||||
else:
|
else:
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
|
|
||||||
|
@ -30,6 +30,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
from .formatter import SLOTS, Formatter
|
from .formatter import SLOTS, Formatter
|
||||||
from .mm_plugin import BasePlugin
|
from .mm_plugin import BasePlugin
|
||||||
|
from .tool_utils import FunctionCall
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@ -83,7 +84,7 @@ class Template:
|
|||||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||||
|
|
||||||
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
r"""
|
r"""
|
||||||
Extracts tool message.
|
Extracts tool message.
|
||||||
"""
|
"""
|
||||||
@ -244,7 +245,7 @@ def _register_template(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if any(k in name for k in ("llama2", "mistral")) else Template
|
||||||
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=default_slots)
|
default_assistant_formatter = StringFormatter(slots=default_slots)
|
||||||
@ -854,7 +855,11 @@ _register_template(
|
|||||||
# copied from mistral template
|
# copied from mistral template
|
||||||
_register_template(
|
_register_template(
|
||||||
name="llava_next_mistral",
|
name="llava_next_mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||||
|
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||||
|
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
|
||||||
|
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||||
|
format_tools=ToolFormatter(tool_format="mistral"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
)
|
)
|
||||||
@ -902,7 +907,11 @@ _register_template(
|
|||||||
# copied from mistral template
|
# copied from mistral template
|
||||||
_register_template(
|
_register_template(
|
||||||
name="llava_next_video_mistral",
|
name="llava_next_video_mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||||
|
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||||
|
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
|
||||||
|
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||||
|
format_tools=ToolFormatter(tool_format="mistral"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
mm_plugin=get_mm_plugin(name="llava_next_video", image_token="<image>", video_token="<video>"),
|
||||||
)
|
)
|
||||||
@ -939,7 +948,11 @@ _register_template(
|
|||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="mistral",
|
name="mistral",
|
||||||
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]),
|
||||||
|
format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]),
|
||||||
|
format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"),
|
||||||
|
format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]),
|
||||||
|
format_tools=ToolFormatter(tool_format="mistral"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,17 +15,18 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import namedtuple
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, NamedTuple, Tuple, Union
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from .data_utils import SLOTS
|
from .data_utils import SLOTS
|
||||||
|
|
||||||
|
|
||||||
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
class FunctionCall(NamedTuple):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TOOL_PROMPT = (
|
DEFAULT_TOOL_PROMPT = (
|
||||||
@ -38,13 +39,11 @@ DEFAULT_TOOL_PROMPT = (
|
|||||||
"```\n"
|
"```\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
GLM4_TOOL_PROMPT = (
|
GLM4_TOOL_PROMPT = (
|
||||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LLAMA3_TOOL_PROMPT = (
|
LLAMA3_TOOL_PROMPT = (
|
||||||
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
||||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||||
@ -59,14 +58,6 @@ class ToolUtils(ABC):
|
|||||||
Base class for tool utilities.
|
Base class for tool utilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def get_function_slots() -> SLOTS:
|
|
||||||
r"""
|
|
||||||
Gets a list of slots corresponding to a single function call.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
@ -75,21 +66,24 @@ class ToolUtils(ABC):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||||
|
r"""
|
||||||
|
Generates the assistant message including all the tool calls.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
r"""
|
r"""
|
||||||
Extracts all the function calls from the response message.
|
Extracts all the function calls from the assistant message.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class DefaultToolUtils(ToolUtils):
|
class DefaultToolUtils(ToolUtils):
|
||||||
@override
|
|
||||||
@staticmethod
|
|
||||||
def get_function_slots() -> SLOTS:
|
|
||||||
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
@ -124,6 +118,15 @@ class DefaultToolUtils(ToolUtils):
|
|||||||
|
|
||||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||||
|
function_text = ""
|
||||||
|
for name, arguments in functions:
|
||||||
|
function_text += f"Action: {name}\nAction Input: {arguments}\n"
|
||||||
|
|
||||||
|
return [function_text]
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
@ -138,7 +141,7 @@ class DefaultToolUtils(ToolUtils):
|
|||||||
tool_input = match[1].strip().strip('"').strip("```")
|
tool_input = match[1].strip().strip('"').strip("```")
|
||||||
try:
|
try:
|
||||||
arguments = json.loads(tool_input)
|
arguments = json.loads(tool_input)
|
||||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
results.append(FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@ -146,11 +149,6 @@ class DefaultToolUtils(ToolUtils):
|
|||||||
|
|
||||||
|
|
||||||
class GLM4ToolUtils(ToolUtils):
|
class GLM4ToolUtils(ToolUtils):
|
||||||
@override
|
|
||||||
@staticmethod
|
|
||||||
def get_function_slots() -> SLOTS:
|
|
||||||
return ["{{name}}\n{{arguments}}"]
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
@ -162,6 +160,14 @@ class GLM4ToolUtils(ToolUtils):
|
|||||||
|
|
||||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||||
|
if len(functions) > 1:
|
||||||
|
raise ValueError("GLM-4 does not support parallel functions.")
|
||||||
|
|
||||||
|
return [f"{functions[0].name}\n{functions[0].arguments}"]
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
@ -174,7 +180,7 @@ class GLM4ToolUtils(ToolUtils):
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
return [FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||||
|
|
||||||
|
|
||||||
class Llama3ToolUtils(ToolUtils):
|
class Llama3ToolUtils(ToolUtils):
|
||||||
@ -184,11 +190,6 @@ class Llama3ToolUtils(ToolUtils):
|
|||||||
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
|
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@override
|
|
||||||
@staticmethod
|
|
||||||
def get_function_slots() -> SLOTS:
|
|
||||||
return ["""{"name": "{{name}}", "parameters": {{arguments}}}"""]
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
@ -200,6 +201,14 @@ class Llama3ToolUtils(ToolUtils):
|
|||||||
|
|
||||||
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
|
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||||
|
if len(functions) > 1:
|
||||||
|
raise ValueError("Llama 3 does not support parallel functions.")
|
||||||
|
|
||||||
|
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
@ -211,13 +220,54 @@ class Llama3ToolUtils(ToolUtils):
|
|||||||
if "name" not in tool or "parameters" not in tool:
|
if "name" not in tool or "parameters" not in tool:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
return [(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
|
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
|
||||||
|
|
||||||
|
|
||||||
|
class MistralToolUtils(ToolUtils):
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
wrapped_tools = []
|
||||||
|
for tool in tools:
|
||||||
|
wrapped_tools.append({"type": "function", "function": tool})
|
||||||
|
|
||||||
|
return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||||
|
function_texts = []
|
||||||
|
for name, arguments in functions:
|
||||||
|
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
|
||||||
|
|
||||||
|
return ["[" + ", ".join(function_texts) + "]"]
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
|
try:
|
||||||
|
tools = json.loads(content.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
tools = [tools]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for tool in tools:
|
||||||
|
if "name" not in tool or "arguments" not in tool:
|
||||||
|
return content
|
||||||
|
|
||||||
|
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
TOOLS = {
|
TOOLS = {
|
||||||
"default": DefaultToolUtils(),
|
"default": DefaultToolUtils(),
|
||||||
"glm4": GLM4ToolUtils(),
|
"glm4": GLM4ToolUtils(),
|
||||||
"llama3": Llama3ToolUtils(),
|
"llama3": Llama3ToolUtils(),
|
||||||
|
"mistral": MistralToolUtils(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -157,7 +157,7 @@ class WebChatModel(ChatModel):
|
|||||||
result = response
|
result = response
|
||||||
|
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result]
|
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
|
||||||
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
|
tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False)
|
||||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
|
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
|
||||||
bot_text = "```json\n" + tool_calls + "\n```"
|
bot_text = "```json\n" + tool_calls + "\n```"
|
||||||
|
@ -59,7 +59,7 @@ def test_multi_function_formatter():
|
|||||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||||
tool_calls = json.dumps([FUNCTION] * 2)
|
tool_calls = json.dumps([FUNCTION] * 2)
|
||||||
assert formatter.apply(content=tool_calls) == [
|
assert formatter.apply(content=tool_calls) == [
|
||||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||||
"</s>",
|
"</s>",
|
||||||
]
|
]
|
||||||
@ -112,7 +112,7 @@ def test_glm4_tool_formatter():
|
|||||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
|
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
|
||||||
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4, ensure_ascii=False)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ def test_llama3_tool_formatter():
|
|||||||
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
||||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||||
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
|
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
|
||||||
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n"
|
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4, ensure_ascii=False)}\n\n"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -147,3 +147,50 @@ def test_llama3_tool_extractor():
|
|||||||
formatter = ToolFormatter(tool_format="llama3")
|
formatter = ToolFormatter(tool_format="llama3")
|
||||||
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
|
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
|
||||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
|
||||||
|
tool_calls = json.dumps(FUNCTION)
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"[TOOL_CALLS] ",
|
||||||
|
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_multi_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
|
||||||
|
tool_calls = json.dumps([FUNCTION] * 2)
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"[TOOL_CALLS] ",
|
||||||
|
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
|
||||||
|
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_tool_formatter():
|
||||||
|
formatter = ToolFormatter(tool_format="mistral")
|
||||||
|
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||||
|
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||||
|
"[AVAILABLE_TOOLS] " + json.dumps([wrapped_tool], ensure_ascii=False) + "[/AVAILABLE_TOOLS]"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="mistral")
|
||||||
|
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
|
||||||
|
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mistral_multi_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="mistral")
|
||||||
|
result = (
|
||||||
|
"""[{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}, """
|
||||||
|
"""{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}]"""
|
||||||
|
)
|
||||||
|
assert formatter.extract(result) == [
|
||||||
|
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||||
|
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||||
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user