mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-12 17:10:36 +08:00
[model] support for LiquidAI's LFM2.5 (Liquid Foundation Models) (#9726)
This commit is contained in:
@@ -1330,6 +1330,26 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="lfm",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
"<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="lfm"),
|
||||
default_system="You are a helpful AI assistant.",
|
||||
stop_words=["<|im_end|>"],
|
||||
tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"),
|
||||
replace_eos=True,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="llama2",
|
||||
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -101,6 +102,8 @@ LING_TOOL_PROMPT = (
|
||||
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||
)
|
||||
|
||||
LFM_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@@ -546,10 +549,115 @@ class LingToolUtils(QwenToolUtils):
|
||||
return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off"
|
||||
|
||||
|
||||
class LFMToolUtils(ToolUtils):
|
||||
r"""LFM 2.5 tool using template with Pythonic function call syntax."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_list = []
|
||||
for tool in tools:
|
||||
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||
tool_list.append(tool)
|
||||
|
||||
return LFM_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False))
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
calls = []
|
||||
for name, args_json in functions:
|
||||
args = json.loads(args_json)
|
||||
kwargs_parts = []
|
||||
for key, value in args.items():
|
||||
if isinstance(value, str):
|
||||
kwargs_parts.append(f'{key}="{value}"')
|
||||
else:
|
||||
kwargs_parts.append(f"{key}={json.dumps(value, ensure_ascii=False)}")
|
||||
|
||||
calls.append(f"{name}({', '.join(kwargs_parts)})")
|
||||
|
||||
return f"<|tool_call_start|>[{', '.join(calls)}]<|tool_call_end|>"
|
||||
|
||||
@staticmethod
|
||||
def _ast_to_value(node: ast.AST) -> Any:
|
||||
"""Convert an AST node to a Python value, handling JSON-style booleans/null."""
|
||||
# Handle JSON-style true/false/null as Name nodes
|
||||
if isinstance(node, ast.Name):
|
||||
if node.id == "true":
|
||||
return True
|
||||
elif node.id == "false":
|
||||
return False
|
||||
elif node.id == "null":
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Unknown identifier: {node.id}")
|
||||
|
||||
# Use literal_eval for other cases (strings, numbers, lists, dicts)
|
||||
return ast.literal_eval(node)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
# Extract content between tool call markers
|
||||
start_marker = "<|tool_call_start|>"
|
||||
end_marker = "<|tool_call_end|>"
|
||||
|
||||
start_idx = content.find(start_marker)
|
||||
if start_idx == -1:
|
||||
return content
|
||||
|
||||
end_idx = content.find(end_marker, start_idx)
|
||||
if end_idx == -1:
|
||||
return content
|
||||
|
||||
tool_call_str = content[start_idx + len(start_marker) : end_idx].strip()
|
||||
|
||||
# Parse Pythonic function call syntax using AST
|
||||
try:
|
||||
tree = ast.parse(tool_call_str, mode="eval")
|
||||
except SyntaxError:
|
||||
return content
|
||||
|
||||
# Handle both single call and list of calls
|
||||
if isinstance(tree.body, ast.List):
|
||||
call_nodes = tree.body.elts
|
||||
elif isinstance(tree.body, ast.Call):
|
||||
call_nodes = [tree.body]
|
||||
else:
|
||||
return content
|
||||
|
||||
results = []
|
||||
for node in call_nodes:
|
||||
if not isinstance(node, ast.Call):
|
||||
return content
|
||||
|
||||
# Extract function name
|
||||
if isinstance(node.func, ast.Name):
|
||||
func_name = node.func.id
|
||||
else:
|
||||
return content
|
||||
|
||||
# Extract keyword arguments
|
||||
args_dict = {}
|
||||
for keyword in node.keywords:
|
||||
key = keyword.arg
|
||||
try:
|
||||
value = LFMToolUtils._ast_to_value(keyword.value)
|
||||
except (ValueError, SyntaxError):
|
||||
return content
|
||||
args_dict[key] = value
|
||||
|
||||
results.append(FunctionCall(func_name, json.dumps(args_dict, ensure_ascii=False)))
|
||||
|
||||
return results if results else content
|
||||
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
"glm4": GLM4ToolUtils(),
|
||||
"llama3": Llama3ToolUtils(),
|
||||
"lfm": LFMToolUtils(),
|
||||
"minimax1": MiniMaxM1ToolUtils(),
|
||||
"minimax2": MiniMaxM2ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
|
||||
@@ -1493,6 +1493,19 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LFM2.5-1.2B": {
|
||||
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Base",
|
||||
},
|
||||
"LFM2.5-1.2B-Instruct": {
|
||||
DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct",
|
||||
},
|
||||
},
|
||||
template="lfm",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Llama-7B": {
|
||||
|
||||
@@ -292,3 +292,91 @@ def test_qwen_multi_tool_extractor():
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm")
|
||||
tool_calls = json.dumps(FUNCTION)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""<|tool_call_start|>[tool_name(foo="bar", size=10)]<|tool_call_end|><|im_end|>\n"""
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm_multi_function_formatter():
|
||||
formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm")
|
||||
tool_calls = json.dumps([FUNCTION] * 2)
|
||||
assert formatter.apply(content=tool_calls) == [
|
||||
"""<|tool_call_start|>[tool_name(foo="bar", size=10), tool_name(foo="bar", size=10)]<|tool_call_end|>"""
|
||||
"<|im_end|>\n"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm_tool_formatter():
|
||||
formatter = ToolFormatter(tool_format="lfm")
|
||||
assert formatter.apply(content=json.dumps(TOOLS)) == [
|
||||
"List of tools: <|tool_list_start|>" + json.dumps(TOOLS, ensure_ascii=False) + "<|tool_list_end|>"
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="lfm")
|
||||
result = """<|tool_call_start|>[test_tool(foo="bar", size=10)]<|tool_call_end|>"""
|
||||
assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm_multi_tool_extractor():
|
||||
formatter = ToolFormatter(tool_format="lfm")
|
||||
result = """<|tool_call_start|>[test_tool(foo="bar", size=10), another_tool(foo="job", size=2)]<|tool_call_end|>"""
|
||||
assert formatter.extract(result) == [
|
||||
("test_tool", """{"foo": "bar", "size": 10}"""),
|
||||
("another_tool", """{"foo": "job", "size": 2}"""),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm_tool_extractor_with_nested_dict():
|
||||
formatter = ToolFormatter(tool_format="lfm")
|
||||
result = """<|tool_call_start|>[search(query="test", options={"limit": 10, "offset": 0})]<|tool_call_end|>"""
|
||||
extracted = formatter.extract(result)
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0][0] == "search"
|
||||
args = json.loads(extracted[0][1])
|
||||
assert args["query"] == "test"
|
||||
assert args["options"] == {"limit": 10, "offset": 0}
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm_tool_extractor_with_list_arg():
|
||||
formatter = ToolFormatter(tool_format="lfm")
|
||||
result = """<|tool_call_start|>[batch_process(items=[1, 2, 3], enabled=True)]<|tool_call_end|>"""
|
||||
extracted = formatter.extract(result)
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0][0] == "batch_process"
|
||||
args = json.loads(extracted[0][1])
|
||||
assert args["items"] == [1, 2, 3]
|
||||
assert args["enabled"] is True
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm_tool_extractor_no_match():
|
||||
formatter = ToolFormatter(tool_format="lfm")
|
||||
result = "This is a regular response without tool calls."
|
||||
extracted = formatter.extract(result)
|
||||
assert extracted == result
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_lfm_tool_round_trip():
|
||||
formatter = FunctionFormatter(slots=["{{content}}"], tool_format="lfm")
|
||||
tool_formatter = ToolFormatter(tool_format="lfm")
|
||||
original = {"name": "my_func", "arguments": {"arg1": "hello", "arg2": 42, "arg3": True}}
|
||||
formatted = formatter.apply(content=json.dumps(original))
|
||||
extracted = tool_formatter.extract(formatted[0])
|
||||
assert len(extracted) == 1
|
||||
assert extracted[0][0] == original["name"]
|
||||
assert json.loads(extracted[0][1]) == original["arguments"]
|
||||
|
||||
Reference in New Issue
Block a user