[test] align test cases (#6865)

* align test cases

* fix function formatter

Former-commit-id: a68f5e22d0391c80a9a826dc83967255be572032
This commit is contained in:
hoshi-hiyouga 2025-02-09 01:03:49 +08:00 committed by GitHub
parent 94726bdc8d
commit 72d5b06b08
3 changed files with 32 additions and 42 deletions

View File

@ -86,19 +86,20 @@ class StringFormatter(Formatter):
elif isinstance(slot, (dict, set)): elif isinstance(slot, (dict, set)):
elements.append(slot) elements.append(slot)
else: else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}") raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
return elements return elements
@dataclass @dataclass
class FunctionFormatter(Formatter): class FunctionFormatter(StringFormatter):
def __post_init__(self): def __post_init__(self):
super().__post_init__()
self.tool_utils = get_tool_utils(self.tool_format) self.tool_utils = get_tool_utils(self.tool_format)
@override @override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content: str = kwargs.pop("content")
regex = re.compile(r"<think>(.*)</think>", re.DOTALL) regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
thought = re.search(regex, content) thought = re.search(regex, content)
if thought: if thought:
@ -116,19 +117,13 @@ class FunctionFormatter(Formatter):
) )
except json.JSONDecodeError: except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
elements = [] function_str = self.tool_utils.function_formatter(functions)
for slot in self.slots: if thought:
if slot == "{{content}}": function_str = thought.group(1) + function_str
if thought:
elements.append(thought.group(1))
elements += self.tool_utils.function_formatter(functions) return super().apply(content=function_str)
else:
elements.append(slot)
return elements
@dataclass @dataclass
@ -143,7 +138,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content) tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError: except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]: def extract(self, content: str) -> Union[str, List["FunctionCall"]]:

View File

@ -21,8 +21,6 @@ from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing_extensions import override from typing_extensions import override
from .data_utils import SLOTS
class FunctionCall(NamedTuple): class FunctionCall(NamedTuple):
name: str name: str
@ -76,7 +74,7 @@ class ToolUtils(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS: def function_formatter(functions: List["FunctionCall"]) -> str:
r""" r"""
Generates the assistant message including all the tool calls. Generates the assistant message including all the tool calls.
""" """
@ -134,12 +132,12 @@ class DefaultToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS: def function_formatter(functions: List["FunctionCall"]) -> str:
function_text = "" function_text = ""
for name, arguments in functions: for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n" function_text += f"Action: {name}\nAction Input: {arguments}\n"
return [function_text] return function_text
@override @override
@staticmethod @staticmethod
@ -180,11 +178,11 @@ class GLM4ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS: def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1: if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.") raise ValueError("GLM-4 does not support parallel functions.")
return [f"{functions[0].name}\n{functions[0].arguments}"] return f"{functions[0].name}\n{functions[0].arguments}"
@override @override
@staticmethod @staticmethod
@ -221,11 +219,11 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS: def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1: if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.") raise ValueError("Llama-3 does not support parallel functions.")
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'] return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
@override @override
@staticmethod @staticmethod
@ -257,12 +255,12 @@ class MistralToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS: def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = [] function_texts = []
for name, arguments in functions: for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return ["[" + ", ".join(function_texts) + "]"] return "[" + ", ".join(function_texts) + "]"
@override @override
@staticmethod @staticmethod
@ -302,14 +300,14 @@ class QwenToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS: def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = [] function_texts = []
for name, arguments in functions: for name, arguments in functions:
function_texts.append( function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>" "<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
) )
return ["\n".join(function_texts)] return "\n".join(function_texts)
@override @override
@staticmethod @staticmethod

View File

@ -123,11 +123,10 @@ def test_glm4_tool_extractor():
def test_llama3_function_formatter(): def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3") formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""", """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
"<|eot_id|>",
] ]
@ -150,20 +149,19 @@ def test_llama3_tool_extractor():
def test_mistral_function_formatter(): def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral") formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ", "[TOOL_CALLS] " """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>", "</s>",
] ]
def test_mistral_multi_function_formatter(): def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral") formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2) tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ", "[TOOL_CALLS] "
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """ """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""", """{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>", "</s>",
@ -197,21 +195,20 @@ def test_mistral_multi_tool_extractor():
def test_qwen_function_formatter(): def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen") formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""", """<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call><|im_end|>\n"""
"<|im_end|>",
] ]
def test_qwen_multi_function_formatter(): def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen") formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2) tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n""" """<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""", """<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
"<|im_end|>", "<|im_end|>\n"
] ]