From fdd24276edcdbf3fa0cafad56ede25e98f9dd022 Mon Sep 17 00:00:00 2001 From: Xunpeng Xiao <124695565+tangefly@users.noreply.github.com> Date: Sun, 14 Dec 2025 00:20:33 +0800 Subject: [PATCH] [feat] support new function call value (#9610) Co-authored-by: Yaowei Zheng --- src/llamafactory/data/formatter.py | 55 +++++++++++++++++++----------- src/llamafactory/data/template.py | 8 ++++- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 9b527a7e..d13bb858 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -97,31 +97,46 @@ class FunctionFormatter(StringFormatter): @override def apply(self, **kwargs) -> SLOTS: content: str = kwargs.pop("content") - thought_words, thought = kwargs.pop("thought_words", None), None - if thought_words and len(thought_words) == 2: - regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL) - thought = re.search(regex, content) + thought_words = kwargs.pop("thought_words", None) + tool_call_words = kwargs.pop("tool_call_words", None) - if thought: - content = content.replace(thought.group(0), "") + def _parse_functions(json_content: str) -> list["FunctionCall"]: + try: + tool_calls = json.loads(json_content) + if not isinstance(tool_calls, list): # parallel function call + tool_calls = [tool_calls] - functions: list[FunctionCall] = [] - try: - tool_calls = json.loads(content) - if not isinstance(tool_calls, list): # parallel function call - tool_calls = [tool_calls] + return [FunctionCall(tc["name"], json.dumps(tc["arguments"], ensure_ascii=False)) for tc in tool_calls] + except json.JSONDecodeError: + raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") - for tool_call in tool_calls: - functions.append( - FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)) - ) + tool_call_match = None + if tool_call_words and len(tool_call_words) == 2: + tool_call_regex = re.compile( + rf"{re.escape(tool_call_words[0])}(.*?){re.escape(tool_call_words[1])}", re.DOTALL + ) + tool_call_match = re.search(tool_call_regex, content) - except json.JSONDecodeError: - raise RuntimeError(f"Invalid JSON format in function message: {str([content])}.") # flat string + if tool_call_match is None: + thought_match = None + if thought_words and len(thought_words) == 2: + regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL) + thought_match = re.search(regex, content) - function_str = self.tool_utils.function_formatter(functions) - if thought: - function_str = thought.group(0) + function_str + if thought_match: + json_part = content.replace(thought_match.group(0), "") + else: + json_part = content + + functions = _parse_functions(json_part) + function_str = self.tool_utils.function_formatter(functions) + if thought_match: + function_str = thought_match.group(0) + function_str + else: + thought_content = content.replace(tool_call_match.group(0), "") + functions = _parse_functions(tool_call_match.group(1)) + function_str = self.tool_utils.function_formatter(functions) + function_str = thought_content + function_str return super().apply(content=function_str) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 36a4d43b..c611632f 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -49,6 +49,7 @@ class Template: default_system: str stop_words: list[str] thought_words: tuple[str, str] + tool_call_words: tuple[str, str] efficient_eos: bool replace_eos: bool replace_jinja_template: bool @@ -156,7 +157,9 @@ class Template: elif message["role"] == Role.OBSERVATION: elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION: - elements += self.format_function.apply(content=message["content"], thought_words=self.thought_words) + elements += self.format_function.apply( + content=message["content"], thought_words=self.thought_words, tool_call_words=self.tool_call_words + ) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -471,6 +474,7 @@ def register_template( default_system: str = "", stop_words: Optional[list[str]] = None, thought_words: Optional[tuple[str, str]] = None, + tool_call_words: Optional[tuple[str, str]] = None, efficient_eos: bool = False, replace_eos: bool = False, replace_jinja_template: bool = False, @@ -522,6 +526,7 @@ def register_template( default_system=default_system, stop_words=stop_words or [], thought_words=thought_words or ("\n", "\n\n\n"), + tool_call_words=tool_call_words or ("", ""), efficient_eos=efficient_eos, replace_eos=replace_eos, replace_jinja_template=replace_jinja_template, @@ -583,6 +588,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template": default_system=default_system, stop_words=[], thought_words=("\n", "\n\n\n"), + tool_call_words=("", ""), efficient_eos=False, replace_eos=False, replace_jinja_template=False,