mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
Merge pull request #6368 from hiyouga/hiyouga/fix_llama_template
[template] fix llama3 tool template Former-commit-id: 8974a0a185daf7744b4d3a0b2776f9bd72e24426
This commit is contained in:
commit
ad00c793ce
@ -189,7 +189,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
|
@ -190,7 +190,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
|
||||
| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 |
|
||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
| [Llama 3-3.3](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
|
||||
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
|
||||
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
|
||||
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
|
||||
|
@ -116,6 +116,8 @@ class FunctionFormatter(Formatter):
|
||||
raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string
|
||||
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
if slot == "{{content}}":
|
||||
for name, arguments in functions:
|
||||
for slot in self.function_slots:
|
||||
if isinstance(slot, str):
|
||||
@ -125,8 +127,10 @@ class FunctionFormatter(Formatter):
|
||||
elements.append(slot)
|
||||
else:
|
||||
raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}")
|
||||
else:
|
||||
elements.append(slot)
|
||||
|
||||
return elements + self.slots
|
||||
return elements
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -244,11 +244,11 @@ def _register_template(
|
||||
)
|
||||
```
|
||||
"""
|
||||
eos_slots = [] if efficient_eos else [{"eos_token"}]
|
||||
template_class = Llama2Template if name.startswith("llama2") else Template
|
||||
default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
|
||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
|
||||
default_assistant_formatter = StringFormatter(slots=default_slots)
|
||||
default_function_formatter = FunctionFormatter(slots=default_slots, tool_format="default")
|
||||
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||
default_separator_formatter = EmptyFormatter()
|
||||
default_prefix_formatter = EmptyFormatter()
|
||||
@ -371,8 +371,8 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
|
||||
if data_args.tool_format is not None:
|
||||
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
|
||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
||||
default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}]
|
||||
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
|
||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||
|
||||
stop_words = template.stop_words
|
||||
@ -490,7 +490,7 @@ _register_template(
|
||||
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
||||
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
||||
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
||||
),
|
||||
@ -535,7 +535,7 @@ _register_template(
|
||||
name="codegeex4",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>\n"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
@ -684,7 +684,7 @@ _register_template(
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
@ -750,7 +750,7 @@ _register_template(
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
@ -779,7 +779,7 @@ _register_template(
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
@ -833,7 +833,7 @@ _register_template(
|
||||
]
|
||||
),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["<|eom_id|>"], tool_format="llama3"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
|
@ -46,7 +46,7 @@ GLM4_TOOL_PROMPT = (
|
||||
|
||||
|
||||
LLAMA3_TOOL_PROMPT = (
|
||||
"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
|
||||
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
|
||||
"Do not use variables.\n\n{tool_text}"
|
||||
@ -180,6 +180,8 @@ class GLM4ToolUtils(ToolUtils):
|
||||
class Llama3ToolUtils(ToolUtils):
|
||||
r"""
|
||||
Llama 3.x tool using template with `tools_in_user_message=False`.
|
||||
|
||||
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
|
||||
"""
|
||||
|
||||
@override
|
||||
@ -190,13 +192,13 @@ class Llama3ToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
cur_time = datetime.now().strftime("%d %b %Y")
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
wrapped_tool = {"type": "function", "function": tool}
|
||||
tool_text += json.dumps(wrapped_tool, indent=4, ensure_ascii=False) + "\n\n"
|
||||
|
||||
return LLAMA3_TOOL_PROMPT.format(cur_time=cur_time, tool_text=tool_text)
|
||||
return LLAMA3_TOOL_PROMPT.format(date=date, tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
|
@ -47,7 +47,7 @@ def test_string_formatter():
|
||||
|
||||
|
||||
def test_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["</s>"], tool_format="default")
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||
@ -56,7 +56,7 @@ def test_function_formatter():
|
||||
|
||||
|
||||
def test_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["</s>"], tool_format="default")
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||
@ -102,7 +102,7 @@ def test_default_multi_tool_extractor():
|
||||
|
||||
|
||||
def test_glm4_function_formatter():
|
||||
formatter = FunctionFormatter(tool_format="glm4")
|
||||
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="glm4")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == ["""tool_name\n{"foo": "bar", "size": 10}"""]
|
||||
|
||||
@ -123,19 +123,20 @@ def test_glm4_tool_extractor():
|
||||
|
||||
|
||||
def test_llama3_function_formatter():
|
||||
formatter = FunctionFormatter(tool_format="llama3")
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}"""
|
||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}""",
|
||||
"<|eot_id|>",
|
||||
]
|
||||
|
||||
|
||||
def test_llama3_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
cur_time = datetime.now().strftime("%d %b %Y")
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
wrapped_tool = {"type": "function", "function": TOOLS[0]}
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
f"Environment: ipython\nCutting Knowledge Date: December 2023\nToday Date: {cur_time}\n\n"
|
||||
f"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||
"""Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. """
|
||||
f"Do not use variables.\n\n{json.dumps(wrapped_tool, indent=4)}\n\n"
|
||||
|
Loading…
x
Reference in New Issue
Block a user