Merge pull request #6368 from hiyouga/hiyouga/fix_llama_template

[template] fix llama3 tool template

Former-commit-id: 8974a0a185daf7744b4d3a0b2776f9bd72e24426
This commit is contained in:
hoshi-hiyouga 2024-12-18 01:10:48 +08:00 committed by GitHub
commit ad00c793ce
6 changed files with 40 additions and 33 deletions

View File

@ -189,7 +189,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama | | [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |

View File

@ -190,7 +190,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | | [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | | [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama | | [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |

View File

@ -116,17 +116,21 @@ class FunctionFormatter(Formatter):
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
elements = [] elements = []
for name, arguments in functions: for slot in self.slots:
for slot in self.function_slots: if slot == "{{content}}":
if isinstance(slot, str): for name, arguments in functions:
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) for slot in self.function_slots:
elements.append(slot) if isinstance(slot, str):
elif isinstance(slot, (dict, set)): slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
elements.append(slot) elements.append(slot)
else: elif isinstance(slot, (dict, set)):
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}") elements.append(slot)
else:
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
else:
elements.append(slot)
return elements + self.slots return elements
@dataclass @dataclass

View File

@ -244,11 +244,11 @@ def _register_template(
) )
``` ```
""" """
eos_slots = [] if efficient_eos else [{"eos_token"}]
template_class = Llama2Template if name.startswith("llama2") else Template template_class = Llama2Template if name.startswith("llama2") else Template
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
default_user_formatter = StringFormatter(slots=["{{content}}"]) default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) default_assistant_formatter = StringFormatter(slots=default_slots)
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default") default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
default_tool_formatter = ToolFormatter(tool_format="default") default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter() default_separator_formatter = EmptyFormatter()
default_prefix_formatter = EmptyFormatter() default_prefix_formatter = EmptyFormatter()
@ -371,8 +371,8 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
if data_args.tool_format is not None: if data_args.tool_format is not None:
logger.info_rank0(f"Using tool format: {data_args.tool_format}.") logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
eos_slots = [] if template.efficient_eos else [{"eos_token"}] default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}]
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
stop_words = template.stop_words stop_words = template.stop_words
@ -490,7 +490,7 @@ _register_template(
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
format_assistant=StringFormatter(slots=["\n", "{{content}}"]), format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"), format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
), ),
@ -535,7 +535,7 @@ _register_template(
name="codegeex4", name="codegeex4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"), format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
format_tools=ToolFormatter(tool_format="glm4"), format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]), format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
@ -684,7 +684,7 @@ _register_template(
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]), format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=[], tool_format="glm4"), format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4"), format_tools=ToolFormatter(tool_format="glm4"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]), format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
@ -750,7 +750,7 @@ _register_template(
] ]
), ),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"), format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[ slots=[
( (
@ -779,7 +779,7 @@ _register_template(
] ]
), ),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"), format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[ slots=[
( (
@ -833,7 +833,7 @@ _register_template(
] ]
), ),
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"), format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
format_observation=StringFormatter( format_observation=StringFormatter(
slots=[ slots=[
( (

View File

@ -46,7 +46,7 @@ GLM4_TOOL_PROMPT = (
LLAMA3_TOOL_PROMPT = ( LLAMA3_TOOL_PROMPT = (
"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n" "Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. " "You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """ """Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
"Do not use variables.\n\n{tool_text}" "Do not use variables.\n\n{tool_text}"
@ -180,6 +180,8 @@ class GLM4ToolUtils(ToolUtils):
class Llama3ToolUtils(ToolUtils): class Llama3ToolUtils(ToolUtils):
r""" r"""
Llama 3.x tool using template with `tools_in_user_message=False`. Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
""" """
@override @override
@ -190,13 +192,13 @@ class Llama3ToolUtils(ToolUtils):
@override @override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
cur_time = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
tool_text = "" tool_text = ""
for tool in tools: for tool in tools:
wrapped_tool = {"type": "function", "function": tool} wrapped_tool = {"type": "function", "function": tool}
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n" tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
return LLAMA3_TOOL_PROMPT.format(cur_time=cur_time, tool_text=tool_text) return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
@override @override
@staticmethod @staticmethod

View File

@ -47,7 +47,7 @@ def test_string_formatter():
def test_function_formatter(): def test_function_formatter():
formatter = FunctionFormatter(slots=["</s>"], tool_format="default") formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
@ -56,7 +56,7 @@ def test_function_formatter():
def test_multi_function_formatter(): def test_multi_function_formatter():
formatter = FunctionFormatter(slots=["</s>"], tool_format="default") formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
tool_calls = json.dumps([FUNCTION] * 2) tool_calls = json.dumps([FUNCTION] * 2)
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
@ -102,7 +102,7 @@ def test_default_multi_tool_extractor():
def test_glm4_function_formatter(): def test_glm4_function_formatter():
formatter = FunctionFormatter(tool_format="glm4") formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
tool_calls = json.dumps(FUNCTION) tool_calls = json.dumps(FUNCTION)
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""] assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
@ -123,19 +123,20 @@ def test_glm4_tool_extractor():
def test_llama3_function_formatter(): def test_llama3_function_formatter():
formatter = FunctionFormatter(tool_format="llama3") formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
assert formatter.apply(content=tool_calls) == [ assert formatter.apply(content=tool_calls) == [
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""" """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
"<|eot_id|>",
] ]
def test_llama3_tool_formatter(): def test_llama3_tool_formatter():
formatter = ToolFormatter(tool_format="llama3") formatter = ToolFormatter(tool_format="llama3")
cur_time = datetime.now().strftime("%d %b %Y") date = datetime.now().strftime("%d %b %Y")
wrapped_tool = {"type": "function", "function": TOOLS[0]} wrapped_tool = {"type": "function", "function": TOOLS[0]}
assert formatter.apply(content=json.dumps(TOOLS)) == [ assert formatter.apply(content=json.dumps(TOOLS)) == [
f"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n" f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
"You have access to the following functions. To call a function, please respond with JSON for a function call. " "You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """ """Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n" f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n"