mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
Merge pull request #6368 from hiyouga/hiyouga/fix_llama_template
[template] fix llama3 tool template Former-commit-id: 8974a0a185daf7744b4d3a0b2776f9bd72e24426
This commit is contained in:
commit
ad00c793ce
@ -189,7 +189,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
|
@ -190,7 +190,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
|||||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||||
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||||
|
@ -116,17 +116,21 @@ class FunctionFormatter(Formatter):
|
|||||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||||
|
|
||||||
elements = []
|
elements = []
|
||||||
for name, arguments in functions:
|
for slot in self.slots:
|
||||||
for slot in self.function_slots:
|
if slot == "{{content}}":
|
||||||
if isinstance(slot, str):
|
for name, arguments in functions:
|
||||||
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
for slot in self.function_slots:
|
||||||
elements.append(slot)
|
if isinstance(slot, str):
|
||||||
elif isinstance(slot, (dict, set)):
|
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
||||||
elements.append(slot)
|
elements.append(slot)
|
||||||
else:
|
elif isinstance(slot, (dict, set)):
|
||||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
elements.append(slot)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||||
|
else:
|
||||||
|
elements.append(slot)
|
||||||
|
|
||||||
return elements + self.slots
|
return elements
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -244,11 +244,11 @@ def _register_template(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
eos_slots = [] if efficient_eos else [{"eos_token"}]
|
|
||||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||||
|
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
default_assistant_formatter = StringFormatter(slots=default_slots)
|
||||||
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
|
default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
|
||||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||||
default_separator_formatter = EmptyFormatter()
|
default_separator_formatter = EmptyFormatter()
|
||||||
default_prefix_formatter = EmptyFormatter()
|
default_prefix_formatter = EmptyFormatter()
|
||||||
@ -371,8 +371,8 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
|
|
||||||
if data_args.tool_format is not None:
|
if data_args.tool_format is not None:
|
||||||
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
|
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
|
||||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}]
|
||||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
|
||||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||||
|
|
||||||
stop_words = template.stop_words
|
stop_words = template.stop_words
|
||||||
@ -490,7 +490,7 @@ _register_template(
|
|||||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||||
),
|
),
|
||||||
@ -535,7 +535,7 @@ _register_template(
|
|||||||
name="codegeex4",
|
name="codegeex4",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
|
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
|
||||||
format_tools=ToolFormatter(tool_format="glm4"),
|
format_tools=ToolFormatter(tool_format="glm4"),
|
||||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||||
@ -684,7 +684,7 @@ _register_template(
|
|||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||||
format_tools=ToolFormatter(tool_format="glm4"),
|
format_tools=ToolFormatter(tool_format="glm4"),
|
||||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||||
@ -750,7 +750,7 @@ _register_template(
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
@ -779,7 +779,7 @@ _register_template(
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
@ -833,7 +833,7 @@ _register_template(
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
|
@ -46,7 +46,7 @@ GLM4_TOOL_PROMPT = (
|
|||||||
|
|
||||||
|
|
||||||
LLAMA3_TOOL_PROMPT = (
|
LLAMA3_TOOL_PROMPT = (
|
||||||
"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
|
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
||||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||||
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
|
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
|
||||||
"Do not use variables.\n\n{tool_text}"
|
"Do not use variables.\n\n{tool_text}"
|
||||||
@ -180,6 +180,8 @@ class GLM4ToolUtils(ToolUtils):
|
|||||||
class Llama3ToolUtils(ToolUtils):
|
class Llama3ToolUtils(ToolUtils):
|
||||||
r"""
|
r"""
|
||||||
Llama 3.x tool using template with `tools_in_user_message=False`.
|
Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||||
|
|
||||||
|
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -190,13 +192,13 @@ class Llama3ToolUtils(ToolUtils):
|
|||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
cur_time = datetime.now().strftime("%d %b %Y")
|
date = datetime.now().strftime("%d %b %Y")
|
||||||
tool_text = ""
|
tool_text = ""
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
wrapped_tool = {"type": "function", "function": tool}
|
wrapped_tool = {"type": "function", "function": tool}
|
||||||
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
|
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
return LLAMA3_TOOL_PROMPT.format(cur_time=cur_time, tool_text=tool_text)
|
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -47,7 +47,7 @@ def test_string_formatter():
|
|||||||
|
|
||||||
|
|
||||||
def test_function_formatter():
|
def test_function_formatter():
|
||||||
formatter = FunctionFormatter(slots=["</s>"], tool_format="default")
|
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||||
tool_calls = json.dumps(FUNCTION)
|
tool_calls = json.dumps(FUNCTION)
|
||||||
assert formatter.apply(content=tool_calls) == [
|
assert formatter.apply(content=tool_calls) == [
|
||||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||||
@ -56,7 +56,7 @@ def test_function_formatter():
|
|||||||
|
|
||||||
|
|
||||||
def test_multi_function_formatter():
|
def test_multi_function_formatter():
|
||||||
formatter = FunctionFormatter(slots=["</s>"], tool_format="default")
|
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||||
tool_calls = json.dumps([FUNCTION] * 2)
|
tool_calls = json.dumps([FUNCTION] * 2)
|
||||||
assert formatter.apply(content=tool_calls) == [
|
assert formatter.apply(content=tool_calls) == [
|
||||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||||
@ -102,7 +102,7 @@ def test_default_multi_tool_extractor():
|
|||||||
|
|
||||||
|
|
||||||
def test_glm4_function_formatter():
|
def test_glm4_function_formatter():
|
||||||
formatter = FunctionFormatter(tool_format="glm4")
|
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
|
||||||
tool_calls = json.dumps(FUNCTION)
|
tool_calls = json.dumps(FUNCTION)
|
||||||
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
|
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
|
||||||
|
|
||||||
@ -123,19 +123,20 @@ def test_glm4_tool_extractor():
|
|||||||
|
|
||||||
|
|
||||||
def test_llama3_function_formatter():
|
def test_llama3_function_formatter():
|
||||||
formatter = FunctionFormatter(tool_format="llama3")
|
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
|
||||||
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||||
assert formatter.apply(content=tool_calls) == [
|
assert formatter.apply(content=tool_calls) == [
|
||||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}"""
|
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
|
||||||
|
"<|eot_id|>",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_llama3_tool_formatter():
|
def test_llama3_tool_formatter():
|
||||||
formatter = ToolFormatter(tool_format="llama3")
|
formatter = ToolFormatter(tool_format="llama3")
|
||||||
cur_time = datetime.now().strftime("%d %b %Y")
|
date = datetime.now().strftime("%d %b %Y")
|
||||||
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||||
f"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
|
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
||||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||||
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
|
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
|
||||||
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n"
|
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user