mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 09:10:35 +08:00
[misc] fix accelerator (#9661)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -38,19 +38,19 @@ TOOLS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_empty_formatter():
|
||||
formatter = EmptyFormatter(slots=["\n"])
|
||||
assert formatter.apply() == ["\n"]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_string_formatter():
|
||||
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
|
||||
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
@@ -60,7 +60,7 @@ def test_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
@@ -71,7 +71,7 @@ def test_multi_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_default_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
@@ -90,14 +90,14 @@ def test_default_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_default_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_default_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = (
|
||||
@@ -110,14 +110,14 @@ def test_default_multi_tool_extractor():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_glm4_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_glm4_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="glm4")
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
@@ -128,14 +128,14 @@ def test_glm4_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_glm4_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="glm4")
|
||||
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_llama3_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
@@ -144,7 +144,7 @@ def test_llama3_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_llama3_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
@@ -155,7 +155,7 @@ def test_llama3_multi_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_llama3_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
@@ -169,14 +169,14 @@ def test_llama3_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_llama3_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_llama3_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
result = (
|
||||
@@ -189,7 +189,7 @@ def test_llama3_multi_tool_extractor():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_mistral_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
@@ -199,7 +199,7 @@ def test_mistral_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_mistral_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
@@ -211,7 +211,7 @@ def test_mistral_multi_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_mistral_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="mistral")
|
||||
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||
@@ -220,14 +220,14 @@ def test_mistral_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_mistral_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="mistral")
|
||||
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_mistral_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="mistral")
|
||||
result = (
|
||||
@@ -240,7 +240,7 @@ def test_mistral_multi_tool_extractor():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_qwen_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
@@ -249,7 +249,7 @@ def test_qwen_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_qwen_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
@@ -260,7 +260,7 @@ def test_qwen_multi_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_qwen_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="qwen")
|
||||
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||
@@ -274,14 +274,14 @@ def test_qwen_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_qwen_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="qwen")
|
||||
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_qwen_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="qwen")
|
||||
result = (
|
||||
|
||||
Reference in New Issue
Block a user