diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 8e8e11e41..dd9d06291 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -1330,6 +1330,26 @@ register_template( ) +register_template( + name="lfm", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm"), + format_observation=StringFormatter( + slots=[ + "<|im_start|>tool\n<|tool_response_start|>{{content}}<|tool_response_end|><|im_end|>\n" + "<|im_start|>assistant\n" + ] + ), + format_tools=ToolFormatter(tool_format="lfm"), + default_system="You are a helpful AI assistant.", + stop_words=["<|im_end|>"], + tool_call_words=("<|tool_call_start|>", "<|tool_call_end|>"), + replace_eos=True, +) + + register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 48132af84..54cc8f5c3 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import json import re from abc import ABC, abstractmethod @@ -101,6 +102,8 @@ LING_TOOL_PROMPT = ( """"arguments": }}\n""" ) +LFM_TOOL_PROMPT = "List of tools: <|tool_list_start|>{tool_text}<|tool_list_end|>" + @dataclass class ToolUtils(ABC): @@ -546,10 +549,115 @@ class LingToolUtils(QwenToolUtils): return LING_TOOL_PROMPT.format(tool_text=tool_text) + "\n" + "detailed thinking off" +class LFMToolUtils(ToolUtils): + r"""LFM 2.5 tool using template with Pythonic function call syntax.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_list = [] + for tool in tools: + tool = tool.get("function", tool) if tool.get("type") == "function" else tool + tool_list.append(tool) + + return LFM_TOOL_PROMPT.format(tool_text=json.dumps(tool_list, ensure_ascii=False)) + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + calls = [] + for name, args_json in functions: + args = json.loads(args_json) + kwargs_parts = [] + for key, value in args.items(): + if isinstance(value, str): + kwargs_parts.append(f'{key}="{value}"') + else: + kwargs_parts.append(f"{key}={json.dumps(value, ensure_ascii=False)}") + + calls.append(f"{name}({', '.join(kwargs_parts)})") + + return f"<|tool_call_start|>[{', '.join(calls)}]<|tool_call_end|>" + + @staticmethod + def _ast_to_value(node: ast.AST) -> Any: + """Convert an AST node to a Python value, handling JSON-style booleans/null.""" + # Handle JSON-style true/false/null as Name nodes + if isinstance(node, ast.Name): + if node.id == "true": + return True + elif node.id == "false": + return False + elif node.id == "null": + return None + else: + raise ValueError(f"Unknown identifier: {node.id}") + + # Use literal_eval for other cases (strings, numbers, lists, dicts) + return ast.literal_eval(node) + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: + # Extract content between tool call markers + start_marker = "<|tool_call_start|>" + end_marker = "<|tool_call_end|>" + + start_idx = content.find(start_marker) + if start_idx == -1: + return content + + end_idx = content.find(end_marker, start_idx) + if end_idx == -1: + return content + + tool_call_str = content[start_idx + len(start_marker) : end_idx].strip() + + # Parse Pythonic function call syntax using AST + try: + tree = ast.parse(tool_call_str, mode="eval") + except SyntaxError: + return content + + # Handle both single call and list of calls + if isinstance(tree.body, ast.List): + call_nodes = tree.body.elts + elif isinstance(tree.body, ast.Call): + call_nodes = [tree.body] + else: + return content + + results = [] + for node in call_nodes: + if not isinstance(node, ast.Call): + return content + + # Extract function name + if isinstance(node.func, ast.Name): + func_name = node.func.id + else: + return content + + # Extract keyword arguments + args_dict = {} + for keyword in node.keywords: + key = keyword.arg + try: + value = LFMToolUtils._ast_to_value(keyword.value) + except (ValueError, SyntaxError): + return content + args_dict[key] = value + + results.append(FunctionCall(func_name, json.dumps(args_dict, ensure_ascii=False))) + + return results if results else content + + TOOLS = { "default": DefaultToolUtils(), "glm4": GLM4ToolUtils(), "llama3": Llama3ToolUtils(), + "lfm": LFMToolUtils(), "minimax1": MiniMaxM1ToolUtils(), "minimax2": MiniMaxM2ToolUtils(), "mistral": MistralToolUtils(), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index cf50954e9..8920420ad 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1493,6 +1493,19 @@ register_model_group( ) +register_model_group( + models={ + "LFM2.5-1.2B": { + DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Base", + }, + "LFM2.5-1.2B-Instruct": { + DownloadSource.DEFAULT: "LiquidAI/LFM2.5-1.2B-Instruct", + }, + }, + template="lfm", +) + + register_model_group( models={ "Llama-7B": { diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py index 969b0be32..38266ac1a 100644 --- a/tests/data/test_formatter.py +++ b/tests/data/test_formatter.py @@ -292,3 +292,91 @@ def test_qwen_multi_tool_extractor(): ("test_tool", """{"foo": "bar", "size": 10}"""), ("another_tool", """{"foo": "job", "size": 2}"""), ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm") + tool_calls = json.dumps(FUNCTION) + assert formatter.apply(content=tool_calls) == [ + """<|tool_call_start|>[tool_name(foo="bar", size=10)]<|tool_call_end|><|im_end|>\n""" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm_multi_function_formatter(): + formatter = FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="lfm") + tool_calls = json.dumps([FUNCTION] * 2) + assert formatter.apply(content=tool_calls) == [ + """<|tool_call_start|>[tool_name(foo="bar", size=10), tool_name(foo="bar", size=10)]<|tool_call_end|>""" + "<|im_end|>\n" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm_tool_formatter(): + formatter = ToolFormatter(tool_format="lfm") + assert formatter.apply(content=json.dumps(TOOLS)) == [ + "List of tools: <|tool_list_start|>" + json.dumps(TOOLS, ensure_ascii=False) + "<|tool_list_end|>" + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm_tool_extractor(): + formatter = ToolFormatter(tool_format="lfm") + result = """<|tool_call_start|>[test_tool(foo="bar", size=10)]<|tool_call_end|>""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="lfm") + result = """<|tool_call_start|>[test_tool(foo="bar", size=10), another_tool(foo="job", size=2)]<|tool_call_end|>""" + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ] + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm_tool_extractor_with_nested_dict(): + formatter = ToolFormatter(tool_format="lfm") + result = """<|tool_call_start|>[search(query="test", options={"limit": 10, "offset": 0})]<|tool_call_end|>""" + extracted = formatter.extract(result) + assert len(extracted) == 1 + assert extracted[0][0] == "search" + args = json.loads(extracted[0][1]) + assert args["query"] == "test" + assert args["options"] == {"limit": 10, "offset": 0} + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm_tool_extractor_with_list_arg(): + formatter = ToolFormatter(tool_format="lfm") + result = """<|tool_call_start|>[batch_process(items=[1, 2, 3], enabled=True)]<|tool_call_end|>""" + extracted = formatter.extract(result) + assert len(extracted) == 1 + assert extracted[0][0] == "batch_process" + args = json.loads(extracted[0][1]) + assert args["items"] == [1, 2, 3] + assert args["enabled"] is True + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm_tool_extractor_no_match(): + formatter = ToolFormatter(tool_format="lfm") + result = "This is a regular response without tool calls." + extracted = formatter.extract(result) + assert extracted == result + + +@pytest.mark.runs_on(["cpu", "mps"]) +def test_lfm_tool_round_trip(): + formatter = FunctionFormatter(slots=["{{content}}"], tool_format="lfm") + tool_formatter = ToolFormatter(tool_format="lfm") + original = {"name": "my_func", "arguments": {"arg1": "hello", "arg2": 42, "arg3": True}} + formatted = formatter.apply(content=json.dumps(original)) + extracted = tool_formatter.extract(formatted[0]) + assert len(extracted) == 1 + assert extracted[0][0] == original["name"] + assert json.loads(extracted[0][1]) == original["arguments"]