mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-07-30 18:22:51 +08:00
[model] add glm4moe (#8689)
This commit is contained in:
parent
2353e16e20
commit
2aadc90c2d
@ -1014,6 +1014,21 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="glm4_moe",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4_moe"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
efficient_eos=True,
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4 template
|
||||
register_template(
|
||||
name="glm4v",
|
||||
|
@ -42,6 +42,18 @@ GLM4_TOOL_PROMPT = (
|
||||
"你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{tool_text}"
|
||||
)
|
||||
|
||||
GLM4_MOE_TOOL_PROMPT = (
|
||||
"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>{tool_text}"
|
||||
"\n</tools>\n\nFor each function call, output the function name and arguments within the following XML format:"
|
||||
"\n<tool_call>{{function-name}}"
|
||||
"\n<arg_key>{{arg-key-1}}</arg_key>"
|
||||
"\n<arg_value>{{arg-value-1}}</arg_value>"
|
||||
"\n<arg_key>{{arg-key-2}}</arg_key>"
|
||||
"\n<arg_value>{{arg-value-2}}</arg_value>"
|
||||
"\n...\n</tool_call>\n"
|
||||
)
|
||||
|
||||
LLAMA3_TOOL_PROMPT = (
|
||||
"Cutting Knowledge Date: December 2023\nToday Date: {date}\n\n"
|
||||
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
|
||||
@ -303,12 +315,45 @@ class QwenToolUtils(ToolUtils):
|
||||
return results
|
||||
|
||||
|
||||
class GLM4MOEToolUtils(QwenToolUtils):
|
||||
r"""GLM-4-MOE tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
|
||||
tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
|
||||
|
||||
return GLM4_MOE_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_json = [
|
||||
{"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions
|
||||
]
|
||||
function_texts = []
|
||||
for func in function_json:
|
||||
prompt = "\n<tool_call>" + func["func_name"]
|
||||
for key, value in func["func_key_values"].items():
|
||||
prompt += "\n<arg_key>" + key + "</arg_key>"
|
||||
if not isinstance(value, str):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
prompt += "\n<arg_value>" + value + "</arg_value>"
|
||||
function_texts.append(prompt)
|
||||
|
||||
return "\n".join(function_texts)
|
||||
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
"qwen": QwenToolUtils(),
|
||||
"glm4_moe": GLM4MOEToolUtils(),
|
||||
}
|
||||
|
||||
|
||||
|
@ -57,6 +57,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [GraniteMoeMoE])
|
||||
|
||||
if model_type == "glm4_moe":
|
||||
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE
|
||||
|
||||
_set_z3_leaf_modules(model, [Glm4MoeMoE])
|
||||
|
||||
if model_type == "jamba":
|
||||
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user