From c107cc22d00ecea740eb09e3765f87e2ef65eb4f Mon Sep 17 00:00:00 2001 From: Hertz <2267379130@qq.com> Date: Sun, 28 Dec 2025 19:02:05 +0800 Subject: [PATCH] [model] support MiniMax-M1&M2 series (#9680) Co-authored-by: Yaowei Zheng --- README.md | 1 + README_zh.md | 1 + src/llamafactory/data/template.py | 33 ++++++++ src/llamafactory/data/tool_utils.py | 122 +++++++++++++++++++++++++++ src/llamafactory/extras/constants.py | 34 ++++++++ 5 files changed, 191 insertions(+) diff --git a/README.md b/README.md index 865eb43ea..93f0c0211 100644 --- a/README.md +++ b/README.md @@ -309,6 +309,7 @@ Read technical notes: | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 | | [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | +| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | diff --git a/README_zh.md b/README_zh.md index 5521f440a..b1027f3fc 100644 --- a/README_zh.md +++ b/README_zh.md @@ -311,6 +311,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [MiMo](https://huggingface.co/XiaomiMiMo) | 7B/309B | mimo/mimo_v2 | | [MiniCPM 1-4.1](https://huggingface.co/openbmb) | 0.5B/1B/2B/4B/8B | cpm/cpm3/cpm4 | | [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_o/minicpm_v | +| [MiniMax-M1/MiniMax-M2](https://huggingface.co/MiniMaxAI/models) | 229B/456B | minimax1/minimax2 | | [Ministral 3](https://huggingface.co/mistralai) | 3B/8B/14B | ministral3 | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index db9301063..a986388d6 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1673,6 +1673,39 @@ register_template( ) +register_template( + name="minimax1", + format_user=StringFormatter( + slots=["user name=user\n{{content}}\nai name=assistant\n"] + ), + format_assistant=StringFormatter(slots=["{{content}}\n"]), + format_system=StringFormatter( + slots=["system ai_setting=assistant\n{{content}}\n"] + ), + format_function=FunctionFormatter(slots=["{{content}}\n"], tool_format="minimax1"), + format_observation=StringFormatter( + slots=["tool name=tools\n{{content}}\nai name=assistant\n"] + ), + format_tools=ToolFormatter(tool_format="minimax1"), + default_system="You are a helpful assistant.", + stop_words=[""], +) + + +register_template( + name="minimax2", + format_user=StringFormatter(slots=["]~b]user\n{{content}}[e~[\n]~b]ai\n"]), + format_assistant=StringFormatter(slots=["{{content}}[e~[\n"]), + format_system=StringFormatter(slots=["]~!b[]~b]system\n{{content}}[e~[\n"]), + format_function=FunctionFormatter(slots=["{{content}}[e~[\n"], tool_format="minimax2"), + format_observation=StringFormatter(slots=["]~b]tool\n{{content}}[e~[\n]~b]ai\n"]), + format_tools=ToolFormatter(tool_format="minimax2"), + default_system="You are a helpful assistant. Your name is MiniMax-M2.1 and is built by MiniMax.", + stop_words=["[e~["], + template_class=ReasoningTemplate, +) + + # mistral tokenizer v3 tekken register_template( name="ministral", diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 2f677f1eb..95fb0ac56 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -61,6 +61,21 @@ LLAMA3_TOOL_PROMPT = ( "Do not use variables.\n\n{tool_text}" ) +MINIMAX_M1_TOOL_PROMPT = ( + "You are provided with these tools:\n\n{tool_text}\n\n" + "If you need to call tools, please respond with XML tags, and provide tool-name and " + "json-object of arguments, following the format below:\n\n" + "{{\"name\": , \"arguments\": }}\n...\n" +) + +MINIMAX_M2_TOOL_PROMPT = ( + "\n\n# Tools\n\nYou may call one or more tools to assist with the user query.\n" + "Here are the tools available in JSONSchema format:\n\n\n{tool_text}\n\n" + "When making tool calls, use XML format to invoke tools and pass parameters:\n" + "\n\n\nparam-value-1\n" + "param-value-2\n...\n\n" +) + QWEN_TOOL_PROMPT = ( "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" "You are provided with function signatures within XML tags:\n{tool_text}" @@ -253,6 +268,111 @@ class Llama3ToolUtils(ToolUtils): return content +class MiniMaxM1ToolUtils(ToolUtils): + r"""MiniMax-M1 tool using template.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool = tool.get("function", "") if tool.get("type") == "function" else tool + tool_text += json.dumps(tool, ensure_ascii=False) + "\n" + + return MINIMAX_M1_TOOL_PROMPT.format(tool_text=tool_text) + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + function_texts = [] + for func in functions: + name, arguments = func.name, json.loads(func.arguments) + function_texts.append(json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)) + + return "\n" + "\n".join(function_texts) + "\n" + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: + regex = re.compile(r"\s*(.+?)\s*", re.DOTALL) + tool_match = re.search(regex, content) + if not tool_match: + return content + + tool_calls_content = tool_match.group(1) + results = [] + for line in tool_calls_content.split("\n"): + line = line.strip() + if not line: + continue + + try: + tool_call = json.loads(line) + results.append(FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + except json.JSONDecodeError: + continue + + return results + + +class MiniMaxM2ToolUtils(ToolUtils): + r"""MiniMax-M2 tool using template.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool = tool.get("function", "") if tool.get("type") == "function" else tool + tool_text += "" + json.dumps(tool, ensure_ascii=False) + "\n" + + return MINIMAX_M2_TOOL_PROMPT.format(tool_text=tool_text) + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + function_texts = [] + for func in functions: + name, arguments = func.name, json.loads(func.arguments) + prompt = "" + for key, value in arguments.items(): + prompt += "\n" + if not isinstance(value, str): + value = json.dumps(value, ensure_ascii=False) + prompt += value + "" + prompt += "\n" + function_texts.append(prompt) + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: + regex = re.compile( + r"\s*(.+?)\s*", re.DOTALL + ) + tool_match = re.search(regex, content) + if not tool_match: + return content + + tool_calls_content = tool_match.group(1) + invoke_regex = re.compile(r"(.*?)", re.DOTALL) + results = [] + + for func_name, params_block in re.findall(invoke_regex, tool_calls_content): + args_dict = {} + param_pattern = re.compile(r"(.*?)", re.DOTALL) + for key, raw_value in re.findall(param_pattern, params_block): + value = raw_value.strip() + try: + parsed_value = json.loads(value) + except json.JSONDecodeError: + parsed_value = raw_value + args_dict[key] = parsed_value + + results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False))) + + return results + + class MistralToolUtils(ToolUtils): r"""Mistral v0.3 tool using template.""" @@ -432,6 +552,8 @@ TOOLS = { "default": DefaultToolUtils(), "glm4": GLM4ToolUtils(), "llama3": Llama3ToolUtils(), + "minimax1": MiniMaxM1ToolUtils(), + "minimax2": MiniMaxM2ToolUtils(), "mistral": MistralToolUtils(), "qwen": QwenToolUtils(), "glm4_moe": GLM4MOEToolUtils(), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index eb053150c..e78c8b908 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1071,6 +1071,40 @@ register_model_group( ) +register_model_group( + models={ + "MiniMax-Text-01-Instruct": { + DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-Text-01-hf", + DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-Text-01", + }, + "MiniMax-M1-40k-Thinking": { + DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-40k-hf", + DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-40k-hf", + }, + "MiniMax-M1-80k-Thinking": { + DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M1-80k-hf", + DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M1-80k-hf", + }, + }, + template="minimax1", +) + + +register_model_group( + models={ + "MiniMax-M2-Thinking": { + DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2", + DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2", + }, + "MiniMax-M2.1-Thinking": { + DownloadSource.DEFAULT: "MiniMaxAI/MiniMax-M2.1", + DownloadSource.MODELSCOPE: "MiniMaxAI/MiniMax-M2.1", + }, + }, + template="minimax2", +) + + register_model_group( models={ "Granite-3.0-1B-A400M-Base": {