[test] align test cases (#6865)

* align test cases

* fix function formatter

Former-commit-id: f6f3f8d0fc79de6bbad0bf892fc2f6c98c27eb8e
This commit is contained in:
hoshi-hiyouga 2025-02-09 01:03:49 +08:00 committed by GitHub
parent fcd0f0480d
commit b93333685b
3 changed files with 32 additions and 42 deletions

View File

@ -86,19 +86,20 @@ class StringFormatter(Formatter):
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
return elements
@dataclass
class FunctionFormatter(Formatter):
class FunctionFormatter(StringFormatter):
def __post_init__(self):
super().__post_init__()
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
content: str = kwargs.pop("content")
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
thought = re.search(regex, content)
if thought:
@ -116,19 +117,13 @@ class FunctionFormatter(Formatter):
)
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
elements = []
for slot in self.slots:
if slot == "{{content}}":
if thought:
elements.append(thought.group(1))
function_str = self.tool_utils.function_formatter(functions)
if thought:
function_str = thought.group(1) + function_str
elements += self.tool_utils.function_formatter(functions)
else:
elements.append(slot)
return elements
return super().apply(content=function_str)
@dataclass
@ -143,7 +138,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:

View File

@ -21,8 +21,6 @@ from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS
class FunctionCall(NamedTuple):
name: str
@ -76,7 +74,7 @@ class ToolUtils(ABC):
@staticmethod
@abstractmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
r"""
Generates the assistant message including all the tool calls.
"""
@ -134,12 +132,12 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
return [function_text]
return function_text
@override
@staticmethod
@ -180,11 +178,11 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
return [f"{functions[0].name}\n{functions[0].arguments}"]
return f"{functions[0].name}\n{functions[0].arguments}"
@override
@staticmethod
@ -221,11 +219,11 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
@override
@staticmethod
@ -257,12 +255,12 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
return ["[" + ", ".join(function_texts) + "]"]
return "[" + ", ".join(function_texts) + "]"
@override
@staticmethod
@ -302,14 +300,14 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
)
return ["\n".join(function_texts)]
return "\n".join(function_texts)
@override
@staticmethod

View File

@ -123,11 +123,10 @@ def test_glm4_tool_extractor():
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
"<|eot_id|>",
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
]
@ -150,20 +149,19 @@ def test_llama3_tool_extractor():
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"[TOOL_CALLS] " """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
]
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"[TOOL_CALLS] ",
"[TOOL_CALLS] "
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"</s>",
@ -197,21 +195,20 @@ def test_mistral_multi_tool_extractor():
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call><|im_end|>\n"""
]
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
"<|im_end|>",
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
"<|im_end|>\n"
]