mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-19 05:38:56 +08:00
[model] gemma4 (#10346)
This commit is contained in:
@@ -209,6 +209,164 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
return results
|
||||
|
||||
class Gemma4ToolUtils(ToolUtils):
|
||||
r"""Gemma-4 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
def _format_parameters(properties: dict[str, Any]) -> str:
|
||||
parts: list[str] = []
|
||||
for name, schema in properties.items():
|
||||
item_parts: list[str] = []
|
||||
if schema.get("description"):
|
||||
item_parts.append(f'description:<|"|>{schema["description"]}<|"|>')
|
||||
if schema.get("type"):
|
||||
item_parts.append(f'type:<|"|>{str(schema["type"]).upper()}<|"|>')
|
||||
parts.append(f"{name}:{{{','.join(item_parts)}}}")
|
||||
|
||||
return ",".join(parts)
|
||||
|
||||
declarations: list[str] = []
|
||||
for tool in tools:
|
||||
function_data = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||
declaration = (
|
||||
f"declaration:{function_data['name']}"
|
||||
+ "{"
|
||||
+ f'description:<|"|>{function_data.get("description", "")}<|"|>'
|
||||
)
|
||||
|
||||
params = function_data.get("parameters")
|
||||
if params:
|
||||
param_parts: list[str] = []
|
||||
if params.get("properties"):
|
||||
param_parts.append(f"properties:{{{_format_parameters(params['properties'])}}}")
|
||||
|
||||
if params.get("required"):
|
||||
required_text = ",".join(f'<|"|>{item}<|"|>' for item in params["required"])
|
||||
param_parts.append(f"required:[{required_text}]")
|
||||
|
||||
if params.get("type"):
|
||||
param_parts.append(f'type:<|"|>{str(params["type"]).upper()}<|"|>')
|
||||
|
||||
declaration += f",parameters:{{{','.join(param_parts)}}}"
|
||||
|
||||
response_declaration = function_data.get("response")
|
||||
if response_declaration:
|
||||
response_parts: list[str] = []
|
||||
if response_declaration.get("description"):
|
||||
response_parts.append(f'description:<|"|>{response_declaration["description"]}<|"|>')
|
||||
|
||||
response_type = str(response_declaration.get("type", "")).upper()
|
||||
|
||||
if response_type == "OBJECT":
|
||||
response_parts.append(f'type:<|"|>{response_type}<|"|>')
|
||||
|
||||
declaration += f",response:{{{','.join(response_parts)}}}"
|
||||
|
||||
declarations.append(declaration + "}")
|
||||
|
||||
return "\n".join(declarations)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
regex = re.compile(r"<\|tool_call\>call:([^{\s]+)\{(.*?)\}<tool_call\|>", re.DOTALL)
|
||||
matches = re.findall(regex, content)
|
||||
if not matches:
|
||||
return content
|
||||
|
||||
def _parse_arguments(arg_text: str) -> Any:
|
||||
text = arg_text.strip()
|
||||
if not text:
|
||||
return {}
|
||||
|
||||
# `function_formatter` writes dict arguments as `k:v,...` inside `{...}`.
|
||||
# The extractor captures only the inner text, so re-wrap it to parse as JSON object.
|
||||
object_like_text = "{" + text + "}"
|
||||
# Convert Gemma string markers (<|"|>value<|"|>) to valid JSON strings.
|
||||
normalized = re.sub(
|
||||
r"<\|\"\|\>(.*?)<\|\"\|\>",
|
||||
lambda m: json.dumps(m.group(1), ensure_ascii=False),
|
||||
object_like_text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
# Quote unquoted object keys so the payload can be parsed by json.loads.
|
||||
normalized = re.sub(r'(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)', r'\1"\2"\3', normalized)
|
||||
try:
|
||||
return json.loads(normalized)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return text
|
||||
|
||||
results: list[FunctionCall] = []
|
||||
for name, arg_block in matches:
|
||||
parsed_arguments = _parse_arguments(arg_block)
|
||||
if isinstance(parsed_arguments, str):
|
||||
arguments = parsed_arguments
|
||||
else:
|
||||
arguments = json.dumps(parsed_arguments, ensure_ascii=False)
|
||||
results.append(FunctionCall(name.strip(), arguments))
|
||||
|
||||
return results
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
def _format_argument(argument: Any, escape_keys: bool = True) -> str:
|
||||
if isinstance(argument, str):
|
||||
return f'<|"|>{argument}<|"|>'
|
||||
|
||||
if isinstance(argument, bool):
|
||||
return "true" if argument else "false"
|
||||
|
||||
if isinstance(argument, dict):
|
||||
items: list[str] = []
|
||||
for key in sorted(argument.keys()):
|
||||
formatted_key = f'<|"|>{key}<|"|>' if escape_keys else str(key)
|
||||
formatted_value = _format_argument(argument[key], escape_keys=escape_keys)
|
||||
items.append(f"{formatted_key}:{formatted_value}")
|
||||
return "{" + ",".join(items) + "}"
|
||||
|
||||
if isinstance(argument, (list, tuple)):
|
||||
return "[" + ",".join(_format_argument(item, escape_keys=escape_keys) for item in argument) + "]"
|
||||
|
||||
if argument is None:
|
||||
return "null"
|
||||
|
||||
return str(argument)
|
||||
|
||||
function_texts: list[str] = []
|
||||
for function in functions:
|
||||
name = function.name
|
||||
raw_arguments = function.arguments
|
||||
|
||||
try:
|
||||
parsed_arguments = json.loads(raw_arguments)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
parsed_arguments = raw_arguments
|
||||
|
||||
call_text = f"<|tool_call>call:{name}" + "{"
|
||||
if isinstance(parsed_arguments, dict):
|
||||
args_text = []
|
||||
for key in sorted(parsed_arguments.keys()):
|
||||
value_text = _format_argument(parsed_arguments[key], escape_keys=False)
|
||||
args_text.append(f"{key}:{value_text}")
|
||||
|
||||
call_text += ",".join(args_text)
|
||||
elif isinstance(parsed_arguments, str):
|
||||
call_text += parsed_arguments
|
||||
else:
|
||||
call_text += _format_argument(parsed_arguments, escape_keys=False)
|
||||
|
||||
call_text += "}<tool_call|>"
|
||||
function_texts.append(call_text)
|
||||
|
||||
return "".join(function_texts)
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
r"""GLM-4 tool using template."""
|
||||
@@ -723,6 +881,7 @@ class LFM2ToolUtils(ToolUtils):
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"gemma4": Gemma4ToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"lfm2": LFM2ToolUtils(),
|
||||
|
||||
Reference in New Issue
Block a user