mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[data] llama3 multi tool support (#8124)
This commit is contained in:
		
							parent
							
								
									c2f6f2fa77
								
							
						
					
					
						commit
						56926d76f9
					
				@ -53,9 +53,7 @@ Choose your path:
 | 
			
		||||
- **Documentation**: https://llamafactory.readthedocs.io/en/latest/
 | 
			
		||||
- **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing
 | 
			
		||||
- **Local machine**: Please refer to [usage](#getting-started)
 | 
			
		||||
- **PAI-DSW (free trial)**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)
 | 
			
		||||
- **Amazon SageMaker**: [Blog](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)
 | 
			
		||||
- **Easy Dataset**: [Fine-tune on Synthetic Data](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g)
 | 
			
		||||
- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
 | 
			
		||||
 | 
			
		||||
> [!NOTE]
 | 
			
		||||
> Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them.
 | 
			
		||||
@ -107,7 +105,7 @@ Choose your path:
 | 
			
		||||
 | 
			
		||||
- [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English)
 | 
			
		||||
- [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English)
 | 
			
		||||
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model For News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
 | 
			
		||||
- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
 | 
			
		||||
 | 
			
		||||
<details><summary>All Blogs</summary>
 | 
			
		||||
 | 
			
		||||
@ -918,6 +916,7 @@ If you have a project that should be incorporated, please contact via email or c
 | 
			
		||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
 | 
			
		||||
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention.
 | 
			
		||||
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost.
 | 
			
		||||
1. **[WeClone](https://github.com/xming521/WeClone)**: One-stop solution for creating your digital avatar from chat logs.
 | 
			
		||||
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,9 +55,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
- **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/
 | 
			
		||||
- **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing
 | 
			
		||||
- **本地机器**:请见[如何使用](#如何使用)
 | 
			
		||||
- **PAI-DSW(免费试用)**:[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)
 | 
			
		||||
- **Amazon SageMaker**:[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)
 | 
			
		||||
- **Easy Dataset**:[数据蒸馏微调](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)
 | 
			
		||||
- **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory
 | 
			
		||||
 | 
			
		||||
> [!NOTE]
 | 
			
		||||
> 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。
 | 
			
		||||
@ -109,12 +107,12 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
 | 
			
		||||
- [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文)
 | 
			
		||||
- [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文)
 | 
			
		||||
- [LLaMA Factory:微调DeepSeek-R1-Distill-Qwen-7B模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
 | 
			
		||||
- [LLaMA Factory:微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
 | 
			
		||||
 | 
			
		||||
<details><summary>全部博客</summary>
 | 
			
		||||
 | 
			
		||||
- [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文)
 | 
			
		||||
- [LLaMA Factory多模态微调实践:微调Qwen2-VL构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
 | 
			
		||||
- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文)
 | 
			
		||||
- [LLaMA Factory:微调LLaMA3模型实现角色扮演](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)(中文)
 | 
			
		||||
 | 
			
		||||
</details>
 | 
			
		||||
@ -920,6 +918,7 @@ swanlab_run_name: test_run # 可选
 | 
			
		||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357)
 | 
			
		||||
1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。
 | 
			
		||||
1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**:由 NovaSky AI 微调的低成本类 o1 长推理模型。
 | 
			
		||||
1. **[WeClone](https://github.com/xming521/WeClone)**:从聊天记录创造数字分身的一站式解决方案。
 | 
			
		||||
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -125,11 +125,7 @@ class DefaultToolUtils(ToolUtils):
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def function_formatter(functions: list["FunctionCall"]) -> str:
 | 
			
		||||
        function_text = ""
 | 
			
		||||
        for name, arguments in functions:
 | 
			
		||||
            function_text += f"Action: {name}\nAction Input: {arguments}\n"
 | 
			
		||||
 | 
			
		||||
        return function_text
 | 
			
		||||
        return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -210,24 +206,23 @@ class Llama3ToolUtils(ToolUtils):
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def function_formatter(functions: list["FunctionCall"]) -> str:
 | 
			
		||||
        if len(functions) > 1:
 | 
			
		||||
            raise ValueError("Llama-3 does not support parallel functions.")
 | 
			
		||||
 | 
			
		||||
        return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}'
 | 
			
		||||
        function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions]
 | 
			
		||||
        return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
 | 
			
		||||
        try:
 | 
			
		||||
            tool = json.loads(content.strip())
 | 
			
		||||
            tools = json.loads(content.strip())
 | 
			
		||||
        except json.JSONDecodeError:
 | 
			
		||||
            return content
 | 
			
		||||
 | 
			
		||||
        if "name" not in tool or "parameters" not in tool:
 | 
			
		||||
        tools = [tools] if not isinstance(tools, list) else tools
 | 
			
		||||
        try:
 | 
			
		||||
            return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            return content
 | 
			
		||||
 | 
			
		||||
        return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MistralToolUtils(ToolUtils):
 | 
			
		||||
    r"""Mistral v0.3 tool using template."""
 | 
			
		||||
@ -244,11 +239,9 @@ class MistralToolUtils(ToolUtils):
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def function_formatter(functions: list["FunctionCall"]) -> str:
 | 
			
		||||
        function_texts = []
 | 
			
		||||
        for name, arguments in functions:
 | 
			
		||||
            function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}')
 | 
			
		||||
 | 
			
		||||
        return "[" + ", ".join(function_texts) + "]"
 | 
			
		||||
        return json.dumps(
 | 
			
		||||
            [{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -258,17 +251,11 @@ class MistralToolUtils(ToolUtils):
 | 
			
		||||
        except json.JSONDecodeError:
 | 
			
		||||
            return content
 | 
			
		||||
 | 
			
		||||
        if not isinstance(tools, list):
 | 
			
		||||
            tools = [tools]
 | 
			
		||||
 | 
			
		||||
        results = []
 | 
			
		||||
        for tool in tools:
 | 
			
		||||
            if "name" not in tool or "arguments" not in tool:
 | 
			
		||||
                return content
 | 
			
		||||
 | 
			
		||||
            results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
        tools = [tools] if not isinstance(tools, list) else tools
 | 
			
		||||
        try:
 | 
			
		||||
            return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            return content
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QwenToolUtils(ToolUtils):
 | 
			
		||||
@ -287,13 +274,11 @@ class QwenToolUtils(ToolUtils):
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def function_formatter(functions: list["FunctionCall"]) -> str:
 | 
			
		||||
        function_texts = []
 | 
			
		||||
        for name, arguments in functions:
 | 
			
		||||
            function_texts.append(
 | 
			
		||||
                "<tool_call>\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n</tool_call>"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return "\n".join(function_texts)
 | 
			
		||||
        function_texts = [
 | 
			
		||||
            json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False)
 | 
			
		||||
            for name, arguments in functions
 | 
			
		||||
        ]
 | 
			
		||||
        return "\n".join([f"<tool_call>\n{text}\n</tool_call>" for text in function_texts])
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
 | 
			
		||||
@ -50,7 +50,7 @@ def test_function_formatter():
 | 
			
		||||
    formatter = FunctionFormatter(slots=["{{content}}", "</s>"], tool_format="default")
 | 
			
		||||
    tool_calls = json.dumps(FUNCTION)
 | 
			
		||||
    assert formatter.apply(content=tool_calls) == [
 | 
			
		||||
        """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
 | 
			
		||||
        """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
 | 
			
		||||
        "</s>",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ def test_multi_function_formatter():
 | 
			
		||||
    tool_calls = json.dumps([FUNCTION] * 2)
 | 
			
		||||
    assert formatter.apply(content=tool_calls) == [
 | 
			
		||||
        """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n"""
 | 
			
		||||
        """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""",
 | 
			
		||||
        """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""",
 | 
			
		||||
        "</s>",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
@ -85,7 +85,7 @@ def test_default_tool_formatter():
 | 
			
		||||
 | 
			
		||||
def test_default_tool_extractor():
 | 
			
		||||
    formatter = ToolFormatter(tool_format="default")
 | 
			
		||||
    result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
 | 
			
		||||
    result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}"""
 | 
			
		||||
    assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -93,7 +93,7 @@ def test_default_multi_tool_extractor():
 | 
			
		||||
    formatter = ToolFormatter(tool_format="default")
 | 
			
		||||
    result = (
 | 
			
		||||
        """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n"""
 | 
			
		||||
        """Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n"""
 | 
			
		||||
        """Action: another_tool\nAction Input: {"foo": "job", "size": 2}"""
 | 
			
		||||
    )
 | 
			
		||||
    assert formatter.extract(result) == [
 | 
			
		||||
        ("test_tool", """{"foo": "bar", "size": 10}"""),
 | 
			
		||||
@ -125,12 +125,22 @@ def test_glm4_tool_extractor():
 | 
			
		||||
 | 
			
		||||
def test_llama3_function_formatter():
 | 
			
		||||
    formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
 | 
			
		||||
    tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}})
 | 
			
		||||
    tool_calls = json.dumps(FUNCTION)
 | 
			
		||||
    assert formatter.apply(content=tool_calls) == [
 | 
			
		||||
        """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>"""
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_llama3_multi_function_formatter():
 | 
			
		||||
    formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3")
 | 
			
		||||
    tool_calls = json.dumps([FUNCTION] * 2)
 | 
			
		||||
    assert formatter.apply(content=tool_calls) == [
 | 
			
		||||
        """[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """
 | 
			
		||||
        """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]"""
 | 
			
		||||
        """<|eot_id|>"""
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_llama3_tool_formatter():
 | 
			
		||||
    formatter = ToolFormatter(tool_format="llama3")
 | 
			
		||||
    date = datetime.now().strftime("%d %b %Y")
 | 
			
		||||
@ -150,6 +160,18 @@ def test_llama3_tool_extractor():
 | 
			
		||||
    assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_llama3_multi_tool_extractor():
 | 
			
		||||
    formatter = ToolFormatter(tool_format="llama3")
 | 
			
		||||
    result = (
 | 
			
		||||
        """[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """
 | 
			
		||||
        """{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]"""
 | 
			
		||||
    )
 | 
			
		||||
    assert formatter.extract(result) == [
 | 
			
		||||
        ("test_tool", """{"foo": "bar", "size": 10}"""),
 | 
			
		||||
        ("another_tool", """{"foo": "job", "size": 2}"""),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_mistral_function_formatter():
 | 
			
		||||
    formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", "</s>"], tool_format="mistral")
 | 
			
		||||
    tool_calls = json.dumps(FUNCTION)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user