mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
support llama3 tool prompt
Former-commit-id: b24ae55ebf548db904a9fe1876192024d8a96108
This commit is contained in:
parent
fc18db6290
commit
a935933bed
@ -98,7 +98,7 @@ class StringFormatter(Formatter):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FunctionFormatter(Formatter):
|
class FunctionFormatter(Formatter):
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
self.function_slots = get_tool_utils(self.tool_format).get_function_slots()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
@ -117,7 +117,7 @@ class FunctionFormatter(Formatter):
|
|||||||
|
|
||||||
elements = []
|
elements = []
|
||||||
for name, arguments in functions:
|
for name, arguments in functions:
|
||||||
for slot in self.slots:
|
for slot in self.function_slots:
|
||||||
if isinstance(slot, str):
|
if isinstance(slot, str):
|
||||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
@ -126,7 +126,7 @@ class FunctionFormatter(Formatter):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||||
|
|
||||||
return elements
|
return elements + self.slots
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -750,16 +750,18 @@ _register_template(
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
|
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="llama3"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
replace_jinja_template=False,
|
replace_jinja_template=False,
|
||||||
)
|
)
|
||||||
@ -777,16 +779,18 @@ _register_template(
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
|
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="llama3"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
replace_jinja_template=False,
|
replace_jinja_template=False,
|
||||||
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
|
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
|
||||||
@ -829,16 +833,18 @@ _register_template(
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
|
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="llama3"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||||
replace_eos=True,
|
replace_eos=True,
|
||||||
replace_jinja_template=False,
|
replace_jinja_template=False,
|
||||||
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
@ -17,6 +17,7 @@ import re
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@ -24,6 +25,9 @@ from typing_extensions import override
|
|||||||
from .data_utils import SLOTS
|
from .data_utils import SLOTS
|
||||||
|
|
||||||
|
|
||||||
|
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TOOL_PROMPT = (
|
DEFAULT_TOOL_PROMPT = (
|
||||||
"You have access to the following tools:\n{tool_text}"
|
"You have access to the following tools:\n{tool_text}"
|
||||||
"Use the following format if using a tool:\n"
|
"Use the following format if using a tool:\n"
|
||||||
@ -41,7 +45,12 @@ GLM4_TOOL_PROMPT = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
LLAMA3_TOOL_PROMPT = (
|
||||||
|
"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
|
||||||
|
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||||
|
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
|
||||||
|
"Do not use variables.\n\n{tool_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -161,16 +170,52 @@ class GLM4ToolUtils(ToolUtils):
|
|||||||
|
|
||||||
tool_name, tool_input = content.split("\n", maxsplit=1)
|
tool_name, tool_input = content.split("\n", maxsplit=1)
|
||||||
try:
|
try:
|
||||||
arguments = json.loads(tool_input)
|
arguments = json.loads(tool_input.strip())
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||||
|
|
||||||
|
|
||||||
|
class Llama3ToolUtils(ToolUtils):
|
||||||
|
r"""
|
||||||
|
Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def get_function_slots() -> SLOTS:
|
||||||
|
return ["""{"name": "{{name}}", "parameters": {{arguments}}}"""]
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
cur_time = datetime.now().strftime("%d %b %Y")
|
||||||
|
tool_text = ""
|
||||||
|
for tool in tools:
|
||||||
|
wrapped_tool = {"type": "function", "function": tool}
|
||||||
|
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
return LLAMA3_TOOL_PROMPT.format(cur_time=cur_time, tool_text=tool_text)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
|
try:
|
||||||
|
tool = json.loads(content.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if "name" not in tool or "parameters" not in tool:
|
||||||
|
return content
|
||||||
|
|
||||||
|
return [(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
|
||||||
|
|
||||||
|
|
||||||
TOOLS = {
|
TOOLS = {
|
||||||
"default": DefaultToolUtils(),
|
"default": DefaultToolUtils(),
|
||||||
"glm4": GLM4ToolUtils(),
|
"glm4": GLM4ToolUtils(),
|
||||||
|
"llama3": Llama3ToolUtils(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -837,6 +837,10 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct",
|
||||||
},
|
},
|
||||||
|
"Llama-3.3-70B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-3.3-70B-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.3-70B-Instruct",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="llama3",
|
template="llama3",
|
||||||
)
|
)
|
||||||
|
@ -13,40 +13,14 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||||
|
|
||||||
|
|
||||||
def test_empty_formatter():
|
FUNCTION = {"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}
|
||||||
formatter = EmptyFormatter(slots=["\n"])
|
|
||||||
assert formatter.apply() == ["\n"]
|
|
||||||
|
|
||||||
|
TOOLS = [
|
||||||
def test_string_formatter():
|
|
||||||
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
|
|
||||||
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_function_formatter():
|
|
||||||
formatter = FunctionFormatter(slots=[], tool_format="default")
|
|
||||||
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
|
||||||
assert formatter.apply(content=tool_calls) == [
|
|
||||||
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n"""
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_multi_function_formatter():
|
|
||||||
formatter = FunctionFormatter(slots=[], tool_format="default")
|
|
||||||
tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2)
|
|
||||||
assert formatter.apply(content=tool_calls) == [
|
|
||||||
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
|
|
||||||
"""Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_default_tool_formatter():
|
|
||||||
formatter = ToolFormatter(tool_format="default")
|
|
||||||
tools = [
|
|
||||||
{
|
{
|
||||||
"name": "test_tool",
|
"name": "test_tool",
|
||||||
"description": "tool_desc",
|
"description": "tool_desc",
|
||||||
@ -60,7 +34,40 @@ def test_default_tool_formatter():
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
assert formatter.apply(content=json.dumps(tools)) == [
|
|
||||||
|
|
||||||
|
def test_empty_formatter():
|
||||||
|
formatter = EmptyFormatter(slots=["\n"])
|
||||||
|
assert formatter.apply() == ["\n"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_string_formatter():
|
||||||
|
formatter = StringFormatter(slots=["<s>", "Human: {{content}}\nAssistant:"])
|
||||||
|
assert formatter.apply(content="Hi") == ["<s>", "Human: Hi\nAssistant:"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["</s>"], tool_format="default")
|
||||||
|
tool_calls = json.dumps(FUNCTION)
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_function_formatter():
|
||||||
|
formatter = FunctionFormatter(slots=["</s>"], tool_format="default")
|
||||||
|
tool_calls = json.dumps([FUNCTION] * 2)
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||||
|
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||||
|
"</s>",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_tool_formatter():
|
||||||
|
formatter = ToolFormatter(tool_format="default")
|
||||||
|
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||||
"You have access to the following tools:\n"
|
"You have access to the following tools:\n"
|
||||||
"> Tool Name: test_tool\n"
|
"> Tool Name: test_tool\n"
|
||||||
"Tool Description: tool_desc\n"
|
"Tool Description: tool_desc\n"
|
||||||
@ -94,26 +101,18 @@ def test_default_multi_tool_extractor():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_glm4_function_formatter():
|
||||||
|
formatter = FunctionFormatter(tool_format="glm4")
|
||||||
|
tool_calls = json.dumps(FUNCTION)
|
||||||
|
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
|
||||||
|
|
||||||
|
|
||||||
def test_glm4_tool_formatter():
|
def test_glm4_tool_formatter():
|
||||||
formatter = ToolFormatter(tool_format="glm4")
|
formatter = ToolFormatter(tool_format="glm4")
|
||||||
tools = [
|
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||||
{
|
|
||||||
"name": "test_tool",
|
|
||||||
"description": "tool_desc",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"foo": {"type": "string", "description": "foo_desc"},
|
|
||||||
"bar": {"type": "number", "description": "bar_desc"},
|
|
||||||
},
|
|
||||||
"required": ["foo"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
assert formatter.apply(content=json.dumps(tools)) == [
|
|
||||||
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
||||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
|
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n"
|
||||||
"## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(json.dumps(tools[0], indent=4))
|
f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -121,3 +120,29 @@ def test_glm4_tool_extractor():
|
|||||||
formatter = ToolFormatter(tool_format="glm4")
|
formatter = ToolFormatter(tool_format="glm4")
|
||||||
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
|
result = """test_tool\n{"foo": "bar", "size": 10}\n"""
|
||||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama3_function_formatter():
|
||||||
|
formatter = FunctionFormatter(tool_format="llama3")
|
||||||
|
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||||
|
assert formatter.apply(content=tool_calls) == [
|
||||||
|
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}"""
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama3_tool_formatter():
|
||||||
|
formatter = ToolFormatter(tool_format="llama3")
|
||||||
|
cur_time = datetime.now().strftime("%d %b %Y")
|
||||||
|
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||||
|
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||||
|
f"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
|
||||||
|
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||||
|
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
|
||||||
|
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama3_tool_extractor():
|
||||||
|
formatter = ToolFormatter(tool_format="llama3")
|
||||||
|
result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n"""
|
||||||
|
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user