diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 0d669d24..e89fb9c1 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -866,7 +866,11 @@ _register_template( name="llava_next_qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], @@ -1050,7 +1054,11 @@ _register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], @@ -1062,7 +1070,11 @@ _register_template( name="qwen2_vl", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"), + format_observation=StringFormatter( + slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] + ), + format_tools=ToolFormatter(tool_format="qwen"), format_separator=EmptyFormatter(slots=["\n"]), default_system="You are a helpful assistant.", stop_words=["<|im_end|>"], diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 6be48767..a92701ed 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -51,6 +51,14 @@ LLAMA3_TOOL_PROMPT = ( "Do not use variables.\n\n{tool_text}" ) +QWEN_TOOL_PROMPT = ( + "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n{tool_text}" + "\n\n\nFor each function call, return a json object with function name and arguments within " + """ XML tags:\n\n{{"name": , """ + """"arguments": }}\n<|im_end|>\n""" +) + @dataclass class ToolUtils(ABC): @@ -79,11 +87,17 @@ class ToolUtils(ABC): def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: r""" Extracts all the function calls from the assistant message. + + It should be an inverse function of `function_formatter`. """ ... class DefaultToolUtils(ToolUtils): + r""" + Default tool using template. + """ + @override @staticmethod def tool_formatter(tools: List[Dict[str, Any]]) -> str: @@ -149,6 +163,10 @@ class DefaultToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils): + r""" + GLM-4 tool using template. + """ + @override @staticmethod def tool_formatter(tools: List[Dict[str, Any]]) -> str: @@ -205,7 +223,7 @@ class Llama3ToolUtils(ToolUtils): @staticmethod def function_formatter(functions: List["FunctionCall"]) -> SLOTS: if len(functions) > 1: - raise ValueError("Llama 3 does not support parallel functions.") + raise ValueError("Llama-3 does not support parallel functions.") return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'] @@ -224,6 +242,10 @@ class Llama3ToolUtils(ToolUtils): class MistralToolUtils(ToolUtils): + r""" + Mistral v0.3 tool using template. + """ + @override @staticmethod def tool_formatter(tools: List[Dict[str, Any]]) -> str: @@ -263,11 +285,61 @@ class MistralToolUtils(ToolUtils): return results +class QwenToolUtils(ToolUtils): + r""" + Qwen 2.5 tool using template. + """ + + @override + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + wrapped_tool = {"type": "function", "function": tool} + tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) + + return QWEN_TOOL_PROMPT.format(tool_text=tool_text) + + @override + @staticmethod + def function_formatter(functions: List["FunctionCall"]) -> SLOTS: + function_texts = [] + for name, arguments in functions: + function_texts.append( + "\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n" + ) + + return ["\n".join(function_texts)] + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: + regex = re.compile(r"(.+?)(?=\s*|\s*$)", re.DOTALL) + tool_match: List[str] = re.findall(regex, content) + if not tool_match: + return content + + results = [] + for tool in tool_match: + try: + tool = json.loads(tool.strip()) + except json.JSONDecodeError: + return content + + if "name" not in tool or "arguments" not in tool: + return content + + results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False))) + + return results + + TOOLS = { "default": DefaultToolUtils(), "glm4": GLM4ToolUtils(), "llama3": Llama3ToolUtils(), "mistral": MistralToolUtils(), + "qwen": QwenToolUtils(), } diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index c3865a85..a76e68a6 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -194,3 +194,53 @@ def test_mistral_multi_tool_extractor(): ("test_tool", """{"foo": "bar", "size": 10}"""), ("another_tool", """{"foo": "job", "size": 2}"""), ] + + +def test_qwen_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen") + tool_calls = json.dumps(FUNCTION) + assert formatter.apply(content=tool_calls) == [ + """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""", + "<|im_end|>", + ] + + +def test_qwen_multi_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen") + tool_calls = json.dumps([FUNCTION] * 2) + assert formatter.apply(content=tool_calls) == [ + """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n\n""" + """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""", + "<|im_end|>", + ] + + +def test_qwen_tool_formatter(): + formatter = ToolFormatter(tool_format="qwen") + wrapped_tool = {"type": "function", "function": TOOLS[0]} + assert formatter.apply(content=json.dumps(TOOLS)) == [ + "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}" + "\n\n\nFor each function call, return a json object with function name and arguments within " + """ XML tags:\n\n{"name": , """ + """"arguments": }\n<|im_end|>\n""" + ] + + +def test_qwen_tool_extractor(): + formatter = ToolFormatter(tool_format="qwen") + result = """\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +def test_qwen_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="qwen") + result = ( + """\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n\n""" + """\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n""" + ) + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ]