From f6a2bfc0e8d721bdf69deb90d9bfda1a1553273a Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 17 Dec 2024 17:04:02 +0000 Subject: [PATCH] fix llama3 tool template Former-commit-id: df5655f61cb847dc2d9eb7b34266b20343ff90d6 --- README.md | 2 +- README_zh.md | 2 +- src/llamafactory/data/formatter.py | 24 ++++++++++++++---------- src/llamafactory/data/template.py | 22 +++++++++++----------- src/llamafactory/data/tool_utils.py | 8 +++++--- tests/data/test_formatter.py | 15 ++++++++------- 6 files changed, 40 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index d0a94dc6..02db0319 100644 --- a/README.md +++ b/README.md @@ -189,7 +189,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | -| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | +| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | diff --git a/README_zh.md b/README_zh.md index 7e5b914b..e890a131 100644 --- a/README_zh.md +++ b/README_zh.md @@ -190,7 +190,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | -| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | +| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index d5f4b385..28bf3fb1 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -116,17 +116,21 @@ class FunctionFormatter(Formatter): raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string elements = [] - for name, arguments in functions: - for slot in self.function_slots: - if isinstance(slot, str): - slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) - elements.append(slot) - elif isinstance(slot, (dict, set)): - elements.append(slot) - else: - raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}") + for slot in self.slots: + if slot == "{{content}}": + for name, arguments in functions: + for slot in self.function_slots: + if isinstance(slot, str): + slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}") + else: + elements.append(slot) - return elements + self.slots + return elements @dataclass diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index becfeaa8..cfc4a2cb 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -244,11 +244,11 @@ def _register_template( ) ``` """ - eos_slots = [] if efficient_eos else [{"eos_token"}] template_class = Llama2Template if name.startswith("llama2") else Template + default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] default_user_formatter = StringFormatter(slots=["{{content}}"]) - default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) - default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default") + default_assistant_formatter = StringFormatter(slots=default_slots) + default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() default_prefix_formatter = EmptyFormatter() @@ -371,8 +371,8 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: if data_args.tool_format is not None: logger.info_rank0(f"Using tool format: {data_args.tool_format}.") - eos_slots = [] if template.efficient_eos else [{"eos_token"}] - template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format) + default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}] + template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format) stop_words = template.stop_words @@ -490,7 +490,7 @@ _register_template( format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), - format_function=FunctionFormatter(slots=[], tool_format="glm4"), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), format_observation=StringFormatter( slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] ), @@ -535,7 +535,7 @@ _register_template( name="codegeex4", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), - format_function=FunctionFormatter(slots=[], tool_format="glm4"), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]), format_tools=ToolFormatter(tool_format="glm4"), format_prefix=EmptyFormatter(slots=["[gMASK]"]), @@ -684,7 +684,7 @@ _register_template( format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_assistant=StringFormatter(slots=["\n{{content}}"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), - format_function=FunctionFormatter(slots=[], tool_format="glm4"), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), format_tools=ToolFormatter(tool_format="glm4"), format_prefix=EmptyFormatter(slots=["[gMASK]"]), @@ -750,7 +750,7 @@ _register_template( ] ), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), - format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"), + format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( @@ -779,7 +779,7 @@ _register_template( ] ), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), - format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"), + format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( @@ -833,7 +833,7 @@ _register_template( ] ), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), - format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"), + format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 2465191a..7c58dfc1 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -46,7 +46,7 @@ GLM4_TOOL_PROMPT = ( LLAMA3_TOOL_PROMPT = ( - "Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n" + "Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n" "You have access to the following functions. To call a function, please respond with JSON for a function call. " """Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """ "Do not use variables.\n\n{tool_text}" @@ -180,6 +180,8 @@ class GLM4ToolUtils(ToolUtils): class Llama3ToolUtils(ToolUtils): r""" Llama 3.x tool using template with `tools_in_user_message=False`. + + Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling """ @override @@ -190,13 +192,13 @@ class Llama3ToolUtils(ToolUtils): @override @staticmethod def tool_formatter(tools: List[Dict[str, Any]]) -> str: - cur_time = datetime.now().strftime("%d %b %Y") + date = datetime.now().strftime("%d %b %Y") tool_text = "" for tool in tools: wrapped_tool = {"type": "function", "function": tool} tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n" - return LLAMA3_TOOL_PROMPT.format(cur_time=cur_time, tool_text=tool_text) + return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text) @override @staticmethod diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index de4b85a8..6543a71a 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -47,7 +47,7 @@ def test_string_formatter(): def test_function_formatter(): - formatter = FunctionFormatter(slots=[""], tool_format="default") + formatter = FunctionFormatter(slots=["{{content}}", ""], tool_format="default") tool_calls = json.dumps(FUNCTION) assert formatter.apply(content=tool_calls) == [ """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", @@ -56,7 +56,7 @@ def test_function_formatter(): def test_multi_function_formatter(): - formatter = FunctionFormatter(slots=[""], tool_format="default") + formatter = FunctionFormatter(slots=["{{content}}", ""], tool_format="default") tool_calls = json.dumps([FUNCTION] * 2) assert formatter.apply(content=tool_calls) == [ """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", @@ -102,7 +102,7 @@ def test_default_multi_tool_extractor(): def test_glm4_function_formatter(): - formatter = FunctionFormatter(tool_format="glm4") + formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4") tool_calls = json.dumps(FUNCTION) assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""] @@ -123,19 +123,20 @@ def test_glm4_tool_extractor(): def test_llama3_function_formatter(): - formatter = FunctionFormatter(tool_format="llama3") + formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3") tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) assert formatter.apply(content=tool_calls) == [ - """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""" + """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""", + "<|eot_id|>", ] def test_llama3_tool_formatter(): formatter = ToolFormatter(tool_format="llama3") - cur_time = datetime.now().strftime("%d %b %Y") + date = datetime.now().strftime("%d %b %Y") wrapped_tool = {"type": "function", "function": TOOLS[0]} assert formatter.apply(content=json.dumps(TOOLS)) == [ - f"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n" + f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n" "You have access to the following functions. To call a function, please respond with JSON for a function call. " """Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """ f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n"