mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
support qwen tool format
Former-commit-id: 98795854e3fda7b0c0bc209b3e2496b0036e154e
This commit is contained in:
parent
acd62fddb8
commit
a421113466
@ -866,7 +866,11 @@ _register_template(
|
|||||||
name="llava_next_qwen",
|
name="llava_next_qwen",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="qwen"),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system="You are a helpful assistant.",
|
default_system="You are a helpful assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
@ -1050,7 +1054,11 @@ _register_template(
|
|||||||
name="qwen",
|
name="qwen",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="qwen"),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system="You are a helpful assistant.",
|
default_system="You are a helpful assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
@ -1062,7 +1070,11 @@ _register_template(
|
|||||||
name="qwen2_vl",
|
name="qwen2_vl",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="qwen"),
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
default_system="You are a helpful assistant.",
|
default_system="You are a helpful assistant.",
|
||||||
stop_words=["<|im_end|>"],
|
stop_words=["<|im_end|>"],
|
||||||
|
@ -51,6 +51,14 @@ LLAMA3_TOOL_PROMPT = (
|
|||||||
"Do not use variables.\n\n{tool_text}"
|
"Do not use variables.\n\n{tool_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
QWEN_TOOL_PROMPT = (
|
||||||
|
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
|
||||||
|
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
|
||||||
|
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
|
||||||
|
""""arguments": <args-json-object>}}\n</tool_call><|im_end|>\n"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolUtils(ABC):
|
class ToolUtils(ABC):
|
||||||
@ -79,11 +87,17 @@ class ToolUtils(ABC):
|
|||||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
r"""
|
r"""
|
||||||
Extracts all the function calls from the assistant message.
|
Extracts all the function calls from the assistant message.
|
||||||
|
|
||||||
|
It should be an inverse function of `function_formatter`.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class DefaultToolUtils(ToolUtils):
|
class DefaultToolUtils(ToolUtils):
|
||||||
|
r"""
|
||||||
|
Default tool using template.
|
||||||
|
"""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
@ -149,6 +163,10 @@ class DefaultToolUtils(ToolUtils):
|
|||||||
|
|
||||||
|
|
||||||
class GLM4ToolUtils(ToolUtils):
|
class GLM4ToolUtils(ToolUtils):
|
||||||
|
r"""
|
||||||
|
GLM-4 tool using template.
|
||||||
|
"""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
@ -205,7 +223,7 @@ class Llama3ToolUtils(ToolUtils):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||||
if len(functions) > 1:
|
if len(functions) > 1:
|
||||||
raise ValueError("Llama 3 does not support parallel functions.")
|
raise ValueError("Llama-3 does not support parallel functions.")
|
||||||
|
|
||||||
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
|
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
|
||||||
|
|
||||||
@ -224,6 +242,10 @@ class Llama3ToolUtils(ToolUtils):
|
|||||||
|
|
||||||
|
|
||||||
class MistralToolUtils(ToolUtils):
|
class MistralToolUtils(ToolUtils):
|
||||||
|
r"""
|
||||||
|
Mistral v0.3 tool using template.
|
||||||
|
"""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
@ -263,11 +285,61 @@ class MistralToolUtils(ToolUtils):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class QwenToolUtils(ToolUtils):
|
||||||
|
r"""
|
||||||
|
Qwen 2.5 tool using template.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
tool_text = ""
|
||||||
|
for tool in tools:
|
||||||
|
wrapped_tool = {"type": "function", "function": tool}
|
||||||
|
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
|
||||||
|
|
||||||
|
return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||||
|
function_texts = []
|
||||||
|
for name, arguments in functions:
|
||||||
|
function_texts.append(
|
||||||
|
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ["\n".join(function_texts)]
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
|
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
|
||||||
|
tool_match: List[str] = re.findall(regex, content)
|
||||||
|
if not tool_match:
|
||||||
|
return content
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for tool in tool_match:
|
||||||
|
try:
|
||||||
|
tool = json.loads(tool.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if "name" not in tool or "arguments" not in tool:
|
||||||
|
return content
|
||||||
|
|
||||||
|
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
TOOLS = {
|
TOOLS = {
|
||||||
"default": DefaultToolUtils(),
|
"default": DefaultToolUtils(),
|
||||||
"glm4": GLM4ToolUtils(),
|
"glm4": GLM4ToolUtils(),
|
||||||
"llama3": Llama3ToolUtils(),
|
"llama3": Llama3ToolUtils(),
|
||||||
"mistral": MistralToolUtils(),
|
"mistral": MistralToolUtils(),
|
||||||
|
"qwen": QwenToolUtils(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -194,3 +194,53 @@ def test_mistral_multi_tool_extractor():
|
|||||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
|
||||||
|
tool_calls = json.dumps(FUNCTION)
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
|
||||||
|
"<|im_end|>",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen_multi_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
|
||||||
|
tool_calls = json.dumps([FUNCTION] * 2)
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
|
||||||
|
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
|
||||||
|
"<|im_end|>",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen_tool_formatter():
|
||||||
|
formatter = ToolFormatter(tool_format="qwen")
|
||||||
|
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||||
|
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||||
|
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||||
|
f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}"
|
||||||
|
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
|
||||||
|
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, """
|
||||||
|
""""arguments": <args-json-object>}\n</tool_call><|im_end|>\n"""
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="qwen")
|
||||||
|
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
|
||||||
|
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen_multi_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="qwen")
|
||||||
|
result = (
|
||||||
|
"""<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
|
||||||
|
"""<tool_call>\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n</tool_call>"""
|
||||||
|
)
|
||||||
|
assert formatter.extract(result) == [
|
||||||
|
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||||
|
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||||
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user