diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index 0d669d24..e89fb9c1 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -866,7 +866,11 @@ _register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
- format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
+ format_observation=StringFormatter(
+ slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"]
+ ),
+ format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
@@ -1050,7 +1054,11 @@ _register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
- format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
+ format_observation=StringFormatter(
+ slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"]
+ ),
+ format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
@@ -1062,7 +1070,11 @@ _register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
- format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
+ format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
+ format_observation=StringFormatter(
+ slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"]
+ ),
+ format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py
index 6be48767..a92701ed 100644
--- a/src/llamafactory/data/tool_utils.py
+++ b/src/llamafactory/data/tool_utils.py
@@ -51,6 +51,14 @@ LLAMA3_TOOL_PROMPT = (
"Do not use variables.\n\n{tool_text}"
)
+QWEN_TOOL_PROMPT = (
+ "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
+ "You are provided with function signatures within XML tags:\n{tool_text}"
+ "\n\n\nFor each function call, return a json object with function name and arguments within "
+ """ XML tags:\n\n{{"name": , """
+ """"arguments": }}\n<|im_end|>\n"""
+)
+
@dataclass
class ToolUtils(ABC):
@@ -79,11 +87,17 @@ class ToolUtils(ABC):
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the assistant message.
+
+ It should be an inverse function of `function_formatter`.
"""
...
class DefaultToolUtils(ToolUtils):
+ r"""
+ Default tool using template.
+ """
+
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
@@ -149,6 +163,10 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
+ r"""
+ GLM-4 tool using template.
+ """
+
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
@@ -205,7 +223,7 @@ class Llama3ToolUtils(ToolUtils):
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
- raise ValueError("Llama 3 does not support parallel functions.")
+ raise ValueError("Llama-3 does not support parallel functions.")
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
@@ -224,6 +242,10 @@ class Llama3ToolUtils(ToolUtils):
class MistralToolUtils(ToolUtils):
+ r"""
+ Mistral v0.3 tool using template.
+ """
+
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
@@ -263,11 +285,61 @@ class MistralToolUtils(ToolUtils):
return results
+class QwenToolUtils(ToolUtils):
+ r"""
+ Qwen 2.5 tool using template.
+ """
+
+ @override
+ @staticmethod
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
+ tool_text = ""
+ for tool in tools:
+ wrapped_tool = {"type": "function", "function": tool}
+ tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
+
+ return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
+
+ @override
+ @staticmethod
+ def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
+ function_texts = []
+ for name, arguments in functions:
+ function_texts.append(
+ "\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n"
+ )
+
+ return ["\n".join(function_texts)]
+
+ @override
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
+ regex = re.compile(r"(.+?)(?=\s*|\s*$)", re.DOTALL)
+ tool_match: List[str] = re.findall(regex, content)
+ if not tool_match:
+ return content
+
+ results = []
+ for tool in tool_match:
+ try:
+ tool = json.loads(tool.strip())
+ except json.JSONDecodeError:
+ return content
+
+ if "name" not in tool or "arguments" not in tool:
+ return content
+
+ results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
+
+ return results
+
+
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"mistral": MistralToolUtils(),
+ "qwen": QwenToolUtils(),
}
diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py
index c3865a85..a76e68a6 100644
--- a/tests/data/test_formatter.py
+++ b/tests/data/test_formatter.py
@@ -194,3 +194,53 @@ def test_mistral_multi_tool_extractor():
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
+
+
+def test_qwen_function_formatter():
+ formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
+ tool_calls = json.dumps(FUNCTION)
+ assert formatter.apply(content=tool_calls) == [
+ """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""",
+ "<|im_end|>",
+ ]
+
+
+def test_qwen_multi_function_formatter():
+ formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
+ tool_calls = json.dumps([FUNCTION] * 2)
+ assert formatter.apply(content=tool_calls) == [
+ """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n\n"""
+ """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""",
+ "<|im_end|>",
+ ]
+
+
+def test_qwen_tool_formatter():
+ formatter = ToolFormatter(tool_format="qwen")
+ wrapped_tool = {"type": "function", "function": TOOLS[0]}
+ assert formatter.apply(content=json.dumps(TOOLS)) == [
+ "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
+ "You are provided with function signatures within XML tags:\n"
+ f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}"
+ "\n\n\nFor each function call, return a json object with function name and arguments within "
+ """ XML tags:\n\n{"name": , """
+ """"arguments": }\n<|im_end|>\n"""
+ ]
+
+
+def test_qwen_tool_extractor():
+ formatter = ToolFormatter(tool_format="qwen")
+ result = """\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n"""
+ assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
+
+
+def test_qwen_multi_tool_extractor():
+ formatter = ToolFormatter(tool_format="qwen")
+ result = (
+ """\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n\n"""
+ """\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n"""
+ )
+ assert formatter.extract(result) == [
+ ("test_tool", """{"foo": "bar", "size": 10}"""),
+ ("another_tool", """{"foo": "job", "size": 2}"""),
+ ]