diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py
index 0a101ac6..68a7e309 100644
--- a/src/llamafactory/data/formatter.py
+++ b/src/llamafactory/data/formatter.py
@@ -97,8 +97,11 @@ class FunctionFormatter(StringFormatter):
@override
def apply(self, **kwargs) -> SLOTS:
content: str = kwargs.pop("content")
- regex = re.compile(r"(.*)", re.DOTALL)
- thought = re.search(regex, content)
+ thought_words, thought = kwargs.pop("thought_words", None), None
+ if thought_words and len(thought_words)== 2:
+ regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
+ thought = re.search(regex, content)
+
if thought:
content = content.replace(thought.group(0), "")
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index f524e84f..f2741446 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -156,7 +156,7 @@ class Template:
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
- elements += self.format_function.apply(content=message["content"])
+ elements += self.format_function.apply(content=message["content"], thought_words=self.thought_words)
else:
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
@@ -1855,6 +1855,22 @@ register_template(
)
+#copied from seed_coder
+register_template(
+ name="seed_oss",
+ format_user=StringFormatter(
+ slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"]
+ ),
+ format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]),
+ format_function=FunctionFormatter(
+ slots=[{"bos_token"}, "\n{{content}}", {"eos_token"}], tool_format="seed_oss"
+ ),
+ format_tools=ToolFormatter(tool_format="seed_oss"),
+ template_class=ReasoningTemplate,
+ thought_words=("", "")
+)
+
+
# copied from llama3 template
register_template(
name="skywork_o1",
diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py
index f90f47c1..6a6fd1dd 100644
--- a/src/llamafactory/data/tool_utils.py
+++ b/src/llamafactory/data/tool_utils.py
@@ -69,6 +69,15 @@ QWEN_TOOL_PROMPT = (
""""arguments": }}\n"""
)
+SEED_TOOL_PROMPT = (
+ "system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
+ "Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
+ "any task, you must decide how to call them based on the descriptions and parameters of these tools.{tool_text}\n"
+ "工具调用请遵循如下格式:\n\n\nvalue_1"
+ "\nThis is the value for the second parameter\nthat can span\nmultiple "
+ "lines\n\n\n"
+)
+
@dataclass
class ToolUtils(ABC):
@@ -346,6 +355,55 @@ class GLM4MOEToolUtils(QwenToolUtils):
return "\n".join(function_texts)
+class SeedToolUtils(ToolUtils):
+ r"""Seed tool using template."""
+
+ @override
+ @staticmethod
+ def tool_formatter(tools: list[dict[str, Any]]) -> str:
+ return SEED_TOOL_PROMPT.format(tool_text="\n" + json.dumps(tools, ensure_ascii=False))
+
+ @override
+ @staticmethod
+ def function_formatter(functions: list["FunctionCall"]) -> str:
+ function_json = [
+ {"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions
+ ]
+ function_texts = []
+ for func in function_json:
+ prompt = "\n\n"
+ if not isinstance(value, str):
+ value = json.dumps(value, ensure_ascii=False)
+ prompt += value + ""
+ prompt += "\n\n"
+ function_texts.append(prompt)
+
+ return "\n".join(function_texts)
+
+ @override
+ @staticmethod
+ def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
+ results = []
+ regex = re.compile(
+ r"\s*\s*",
+ re.DOTALL
+ )
+ for func_name, params_block in re.findall(regex, content):
+ args_dict = {}
+ param_pattern = re.compile(r"(.*?)", re.DOTALL)
+ for key, raw_value in re.findall(param_pattern, params_block.strip()):
+ value = raw_value.strip()
+ try:
+ parsed_value = json.loads(value)
+ except json.JSONDecodeError:
+ parsed_value = raw_value
+ args_dict[key] = parsed_value
+
+ results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
+
+ return results
TOOLS = {
"default": DefaultToolUtils(),
@@ -354,6 +412,7 @@ TOOLS = {
"mistral": MistralToolUtils(),
"qwen": QwenToolUtils(),
"glm4_moe": GLM4MOEToolUtils(),
+ "seed_oss": SeedToolUtils(),
}