From 72d5b06b0834feb79870196a3189cb13c6606931 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 9 Feb 2025 01:03:49 +0800 Subject: [PATCH] [test] align test cases (#6865) * align test cases * fix function formatter Former-commit-id: a68f5e22d0391c80a9a826dc83967255be572032 --- src/llamafactory/data/formatter.py | 25 ++++++++++--------------- src/llamafactory/data/tool_utils.py | 24 +++++++++++------------- tests/data/test_formatter.py | 25 +++++++++++-------------- 3 files changed, 32 insertions(+), 42 deletions(-) diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index ed83a2d2..c68f0a52 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -86,19 +86,20 @@ class StringFormatter(Formatter): elif isinstance(slot, (dict, set)): elements.append(slot) else: - raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}") + raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.") return elements @dataclass -class FunctionFormatter(Formatter): +class FunctionFormatter(StringFormatter): def __post_init__(self): + super().__post_init__() self.tool_utils = get_tool_utils(self.tool_format) @override def apply(self, **kwargs) -> SLOTS: - content = kwargs.pop("content") + content: str = kwargs.pop("content") regex = re.compile(r"(.*)", re.DOTALL) thought = re.search(regex, content) if thought: @@ -116,19 +117,13 @@ class FunctionFormatter(Formatter): ) except json.JSONDecodeError: - raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string + raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string - elements = [] - for slot in self.slots: - if slot == "{{content}}": - if thought: - elements.append(thought.group(1)) + function_str = self.tool_utils.function_formatter(functions) + if thought: + function_str = thought.group(1) + function_str - elements += self.tool_utils.function_formatter(functions) - else: - elements.append(slot) - - return elements + return super().apply(content=function_str) @dataclass @@ -143,7 +138,7 @@ class ToolFormatter(Formatter): tools = json.loads(content) return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""] except json.JSONDecodeError: - raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string + raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string @override def extract(self, content: str) -> Union[str, List["FunctionCall"]]: diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 6132e982..f3269413 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -21,8 +21,6 @@ from typing import Any, Dict, List, NamedTuple, Tuple, Union from typing_extensions import override -from .data_utils import SLOTS - class FunctionCall(NamedTuple): name: str @@ -76,7 +74,7 @@ class ToolUtils(ABC): @staticmethod @abstractmethod - def function_formatter(functions: List["FunctionCall"]) -> SLOTS: + def function_formatter(functions: List["FunctionCall"]) -> str: r""" Generates the assistant message including all the tool calls. """ @@ -134,12 +132,12 @@ class DefaultToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> SLOTS: + def function_formatter(functions: List["FunctionCall"]) -> str: function_text = "" for name, arguments in functions: function_text += f"Action: {name}\nAction Input: {arguments}\n" - return [function_text] + return function_text @override @staticmethod @@ -180,11 +178,11 @@ class GLM4ToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> SLOTS: + def function_formatter(functions: List["FunctionCall"]) -> str: if len(functions) > 1: raise ValueError("GLM-4 does not support parallel functions.") - return [f"{functions[0].name}\n{functions[0].arguments}"] + return f"{functions[0].name}\n{functions[0].arguments}" @override @staticmethod @@ -221,11 +219,11 @@ class Llama3ToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> SLOTS: + def function_formatter(functions: List["FunctionCall"]) -> str: if len(functions) > 1: raise ValueError("Llama-3 does not support parallel functions.") - return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'] + return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}' @override @staticmethod @@ -257,12 +255,12 @@ class MistralToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> SLOTS: + def function_formatter(functions: List["FunctionCall"]) -> str: function_texts = [] for name, arguments in functions: function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') - return ["[" + ", ".join(function_texts) + "]"] + return "[" + ", ".join(function_texts) + "]" @override @staticmethod @@ -302,14 +300,14 @@ class QwenToolUtils(ToolUtils): @override @staticmethod - def function_formatter(functions: List["FunctionCall"]) -> SLOTS: + def function_formatter(functions: List["FunctionCall"]) -> str: function_texts = [] for name, arguments in functions: function_texts.append( "\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n" ) - return ["\n".join(function_texts)] + return "\n".join(function_texts) @override @staticmethod diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index 2aaf48a2..542bafb9 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -123,11 +123,10 @@ def test_glm4_tool_extractor(): def test_llama3_function_formatter(): - formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3") + formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) assert formatter.apply(content=tool_calls) == [ - """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""", - "<|eot_id|>", + """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>""" ] @@ -150,20 +149,19 @@ def test_llama3_tool_extractor(): def test_mistral_function_formatter(): - formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", ""], tool_format="mistral") + formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", ""], tool_format="mistral") tool_calls = json.dumps(FUNCTION) assert formatter.apply(content=tool_calls) == [ - "[TOOL_CALLS] ", - """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""", + "[TOOL_CALLS] " """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""", "", ] def test_mistral_multi_function_formatter(): - formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", ""], tool_format="mistral") + formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", ""], tool_format="mistral") tool_calls = json.dumps([FUNCTION] * 2) assert formatter.apply(content=tool_calls) == [ - "[TOOL_CALLS] ", + "[TOOL_CALLS] " """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """ """{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""", "", @@ -197,21 +195,20 @@ def test_mistral_multi_tool_extractor(): def test_qwen_function_formatter(): - formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen") + formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen") tool_calls = json.dumps(FUNCTION) assert formatter.apply(content=tool_calls) == [ - """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""", - "<|im_end|>", + """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n<|im_end|>\n""" ] def test_qwen_multi_function_formatter(): - formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen") + formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen") tool_calls = json.dumps([FUNCTION] * 2) assert formatter.apply(content=tool_calls) == [ """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n\n""" - """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""", - "<|im_end|>", + """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""" + "<|im_end|>\n" ]