[misc] fix accelerator (#9661)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yaowei Zheng
2025-12-25 02:11:04 +08:00
committed by GitHub
parent 6a2eafbae3
commit a754604c11
44 changed files with 396 additions and 448 deletions

View File

@@ -38,19 +38,19 @@ TOOLS = [
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_empty_formatter():
formatter = EmptyFormatter(slots=["\n"])
assert formatter.apply() == ["\n"]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_string_formatter():
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION)
@@ -60,7 +60,7 @@ def test_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -71,7 +71,7 @@ def test_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_formatter():
formatter = ToolFormatter(tool_format="default")
assert formatter.apply(content=json.dumps(TOOLS)) == [
@@ -90,14 +90,14 @@ def test_default_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_default_multi_tool_extractor():
formatter = ToolFormatter(tool_format="default")
result = (
@@ -110,14 +110,14 @@ def test_default_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_formatter():
formatter = ToolFormatter(tool_format="glm4")
assert formatter.apply(content=json.dumps(TOOLS)) == [
@@ -128,14 +128,14 @@ def test_glm4_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_glm4_tool_extractor():
formatter = ToolFormatter(tool_format="glm4")
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps(FUNCTION)
@@ -144,7 +144,7 @@ def test_llama3_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -155,7 +155,7 @@ def test_llama3_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3")
date = datetime.now().strftime("%d %b %Y")
@@ -169,14 +169,14 @@ def test_llama3_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_llama3_multi_tool_extractor():
formatter = ToolFormatter(tool_format="llama3")
result = (
@@ -189,7 +189,7 @@ def test_llama3_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps(FUNCTION)
@@ -199,7 +199,7 @@ def test_mistral_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_function_formatter():
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -211,7 +211,7 @@ def test_mistral_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_formatter():
formatter = ToolFormatter(tool_format="mistral")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
@@ -220,14 +220,14 @@ def test_mistral_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_mistral_multi_tool_extractor():
formatter = ToolFormatter(tool_format="mistral")
result = (
@@ -240,7 +240,7 @@ def test_mistral_multi_tool_extractor():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps(FUNCTION)
@@ -249,7 +249,7 @@ def test_qwen_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_function_formatter():
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
tool_calls = json.dumps([FUNCTION] * 2)
@@ -260,7 +260,7 @@ def test_qwen_multi_function_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_formatter():
formatter = ToolFormatter(tool_format="qwen")
wrapped_tool = {"type": "function", "function": TOOLS[0]}
@@ -274,14 +274,14 @@ def test_qwen_tool_formatter():
]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
@pytest.mark.runs_on(["cpu"])
@pytest.mark.runs_on(["cpu", "mps"])
def test_qwen_multi_tool_extractor():
formatter = ToolFormatter(tool_format="qwen")
result = (