diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 0a101ac6..68a7e309 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -97,8 +97,11 @@ class FunctionFormatter(StringFormatter): @override def apply(self, **kwargs) -> SLOTS: content: str = kwargs.pop("content") - regex = re.compile(r"(.*)", re.DOTALL) - thought = re.search(regex, content) + thought_words, thought = kwargs.pop("thought_words", None), None + if thought_words and len(thought_words)== 2: + regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL) + thought = re.search(regex, content) + if thought: content = content.replace(thought.group(0), "") diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index f524e84f..f2741446 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -156,7 +156,7 @@ class Template: elif message["role"] == Role.OBSERVATION: elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION: - elements += self.format_function.apply(content=message["content"]) + elements += self.format_function.apply(content=message["content"], thought_words=self.thought_words) else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) @@ -1855,6 +1855,22 @@ register_template( ) +#copied from seed_coder +register_template( + name="seed_oss", + format_user=StringFormatter( + slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"] + ), + format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]), + format_function=FunctionFormatter( + slots=[{"bos_token"}, "\n{{content}}", {"eos_token"}], tool_format="seed_oss" + ), + format_tools=ToolFormatter(tool_format="seed_oss"), + template_class=ReasoningTemplate, + thought_words=("", "") +) + + # copied from llama3 template register_template( name="skywork_o1", diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index f90f47c1..6a6fd1dd 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -69,6 +69,15 @@ QWEN_TOOL_PROMPT = ( """"arguments": }}\n""" ) +SEED_TOOL_PROMPT = ( + "system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query." + "Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing " + "any task, you must decide how to call them based on the descriptions and parameters of these tools.{tool_text}\n" + "工具调用请遵循如下格式:\n\n\nvalue_1" + "\nThis is the value for the second parameter\nthat can span\nmultiple " + "lines\n\n\n" +) + @dataclass class ToolUtils(ABC): @@ -346,6 +355,55 @@ class GLM4MOEToolUtils(QwenToolUtils): return "\n".join(function_texts) +class SeedToolUtils(ToolUtils): + r"""Seed tool using template.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + return SEED_TOOL_PROMPT.format(tool_text="\n" + json.dumps(tools, ensure_ascii=False)) + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + function_json = [ + {"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions + ] + function_texts = [] + for func in function_json: + prompt = "\n\n" + if not isinstance(value, str): + value = json.dumps(value, ensure_ascii=False) + prompt += value + "" + prompt += "\n\n" + function_texts.append(prompt) + + return "\n".join(function_texts) + + @override + @staticmethod + def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: + results = [] + regex = re.compile( + r"\s*\s*", + re.DOTALL + ) + for func_name, params_block in re.findall(regex, content): + args_dict = {} + param_pattern = re.compile(r"(.*?)", re.DOTALL) + for key, raw_value in re.findall(param_pattern, params_block.strip()): + value = raw_value.strip() + try: + parsed_value = json.loads(value) + except json.JSONDecodeError: + parsed_value = raw_value + args_dict[key] = parsed_value + + results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False))) + + return results TOOLS = { "default": DefaultToolUtils(), @@ -354,6 +412,7 @@ TOOLS = { "mistral": MistralToolUtils(), "qwen": QwenToolUtils(), "glm4_moe": GLM4MOEToolUtils(), + "seed_oss": SeedToolUtils(), }