Merge pull request #6369 from hiyouga/hiyouga/template

[template] support qwen2 tool template

Former-commit-id: af336275021cd6aee3fe9f67b9ac9bcd1276de7c
This commit is contained in:
hoshi-hiyouga 2024-12-18 04:23:49 +08:00 committed by GitHub
commit 5f0dd86c15
3 changed files with 143 additions and 21 deletions

View File

@ -216,7 +216,7 @@ def _register_template(
stop_words: Sequence[str] = [],
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = True,
replace_jinja_template: bool = False,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
) -> None:
r"""
@ -416,6 +416,7 @@ _register_template(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
),
replace_jinja_template=True,
)
@ -510,6 +511,7 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]),
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
replace_jinja_template=True,
)
@ -523,6 +525,7 @@ _register_template(
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
stop_words=["<|im_end|>", "<|im_start|>"],
replace_eos=True,
replace_jinja_template=True,
)
@ -676,7 +679,6 @@ _register_template(
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
efficient_eos=True,
replace_jinja_template=False,
)
@ -763,8 +765,6 @@ _register_template(
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
replace_jinja_template=False,
)
@ -792,8 +792,6 @@ _register_template(
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
)
@ -846,8 +844,6 @@ _register_template(
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot_id|>", "<|eom_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
@ -870,12 +866,15 @@ _register_template(
name="llava_next_qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
)
@ -996,8 +995,6 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are OpenCoder, created by OpenCoder Team.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
)
@ -1057,12 +1054,14 @@ _register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
)
@ -1071,12 +1070,14 @@ _register_template(
name="qwen2_vl",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
format_separator=EmptyFormatter(slots=["\n"]),
default_system="You are a helpful assistant.",
stop_words=["<|im_end|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="qwen2_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
@ -1091,7 +1092,6 @@ _register_template(
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|im_end|>"],
replace_eos=True,
)
@ -1124,7 +1124,6 @@ _register_template(
"in your response."
),
stop_words=["<|eot_id|>"],
replace_eos=True,
)
@ -1162,6 +1161,7 @@ _register_template(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
replace_jinja_template=True,
)

View File

@ -51,6 +51,14 @@ LLAMA3_TOOL_PROMPT = (
"Do not use variables.\n\n{tool_text}"
)
QWEN_TOOL_PROMPT = (
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{{"name": <function-name>, """
""""arguments": <args-json-object>}}\n</tool_call><|im_end|>\n"""
)
@dataclass
class ToolUtils(ABC):
@ -79,11 +87,17 @@ class ToolUtils(ABC):
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
...
class DefaultToolUtils(ToolUtils):
r"""
Default tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
@ -149,6 +163,10 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils):
r"""
GLM-4 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
@ -205,7 +223,7 @@ class Llama3ToolUtils(ToolUtils):
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
if len(functions) > 1:
raise ValueError("Llama 3 does not support parallel functions.")
raise ValueError("Llama-3 does not support parallel functions.")
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
@ -224,6 +242,10 @@ class Llama3ToolUtils(ToolUtils):
class MistralToolUtils(ToolUtils):
r"""
Mistral v0.3 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
@ -263,11 +285,61 @@ class MistralToolUtils(ToolUtils):
return results
class QwenToolUtils(ToolUtils):
r"""
Qwen 2.5 tool using template.
"""
@override
@staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
wrapped_tool = {"type": "function", "function": tool}
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
return QWEN_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return ["\n".join(function_texts)]
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)", re.DOTALL)
tool_match: List[str] = re.findall(regex, content)
if not tool_match:
return content
results = []
for tool in tool_match:
try:
tool = json.loads(tool.strip())
except json.JSONDecodeError:
return content
if "name" not in tool or "arguments" not in tool:
return content
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
return results
TOOLS = {
"default": DefaultToolUtils(),
"glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(),
"mistral": MistralToolUtils(),
"qwen": QwenToolUtils(),
}

View File

@ -194,3 +194,53 @@ def test_mistral_multi_tool_extractor():
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
]
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
]
def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
f"\n{json.dumps(wrapped_tool, ensure_ascii=False)}"
"\n</tools>\n\nFor each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:\n<tool_call>\n{"name": <function-name>, """
""""arguments": <args-json-object>}\n</tool_call><|im_end|>\n"""
]
def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = (
"""<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "another_tool", "arguments": {"foo": "job", "size": 2}}\n</tool_call>"""
)
assert formatter.extract(result) == [
("test_tool", """{"foo": "bar", "size": 10}"""),
("another_tool", """{"foo": "job", "size": 2}"""),
]