diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py
index ed83a2d2..c68f0a52 100644
--- a/src/llamafactory/data/formatter.py
+++ b/src/llamafactory/data/formatter.py
@@ -86,19 +86,20 @@ class StringFormatter(Formatter):
elif isinstance(slot, (dict, set)):
elements.append(slot)
else:
- raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
+ raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
return elements
@dataclass
-class FunctionFormatter(Formatter):
+class FunctionFormatter(StringFormatter):
def __post_init__(self):
+ super().__post_init__()
self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS:
- content = kwargs.pop("content")
+ content: str = kwargs.pop("content")
regex = re.compile(r"(.*)", re.DOTALL)
thought = re.search(regex, content)
if thought:
@@ -116,19 +117,13 @@ class FunctionFormatter(Formatter):
)
except json.JSONDecodeError:
- raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
+ raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
- elements = []
- for slot in self.slots:
- if slot == "{{content}}":
- if thought:
- elements.append(thought.group(1))
+ function_str = self.tool_utils.function_formatter(functions)
+ if thought:
+ function_str = thought.group(1) + function_str
- elements += self.tool_utils.function_formatter(functions)
- else:
- elements.append(slot)
-
- return elements
+ return super().apply(content=function_str)
@dataclass
@@ -143,7 +138,7 @@ class ToolFormatter(Formatter):
tools = json.loads(content)
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
- raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
+ raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
@override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py
index 6132e982..f3269413 100644
--- a/src/llamafactory/data/tool_utils.py
+++ b/src/llamafactory/data/tool_utils.py
@@ -21,8 +21,6 @@ from typing import Any, Dict, List, NamedTuple, Tuple, Union
from typing_extensions import override
-from .data_utils import SLOTS
-
class FunctionCall(NamedTuple):
name: str
@@ -76,7 +74,7 @@ class ToolUtils(ABC):
@staticmethod
@abstractmethod
- def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
+ def function_formatter(functions: List["FunctionCall"]) -> str:
r"""
Generates the assistant message including all the tool calls.
"""
@@ -134,12 +132,12 @@ class DefaultToolUtils(ToolUtils):
@override
@staticmethod
- def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
+ def function_formatter(functions: List["FunctionCall"]) -> str:
function_text = ""
for name, arguments in functions:
function_text += f"Action: {name}\nAction Input: {arguments}\n"
- return [function_text]
+ return function_text
@override
@staticmethod
@@ -180,11 +178,11 @@ class GLM4ToolUtils(ToolUtils):
@override
@staticmethod
- def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
+ def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("GLM-4 does not support parallel functions.")
- return [f"{functions[0].name}\n{functions[0].arguments}"]
+ return f"{functions[0].name}\n{functions[0].arguments}"
@override
@staticmethod
@@ -221,11 +219,11 @@ class Llama3ToolUtils(ToolUtils):
@override
@staticmethod
- def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
+ def function_formatter(functions: List["FunctionCall"]) -> str:
if len(functions) > 1:
raise ValueError("Llama-3 does not support parallel functions.")
- return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
+ return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
@override
@staticmethod
@@ -257,12 +255,12 @@ class MistralToolUtils(ToolUtils):
@override
@staticmethod
- def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
+ def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
- return ["[" + ", ".join(function_texts) + "]"]
+ return "[" + ", ".join(function_texts) + "]"
@override
@staticmethod
@@ -302,14 +300,14 @@ class QwenToolUtils(ToolUtils):
@override
@staticmethod
- def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
+ def function_formatter(functions: List["FunctionCall"]) -> str:
function_texts = []
for name, arguments in functions:
function_texts.append(
"\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n"
)
- return ["\n".join(function_texts)]
+ return "\n".join(function_texts)
@override
@staticmethod
diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py
index 2aaf48a2..542bafb9 100644
--- a/tests/data/test_formatter.py
+++ b/tests/data/test_formatter.py
@@ -123,11 +123,10 @@ def test_glm4_tool_extractor():
def test_llama3_function_formatter():
- formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
+ formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [
- """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
- "<|eot_id|>",
+ """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
]
@@ -150,20 +149,19 @@ def test_llama3_tool_extractor():
def test_mistral_function_formatter():
- formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", ""], tool_format="mistral")
+ formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", ""], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
- "[TOOL_CALLS] ",
- """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
+ "[TOOL_CALLS] " """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"",
]
def test_mistral_multi_function_formatter():
- formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", ""], tool_format="mistral")
+ formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", ""], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
- "[TOOL_CALLS] ",
+ "[TOOL_CALLS] "
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
"",
@@ -197,21 +195,20 @@ def test_mistral_multi_tool_extractor():
def test_qwen_function_formatter():
- formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
+ formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [
- """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""",
- "<|im_end|>",
+ """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n<|im_end|>\n"""
]
def test_qwen_multi_function_formatter():
- formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
+ formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [
"""\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n\n"""
- """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n""",
- "<|im_end|>",
+ """\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n"""
+ "<|im_end|>\n"
]