diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 8ad1466a..2c34e1a6 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -93,6 +93,7 @@ class DefaultToolUtils(ToolUtils): tool_text = "" tool_names = [] for tool in tools: + tool = tool.get("function", "") if tool.get("type") == "function" else tool param_text = "" for name, param in tool["parameters"]["properties"].items(): required, enum, items = "", "", "" @@ -159,6 +160,7 @@ class GLM4ToolUtils(ToolUtils): def tool_formatter(tools: list[dict[str, Any]]) -> str: tool_text = "" for tool in tools: + tool = tool.get("function", "") if tool.get("type") == "function" else tool tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) ) @@ -200,7 +202,7 @@ class Llama3ToolUtils(ToolUtils): date = datetime.now().strftime("%d %b %Y") tool_text = "" for tool in tools: - wrapped_tool = {"type": "function", "function": tool} + wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool} tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n" return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text) @@ -235,7 +237,9 @@ class MistralToolUtils(ToolUtils): def tool_formatter(tools: list[dict[str, Any]]) -> str: wrapped_tools = [] for tool in tools: - wrapped_tools.append({"type": "function", "function": tool}) + wrapped_tools.append( + tool if tool.get("type") == "function" else {"type": "function", "function": tool} + ) return "[AVAILABLE_TOOLS] " + json.dumps(wrapped_tools, ensure_ascii=False) + "[/AVAILABLE_TOOLS]" @@ -277,7 +281,7 @@ class QwenToolUtils(ToolUtils): def tool_formatter(tools: list[dict[str, Any]]) -> str: tool_text = "" for tool in tools: - wrapped_tool = {"type": "function", "function": tool} + wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool} tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) return QWEN_TOOL_PROMPT.format(tool_text=tool_text)