From 56926d76f94292f116fae2ea437e9aeeede380d8 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 21 May 2025 02:01:12 +0800 Subject: [PATCH] [data] llama3 multi tool support (#8124) --- README.md | 7 ++-- README_zh.md | 9 ++--- src/llamafactory/data/tool_utils.py | 57 +++++++++++------------------ tests/data/test_formatter.py | 32 +++++++++++++--- 4 files changed, 55 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 214b72b8..c02b7750 100644 --- a/README.md +++ b/README.md @@ -53,9 +53,7 @@ Choose your path: - **Documentation**: https://llamafactory.readthedocs.io/en/latest/ - **Colab (free)**: https://colab.research.google.com/drive/1eRTPn37ltBbYsISy9Aw2NuI2Aq5CQrD9?usp=sharing - **Local machine**: Please refer to [usage](#getting-started) -- **PAI-DSW (free trial)**: [Llama3 Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill Example](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) -- **Amazon SageMaker**: [Blog](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) -- **Easy Dataset**: [Fine-tune on Synthetic Data](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) +- **PAI-DSW (free trial)**: https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory > [!NOTE] > Except for the above links, all other websites are unauthorized third-party websites. Please carefully use them. @@ -107,7 +105,7 @@ Choose your path: - [How Apoidea Group enhances visual information extraction from banking documents with multimodal models using LLaMA-Factory on Amazon SageMaker HyperPod](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/) (English) - [Easy Dataset × LLaMA Factory: Enabling LLMs to Efficiently Learn Domain Knowledge](https://buaa-act.feishu.cn/wiki/GVzlwYcRFiR8OLkHbL6cQpYin7g) (English) -- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model For News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese) +- [LLaMA Factory: Fine-tuning the DeepSeek-R1-Distill-Qwen-7B Model for News Classifier](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) (Chinese)
All Blogs @@ -918,6 +916,7 @@ If you have a project that should be incorporated, please contact via email or c 1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357) 1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**: A modified library that supports long sequence SFT & DPO using ring attention. 1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**: An o1-like model fine-tuned by NovaSky AI with very small cost. +1. **[WeClone](https://github.com/xming521/WeClone)**: One-stop solution for creating your digital avatar from chat logs.
diff --git a/README_zh.md b/README_zh.md index 342cdf84..f9029948 100644 --- a/README_zh.md +++ b/README_zh.md @@ -55,9 +55,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc - **框架文档(昇腾 NPU)**:https://ascend.github.io/docs/sources/llamafactory/ - **Colab(免费)**:https://colab.research.google.com/drive/1d5KQtbemerlSDSxZIfAaWXhKr30QypiK?usp=sharing - **本地机器**:请见[如何使用](#如何使用) -- **PAI-DSW(免费试用)**:[Llama3 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory) | [Qwen2-VL 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl) | [DeepSeek-R1-Distill 案例](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b) -- **Amazon SageMaker**:[博客](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/) -- **Easy Dataset**:[数据蒸馏微调](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9) +- **PAI-DSW(免费试用)**:https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory > [!NOTE] > 除上述链接以外的其他网站均为未经许可的第三方网站,请小心甄别。 @@ -109,12 +107,12 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc - [通过亚马逊 SageMaker HyperPod 上的 LLaMA-Factory 增强多模态模型银行文档的视觉信息提取](https://aws.amazon.com/cn/blogs/machine-learning/how-apoidea-group-enhances-visual-information-extraction-from-banking-documents-with-multimodal-models-using-llama-factory-on-amazon-sagemaker-hyperpod/)(英文) - [Easy Dataset × LLaMA Factory: 让大模型高效学习领域知识](https://buaa-act.feishu.cn/wiki/KY9xwTGs1iqHrRkjXBwcZP9WnL9)(中文) -- [LLaMA Factory:微调DeepSeek-R1-Distill-Qwen-7B模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文) +- [LLaMA Factory:微调 DeepSeek-R1-Distill-Qwen-7B 模型实现新闻标题分类器](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_deepseek_r1_distill_7b)(中文)
全部博客 - [基于 Amazon SageMaker 和 LLaMA-Factory 打造一站式无代码模型微调部署平台 Model Hub](https://aws.amazon.com/cn/blogs/china/a-one-stop-code-free-model-fine-tuning-deployment-platform-based-on-sagemaker-and-llama-factory/)(中文) -- [LLaMA Factory多模态微调实践:微调Qwen2-VL构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文) +- [LLaMA Factory 多模态微调实践:微调 Qwen2-VL 构建文旅大模型](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory_qwen2vl)(中文) - [LLaMA Factory:微调LLaMA3模型实现角色扮演](https://gallery.pai-ml.com/#/preview/deepLearning/nlp/llama_factory)(中文)
@@ -920,6 +918,7 @@ swanlab_run_name: test_run # 可选 1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**:一个全链路 RAG 检索模型微调、推理和蒸馏代码库。[[blog]](https://zhuanlan.zhihu.com/p/987727357) 1. **[360-LLaMA-Factory](https://github.com/Qihoo360/360-LLaMA-Factory)**:一个魔改后的代码库,通过 Ring Attention 支持长序列的 SFT 和 DPO 训练。 1. **[Sky-T1](https://novasky-ai.github.io/posts/sky-t1/)**:由 NovaSky AI 微调的低成本类 o1 长推理模型。 +1. **[WeClone](https://github.com/xming521/WeClone)**:从聊天记录创造数字分身的一站式解决方案。 diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index b2f2798b..37856151 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -125,11 +125,7 @@ class DefaultToolUtils(ToolUtils): @override @staticmethod def function_formatter(functions: list["FunctionCall"]) -> str: - function_text = "" - for name, arguments in functions: - function_text += f"Action: {name}\nAction Input: {arguments}\n" - - return function_text + return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions]) @override @staticmethod @@ -210,24 +206,23 @@ class Llama3ToolUtils(ToolUtils): @override @staticmethod def function_formatter(functions: list["FunctionCall"]) -> str: - if len(functions) > 1: - raise ValueError("Llama-3 does not support parallel functions.") - - return f'{{"name": "{functions[0].name}", "parameters": {functions[0].arguments}}}' + function_objects = [{"name": name, "parameters": json.loads(arguments)} for name, arguments in functions] + return json.dumps(function_objects[0] if len(function_objects) == 1 else function_objects, ensure_ascii=False) @override @staticmethod def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: try: - tool = json.loads(content.strip()) + tools = json.loads(content.strip()) except json.JSONDecodeError: return content - if "name" not in tool or "parameters" not in tool: + tools = [tools] if not isinstance(tools, list) else tools + try: + return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False)) for tool in tools] + except KeyError: return content - return [FunctionCall(tool["name"], json.dumps(tool["parameters"], ensure_ascii=False))] - class MistralToolUtils(ToolUtils): r"""Mistral v0.3 tool using template.""" @@ -244,11 +239,9 @@ class MistralToolUtils(ToolUtils): @override @staticmethod def function_formatter(functions: list["FunctionCall"]) -> str: - function_texts = [] - for name, arguments in functions: - function_texts.append(f'{{"name": "{name}", "arguments": {arguments}}}') - - return "[" + ", ".join(function_texts) + "]" + return json.dumps( + [{"name": name, "arguments": json.loads(arguments)} for name, arguments in functions], ensure_ascii=False + ) @override @staticmethod @@ -258,17 +251,11 @@ class MistralToolUtils(ToolUtils): except json.JSONDecodeError: return content - if not isinstance(tools, list): - tools = [tools] - - results = [] - for tool in tools: - if "name" not in tool or "arguments" not in tool: - return content - - results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False))) - - return results + tools = [tools] if not isinstance(tools, list) else tools + try: + return [FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)) for tool in tools] + except KeyError: + return content class QwenToolUtils(ToolUtils): @@ -287,13 +274,11 @@ class QwenToolUtils(ToolUtils): @override @staticmethod def function_formatter(functions: list["FunctionCall"]) -> str: - function_texts = [] - for name, arguments in functions: - function_texts.append( - "\n" + f'{{"name": "{name}", "arguments": {arguments}}}' + "\n" - ) - - return "\n".join(function_texts) + function_texts = [ + json.dumps({"name": name, "arguments": json.loads(arguments)}, ensure_ascii=False) + for name, arguments in functions + ] + return "\n".join([f"\n{text}\n" for text in function_texts]) @override @staticmethod diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index 7e08465a..00b8c649 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -50,7 +50,7 @@ def test_function_formatter(): formatter = FunctionFormatter(slots=["{{content}}", ""], tool_format="default") tool_calls = json.dumps(FUNCTION) assert formatter.apply(content=tool_calls) == [ - """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", + """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""", "", ] @@ -60,7 +60,7 @@ def test_multi_function_formatter(): tool_calls = json.dumps([FUNCTION] * 2) assert formatter.apply(content=tool_calls) == [ """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""" - """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}\n""", + """Action: tool_name\nAction Input: {"foo": "bar", "size": 10}""", "", ] @@ -85,7 +85,7 @@ def test_default_tool_formatter(): def test_default_tool_extractor(): formatter = ToolFormatter(tool_format="default") - result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n""" + result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}""" assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] @@ -93,7 +93,7 @@ def test_default_multi_tool_extractor(): formatter = ToolFormatter(tool_format="default") result = ( """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n""" - """Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n""" + """Action: another_tool\nAction Input: {"foo": "job", "size": 2}""" ) assert formatter.extract(result) == [ ("test_tool", """{"foo": "bar", "size": 10}"""), @@ -125,12 +125,22 @@ def test_glm4_tool_extractor(): def test_llama3_function_formatter(): formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") - tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) + tool_calls = json.dumps(FUNCTION) assert formatter.apply(content=tool_calls) == [ """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}<|eot_id|>""" ] +def test_llama3_multi_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3") + tool_calls = json.dumps([FUNCTION] * 2) + assert formatter.apply(content=tool_calls) == [ + """[{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}, """ + """{"name": "tool_name", "parameters": {"foo": "bar", "size": 10}}]""" + """<|eot_id|>""" + ] + + def test_llama3_tool_formatter(): formatter = ToolFormatter(tool_format="llama3") date = datetime.now().strftime("%d %b %Y") @@ -150,6 +160,18 @@ def test_llama3_tool_extractor(): assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] +def test_llama3_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="llama3") + result = ( + """[{"name": "test_tool", "parameters": {"foo": "bar", "size": 10}}, """ + """{"name": "another_tool", "parameters": {"foo": "job", "size": 2}}]""" + ) + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ] + + def test_mistral_function_formatter(): formatter = FunctionFormatter(slots=["[TOOL_CALLS] {{content}}", ""], tool_format="mistral") tool_calls = json.dumps(FUNCTION)