mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
[test] add npu test yaml and add ascend a3 docker file (#9547)
Co-authored-by: jiaqiw09 <jiaqiw960714@gmail.com>
This commit is contained in:
@@ -15,6 +15,8 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||
|
||||
|
||||
@@ -36,16 +38,19 @@ TOOLS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_empty_formatter():
|
||||
formatter = EmptyFormatter(slots=["\n"])
|
||||
assert formatter.apply() == ["\n"]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_string_formatter():
|
||||
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
|
||||
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
@@ -55,6 +60,7 @@ def test_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
@@ -65,6 +71,7 @@ def test_multi_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_default_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
@@ -83,12 +90,14 @@ def test_default_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_default_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_default_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = (
|
||||
@@ -101,12 +110,14 @@ def test_default_multi_tool_extractor():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_glm4_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_glm4_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="glm4")
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
@@ -117,12 +128,14 @@ def test_glm4_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_glm4_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="glm4")
|
||||
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_llama3_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
@@ -131,6 +144,7 @@ def test_llama3_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_llama3_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
@@ -141,6 +155,7 @@ def test_llama3_multi_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_llama3_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
@@ -154,12 +169,14 @@ def test_llama3_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_llama3_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_llama3_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
result = (
|
||||
@@ -172,6 +189,7 @@ def test_llama3_multi_tool_extractor():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_mistral_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
@@ -181,6 +199,7 @@ def test_mistral_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_mistral_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
@@ -192,6 +211,7 @@ def test_mistral_multi_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_mistral_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="mistral")
|
||||
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||
@@ -200,12 +220,14 @@ def test_mistral_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_mistral_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="mistral")
|
||||
result = """{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_mistral_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="mistral")
|
||||
result = (
|
||||
@@ -218,6 +240,7 @@ def test_mistral_multi_tool_extractor():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_qwen_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
@@ -226,6 +249,7 @@ def test_qwen_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_qwen_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
@@ -236,6 +260,7 @@ def test_qwen_multi_function_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_qwen_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="qwen")
|
||||
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||
@@ -249,12 +274,14 @@ def test_qwen_tool_formatter():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_qwen_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="qwen")
|
||||
result = """<tool_call>\n{"name": "test_tool", "arguments": {"foo": "bar", "size": 10}}\n</tool_call>"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu"])
|
||||
def test_qwen_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="qwen")
|
||||
result = (
|
||||
|
||||
Reference in New Issue
Block a user