diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index ab86189c..d5f4b385 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -98,7 +98,7 @@ class StringFormatter(Formatter): @dataclass class FunctionFormatter(Formatter): def __post_init__(self): - self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots + self.function_slots = get_tool_utils(self.tool_format).get_function_slots() @override def apply(self, **kwargs) -> SLOTS: @@ -117,7 +117,7 @@ class FunctionFormatter(Formatter): elements = [] for name, arguments in functions: - for slot in self.slots: + for slot in self.function_slots: if isinstance(slot, str): slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) elements.append(slot) @@ -126,7 +126,7 @@ class FunctionFormatter(Formatter): else: raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}") - return elements + return elements + self.slots @dataclass diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 33ba58a9..becfeaa8 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -750,16 +750,18 @@ _register_template( ] ), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( - "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\n" ) ] ), + format_tools=ToolFormatter(tool_format="llama3"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), - stop_words=["<|eot_id|>"], + stop_words=["<|eot_id|>", "<|eom_id|>"], replace_eos=True, replace_jinja_template=False, ) @@ -777,16 +779,18 @@ _register_template( ] ), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( - "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\n" ) ] ), + format_tools=ToolFormatter(tool_format="llama3"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), - stop_words=["<|eot_id|>"], + stop_words=["<|eot_id|>", "<|eom_id|>"], replace_eos=True, replace_jinja_template=False, mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"), @@ -829,16 +833,18 @@ _register_template( ] ), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( - "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\n" ) ] ), + format_tools=ToolFormatter(tool_format="llama3"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), - stop_words=["<|eot_id|>"], + stop_words=["<|eot_id|>", "<|eom_id|>"], replace_eos=True, replace_jinja_template=False, mm_plugin=get_mm_plugin(name="llava_next", image_token=""), diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index ca71d47d..2465191a 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -17,6 +17,7 @@ import re from abc import ABC, abstractmethod from collections import namedtuple from dataclasses import dataclass +from datetime import datetime from typing import Any, Dict, List, Tuple, Union from typing_extensions import override @@ -24,6 +25,9 @@ from typing_extensions import override from .data_utils import SLOTS +FunctionCall = namedtuple("FunctionCall", ["name", "arguments"]) + + DEFAULT_TOOL_PROMPT = ( "You have access to the following tools:\n{tool_text}" "Use the following format if using a tool:\n" @@ -41,7 +45,12 @@ GLM4_TOOL_PROMPT = ( ) -FunctionCall = namedtuple("FunctionCall", ["name", "arguments"]) +LLAMA3_TOOL_PROMPT = ( + "Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n" + "You have access to the following functions. To call a function, please respond with JSON for a function call. " + """Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """ + "Do not use variables.\n\n{tool_text}" +) @dataclass @@ -161,16 +170,52 @@ class GLM4ToolUtils(ToolUtils): tool_name, tool_input = content.split("\n", maxsplit=1) try: - arguments = json.loads(tool_input) + arguments = json.loads(tool_input.strip()) except json.JSONDecodeError: return content return [(tool_name, json.dumps(arguments, ensure_ascii=False))] +class Llama3ToolUtils(ToolUtils): + r""" + Llama 3.x tool using template with `tools_in_user_message=False`. + """ + + @override + @staticmethod + def get_function_slots() -> SLOTS: + return ["""{"name": "{{name}}", "parameters": {{arguments}}}"""] + + @override + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + cur_time = datetime.now().strftime("%d %b %Y") + tool_text = "" + for tool in tools: + wrapped_tool = {"type": "function", "function": tool} + tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n" + + return LLAMA3_TOOL_PROMPT.format(cur_time=cur_time, tool_text=tool_text) + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]: + try: + tool = json.loads(content.strip()) + except json.JSONDecodeError: + return content + + if "name" not in tool or "parameters" not in tool: + return content + + return [(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))] + + TOOLS = { "default": DefaultToolUtils(), "glm4": GLM4ToolUtils(), + "llama3": Llama3ToolUtils(), } diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index c4d0fd84..ff7933e9 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -837,6 +837,10 @@ register_model_group( DownloadSource.DEFAULT: "meta-llama/Llama-3.2-3B-Instruct", DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-3B-Instruct", }, + "Llama-3.3-70B-Instruct": { + DownloadSource.DEFAULT: "meta-llama/Llama-3.3-70B-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.3-70B-Instruct", + }, }, template="llama3", ) diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index 051bc120..de4b85a8 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -13,10 +13,29 @@ # limitations under the License. import json +from datetime import datetime from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter +FUNCTION = {"name": "tool_name", "arguments": {"foo": "bar", "size": 10}} + +TOOLS = [ + { + "name": "test_tool", + "description": "tool_desc", + "parameters": { + "type": "object", + "properties": { + "foo": {"type": "string", "description": "foo_desc"}, + "bar": {"type": "number", "description": "bar_desc"}, + }, + "required": ["foo"], + }, + } +] + + def test_empty_formatter(): formatter = EmptyFormatter(slots=["\n"]) assert formatter.apply() == ["\n"] @@ -28,39 +47,27 @@ def test_string_formatter(): def test_function_formatter(): - formatter = FunctionFormatter(slots=[], tool_format="default") - tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) + formatter = FunctionFormatter(slots=[""], tool_format="default") + tool_calls = json.dumps(FUNCTION) assert formatter.apply(content=tool_calls) == [ - """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""" + """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", + "", ] def test_multi_function_formatter(): - formatter = FunctionFormatter(slots=[], tool_format="default") - tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2) + formatter = FunctionFormatter(slots=[""], tool_format="default") + tool_calls = json.dumps([FUNCTION] * 2) assert formatter.apply(content=tool_calls) == [ - """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", - """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", + """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", + """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", + "", ] def test_default_tool_formatter(): formatter = ToolFormatter(tool_format="default") - tools = [ - { - "name": "test_tool", - "description": "tool_desc", - "parameters": { - "type": "object", - "properties": { - "foo": {"type": "string", "description": "foo_desc"}, - "bar": {"type": "number", "description": "bar_desc"}, - }, - "required": ["foo"], - }, - } - ] - assert formatter.apply(content=json.dumps(tools)) == [ + assert formatter.apply(content=json.dumps(TOOLS)) == [ "You have access to the following tools:\n" "> Tool Name: test_tool\n" "Tool Description: tool_desc\n" @@ -94,26 +101,18 @@ def test_default_multi_tool_extractor(): ] +def test_glm4_function_formatter(): + formatter = FunctionFormatter(tool_format="glm4") + tool_calls = json.dumps(FUNCTION) + assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""] + + def test_glm4_tool_formatter(): formatter = ToolFormatter(tool_format="glm4") - tools = [ - { - "name": "test_tool", - "description": "tool_desc", - "parameters": { - "type": "object", - "properties": { - "foo": {"type": "string", "description": "foo_desc"}, - "bar": {"type": "number", "description": "bar_desc"}, - }, - "required": ["foo"], - }, - } - ] - assert formatter.apply(content=json.dumps(tools)) == [ + assert formatter.apply(content=json.dumps(TOOLS)) == [ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具\n\n" - "## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(json.dumps(tools[0], indent=4)) + f"## test_tool\n\n{json.dumps(TOOLS[0], indent=4)}\n在调用上述函数时,请使用 Json 格式表示调用的参数。" ] @@ -121,3 +120,29 @@ def test_glm4_tool_extractor(): formatter = ToolFormatter(tool_format="glm4") result = """test_tool\n{"foo": "bar", "size": 10}\n""" assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +def test_llama3_function_formatter(): + formatter = FunctionFormatter(tool_format="llama3") + tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) + assert formatter.apply(content=tool_calls) == [ + """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""" + ] + + +def test_llama3_tool_formatter(): + formatter = ToolFormatter(tool_format="llama3") + cur_time = datetime.now().strftime("%d %b %Y") + wrapped_tool = {"type": "function", "function": TOOLS[0]} + assert formatter.apply(content=json.dumps(TOOLS)) == [ + f"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n" + "You have access to the following functions. To call a function, please respond with JSON for a function call. " + """Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """ + f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n" + ] + + +def test_llama3_tool_extractor(): + formatter = ToolFormatter(tool_format="llama3") + result = """{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}\n""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]