mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
[data] llama3 multi tool support (#8124)
This commit is contained in:
parent
f96c085857
commit
b3b2c9f1ee
@ -53,9 +53,7 @@ Choose your path:
|
||||
- **Documentation**: https://llamafactory.readthedocs.io/en/latest/
|
||||
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
|
||||
- **Local machine**: Please refer to [usage](#getting-started)
|
||||
- **PAI-DSW (free trial)**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)
|
||||
- **Amazon SageMaker**: [Blog](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)
|
||||
- **Easy Dataset**: [Fine-tune on Synthetic Data](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g)
|
||||
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
|
||||
> [!NOTE]
|
||||
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
|
||||
@ -107,7 +105,7 @@ Choose your path:
|
||||
|
||||
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
|
||||
- [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
|
||||
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model For News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
|
||||
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
|
||||
|
||||
<details><summary>All Blogs</summary>
|
||||
|
||||
@ -918,6 +916,7 @@ If you have a project that should be incorporated, please contact via email or c
|
||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention.
|
||||
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost.
|
||||
1. **[WeClone](https://github.com/xming521/WeClone)**: One-stop solution for creating your digital avatar from chat logs.
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -55,9 +55,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
|
||||
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
|
||||
- **本地机器**:请见[如何使用](#如何使用)
|
||||
- **PAI-DSW(免费试用)**:[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)
|
||||
- **Amazon SageMaker**:[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
|
||||
- **Easy Dataset**:[数据蒸馏微调](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)
|
||||
- **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
|
||||
|
||||
> [!NOTE]
|
||||
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
|
||||
@ -109,12 +107,12 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
|
||||
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
|
||||
- [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
|
||||
- [LLaMA Factory:微调DeepSeek-R1-Distill-Qwen-7B模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
|
||||
- [LLaMA Factory:微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
|
||||
|
||||
<details><summary>全部博客</summary>
|
||||
|
||||
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
|
||||
- [LLaMA Factory多模态微调实践:微调Qwen2-VL构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
|
||||
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
|
||||
- [LLaMA Factory:微调LLaMA3模型实现角色扮演](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)(中文)
|
||||
|
||||
</details>
|
||||
@ -920,6 +918,7 @@ swanlab_run_name: test_run # 可选
|
||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。
|
||||
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**:由 NovaSky AI 微调的低成本类 o1 长推理模型。
|
||||
1. **[WeClone](https://github.com/xming521/WeClone)**:从聊天记录创造数字分身的一站式解决方案。
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -125,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_text = ""
|
||||
for name, arguments in functions:
|
||||
function_text += f"Action: {name}\nAction Input: {arguments}\n"
|
||||
|
||||
return function_text
|
||||
return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@ -210,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
if len(functions) > 1:
|
||||
raise ValueError("Llama-3 does not support parallel functions.")
|
||||
|
||||
return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
|
||||
function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
|
||||
return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
try:
|
||||
tool = json.loads(content.strip())
|
||||
tools = json.loads(content.strip())
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
if "name" not in tool or "parameters" not in tool:
|
||||
tools = [tools] if not isinstance(tools, list) else tools
|
||||
try:
|
||||
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
|
||||
except KeyError:
|
||||
return content
|
||||
|
||||
return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
|
||||
|
||||
|
||||
class MistralToolUtils(ToolUtils):
|
||||
r"""Mistral v0.3 tool using template."""
|
||||
@ -244,11 +239,9 @@ class MistralToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
|
||||
|
||||
return "[" + ", ".join(function_texts) + "]"
|
||||
return json.dumps(
|
||||
[{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
|
||||
)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
@ -258,17 +251,11 @@ class MistralToolUtils(ToolUtils):
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
results = []
|
||||
for tool in tools:
|
||||
if "name" not in tool or "arguments" not in tool:
|
||||
return content
|
||||
|
||||
results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
|
||||
|
||||
return results
|
||||
tools = [tools] if not isinstance(tools, list) else tools
|
||||
try:
|
||||
return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
|
||||
except KeyError:
|
||||
return content
|
||||
|
||||
|
||||
class QwenToolUtils(ToolUtils):
|
||||
@ -287,13 +274,11 @@ class QwenToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for name, arguments in functions:
|
||||
function_texts.append(
|
||||
"<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
|
||||
)
|
||||
|
||||
return "\n".join(function_texts)
|
||||
function_texts = [
|
||||
json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
|
||||
for name, arguments in functions
|
||||
]
|
||||
return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
|
@ -50,7 +50,7 @@ def test_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
@ -60,7 +60,7 @@ def test_multi_function_formatter():
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
|
||||
"""Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
|
||||
"</s>",
|
||||
]
|
||||
|
||||
@ -85,7 +85,7 @@ def test_default_tool_formatter():
|
||||
|
||||
def test_default_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@ -93,7 +93,7 @@ def test_default_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="default")
|
||||
result = (
|
||||
"""Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
|
||||
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
|
||||
"""Action: another_tool\nAction Input: {"foo": "job", "size": 2}"""
|
||||
)
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
@ -125,12 +125,22 @@ def test_glm4_tool_extractor():
|
||||
|
||||
def test_llama3_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
|
||||
]
|
||||
|
||||
|
||||
def test_llama3_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
|
||||
"""{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
|
||||
"""<|eot_id|>"""
|
||||
]
|
||||
|
||||
|
||||
def test_llama3_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
date = datetime.now().strftime("%d %b %Y")
|
||||
@ -150,6 +160,18 @@ def test_llama3_tool_extractor():
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
def test_llama3_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="llama3")
|
||||
result = (
|
||||
"""[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
|
||||
"""{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
|
||||
)
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
|
||||
|
||||
def test_mistral_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
|
Loading…
x
Reference in New Issue
Block a user