mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[test] align test cases (#6865)
* align test cases * fix function formatter Former-commit-id: f6f3f8d0fc79de6bbad0bf892fc2f6c98c27eb8e
This commit is contained in:
parent
fcd0f0480d
commit
b93333685b
@ -86,19 +86,20 @@ class StringFormatter(Formatter):
|
||||
elif isinstance(slot, (dict, set)):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
|
||||
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
class FunctionFormatter(StringFormatter):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.tool_utils = get_tool_utils(self.tool_format)
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
content: str = kwargs.pop("content")
|
||||
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
|
||||
thought = re.search(regex, content)
|
||||
if thought:
|
||||
@ -116,19 +117,13 @@ class FunctionFormatter(Formatter):
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string
|
||||
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if slot == "{{content}}":
|
||||
if thought:
|
||||
elements.append(thought.group(1))
|
||||
function_str = self.tool_utils.function_formatter(functions)
|
||||
if thought:
|
||||
function_str = thought.group(1) + function_str
|
||||
|
||||
elements += self.tool_utils.function_formatter(functions)
|
||||
else:
|
||||
elements.append(slot)
|
||||
|
||||
return elements
|
||||
return super().apply(content=function_str)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -143,7 +138,7 @@ class ToolFormatter(Formatter):
|
||||
tools = json.loads(content)
|
||||
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}") # flat string
|
||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
||||
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
|
@ -21,8 +21,6 @@ from typing import Any, Dict, List, NamedTuple, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
|
||||
|
||||
class FunctionCall(NamedTuple):
|
||||
name: str
|
||||
@ -76,7 +74,7 @@ class ToolUtils(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
r"""
|
||||
Generates the assistant message including all the tool calls.
|
||||
"""
|
||||
@ -134,12 +132,12 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
function_text = ""
|
||||
for name, arguments in functions:
|
||||
function_text += f"Action: {name}\nAction Input: {arguments}\n"
|
||||
|
||||
return [function_text]
|
||||
return function_text
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@ -180,11 +178,11 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("GLM-4 does not support parallel functions.")
|
||||
|
||||
return [f"{functions[0].name}\n{functions[0].arguments}"]
|
||||
return f"{functions[0].name}\n{functions[0].arguments}"
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@ -221,11 +219,11 @@ class Llama3ToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("Llama-3 does not support parallel functions.")
|
||||
|
||||
return [f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}']
|
||||
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@ -257,12 +255,12 @@ class MistralToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
|
||||
|
||||
return ["[" + ", ".join(function_texts) + "]"]
|
||||
return "[" + ", ".join(function_texts) + "]"
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@ -302,14 +300,14 @@ class QwenToolUtils(ToolUtils):
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: List["FunctionCall"]) -> SLOTS:
|
||||
def function_formatter(functions: List["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(
|
||||
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
|
||||
)
|
||||
|
||||
return ["\n".join(function_texts)]
|
||||
return "\n".join(function_texts)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
|
@ -123,11 +123,10 @@ def test_glm4_tool_extractor():
|
||||
|
||||
|
||||
def test_llama3_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
|
||||
"<|eot_id|>",
|
||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
|
||||
]
|
||||
|
||||
|
||||
@ -150,20 +149,19 @@ def test_llama3_tool_extractor():
|
||||
|
||||
|
||||
def test_mistral_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"[TOOL_CALLS] ",
|
||||
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
|
||||
"[TOOL_CALLS] " """[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
|
||||
def test_mistral_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", "</s>"], tool_format="mistral")
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"[TOOL_CALLS] ",
|
||||
"[TOOL_CALLS] "
|
||||
"""[{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}, """
|
||||
"""{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}]""",
|
||||
"</s>",
|
||||
@ -197,21 +195,20 @@ def test_mistral_multi_tool_extractor():
|
||||
|
||||
|
||||
def test_qwen_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
|
||||
"<|im_end|>",
|
||||
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call><|im_end|>\n"""
|
||||
]
|
||||
|
||||
|
||||
def test_qwen_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen")
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>\n"""
|
||||
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>""",
|
||||
"<|im_end|>",
|
||||
"""<tool_call>\n{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
|
||||
"<|im_end|>\n"
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user