diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index ac558770..ed83a2d2 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -99,6 +99,11 @@ class FunctionFormatter(Formatter): @override def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") + regex = re.compile(r"(.*)", re.DOTALL) + thought = re.search(regex, content) + if thought: + content = content.replace(thought.group(0), "") + functions: List["FunctionCall"] = [] try: tool_calls = json.loads(content) @@ -116,6 +121,9 @@ class FunctionFormatter(Formatter): elements = [] for slot in self.slots: if slot == "{{content}}": + if thought: + elements.append(thought.group(1)) + elements += self.tool_utils.function_formatter(functions) else: elements.append(slot)