From 2aadc90c2d78758768e2a24851fce62de0691e5a Mon Sep 17 00:00:00 2001 From: Kingsley Date: Fri, 25 Jul 2025 19:53:45 +0800 Subject: [PATCH] [model] add glm4moe (#8689) --- src/llamafactory/data/template.py | 15 ++++++++ src/llamafactory/data/tool_utils.py | 45 +++++++++++++++++++++++ src/llamafactory/model/model_utils/moe.py | 5 +++ 3 files changed, 65 insertions(+) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index a07f159d..e49db8b4 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1014,6 +1014,21 @@ register_template( ) +register_template( + name="glm4_moe", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4_moe"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + template_class=ReasoningTemplate, +) + + # copied from glm4 template register_template( name="glm4v", diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index ee0245c4..f90f47c1 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -42,6 +42,18 @@ GLM4_TOOL_PROMPT = ( "你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{tool_text}" ) +GLM4_MOE_TOOL_PROMPT = ( + "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n{tool_text}" + "\n\n\nFor each function call, output the function name and arguments within the following XML format:" + "\n{{function-name}}" + "\n{{arg-key-1}}" + "\n{{arg-value-1}}" + "\n{{arg-key-2}}" + "\n{{arg-value-2}}" + "\n...\n\n" +) + LLAMA3_TOOL_PROMPT = ( "Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n" "You have access to the following functions. To call a function, please respond with JSON for a function call. " @@ -303,12 +315,45 @@ class QwenToolUtils(ToolUtils): return results +class GLM4MOEToolUtils(QwenToolUtils): + r"""GLM-4-MOE tool using template.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool} + tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) + + return GLM4_MOE_TOOL_PROMPT.format(tool_text=tool_text) + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + function_json = [ + {"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions + ] + function_texts = [] + for func in function_json: + prompt = "\n" + func["func_name"] + for key, value in func["func_key_values"].items(): + prompt += "\n" + key + "" + if not isinstance(value, str): + value = json.dumps(value, ensure_ascii=False) + prompt += "\n" + value + "" + function_texts.append(prompt) + + return "\n".join(function_texts) + + TOOLS = { "default": DefaultToolUtils(), "glm4": GLM4ToolUtils(), "llama3": Llama3ToolUtils(), "mistral": MistralToolUtils(), "qwen": QwenToolUtils(), + "glm4_moe": GLM4MOEToolUtils(), } diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index bc517cd6..db7adbee 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -57,6 +57,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: _set_z3_leaf_modules(model, [GraniteMoeMoE]) + if model_type == "glm4_moe": + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE + + _set_z3_leaf_modules(model, [Glm4MoeMoE]) + if model_type == "jamba": from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock