mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-01 18:52:50 +08:00
[model] support Seed-OSS (#8992)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
c5976f9b53
commit
652e6e92da
@ -97,8 +97,11 @@ class FunctionFormatter(StringFormatter):
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content: str = kwargs.pop("content")
|
||||
regex = re.compile(r"<think>(.*)</think>", re.DOTALL)
|
||||
thought = re.search(regex, content)
|
||||
thought_words, thought = kwargs.pop("thought_words", None), None
|
||||
if thought_words and len(thought_words)== 2:
|
||||
regex = re.compile(rf"{re.escape(thought_words[0])}(.*?){re.escape(thought_words[1])}", re.DOTALL)
|
||||
thought = re.search(regex, content)
|
||||
|
||||
if thought:
|
||||
content = content.replace(thought.group(0), "")
|
||||
|
||||
|
@ -156,7 +156,7 @@ class Template:
|
||||
elif message["role"] == Role.OBSERVATION:
|
||||
elements += self.format_observation.apply(content=message["content"])
|
||||
elif message["role"] == Role.FUNCTION:
|
||||
elements += self.format_function.apply(content=message["content"])
|
||||
elements += self.format_function.apply(content=message["content"], thought_words=self.thought_words)
|
||||
else:
|
||||
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
||||
|
||||
@ -1855,6 +1855,22 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
#copied from seed_coder
|
||||
register_template(
|
||||
name="seed_oss",
|
||||
format_user=StringFormatter(
|
||||
slots=[{"bos_token"}, "user\n{{content}}", {"eos_token"}, {"bos_token"}, "assistant\n"]
|
||||
),
|
||||
format_system=StringFormatter(slots=[{"bos_token"}, "system\n{{content}}", {"eos_token"}]),
|
||||
format_function=FunctionFormatter(
|
||||
slots=[{"bos_token"}, "\n{{content}}", {"eos_token"}], tool_format="seed_oss"
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="seed_oss"),
|
||||
template_class=ReasoningTemplate,
|
||||
thought_words=("<seed:think>", "</seed:think>")
|
||||
)
|
||||
|
||||
|
||||
# copied from llama3 template
|
||||
register_template(
|
||||
name="skywork_o1",
|
||||
|
@ -69,6 +69,15 @@ QWEN_TOOL_PROMPT = (
|
||||
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||
)
|
||||
|
||||
SEED_TOOL_PROMPT = (
|
||||
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
|
||||
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
|
||||
"any task, you must decide how to call them based on the descriptions and parameters of these tools.{tool_text}\n"
|
||||
"工具调用请遵循如下格式:\n<seed:tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>value_1"
|
||||
"</parameter>\n<parameter=example_parameter_2>This is the value for the second parameter\nthat can span\nmultiple "
|
||||
"lines</parameter>\n</function>\n</seed:tool_call>\n"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@ -346,6 +355,55 @@ class GLM4MOEToolUtils(QwenToolUtils):
|
||||
|
||||
return "\n".join(function_texts)
|
||||
|
||||
class SeedToolUtils(ToolUtils):
|
||||
r"""Seed tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
return SEED_TOOL_PROMPT.format(tool_text="\n" + json.dumps(tools, ensure_ascii=False))
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_json = [
|
||||
{"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions
|
||||
]
|
||||
function_texts = []
|
||||
for func in function_json:
|
||||
prompt = "\n<seed:tool_call>\n<function=" + func["func_name"]
|
||||
for key, value in func["func_key_values"].items():
|
||||
prompt += "\n<parameter=" + key + ">"
|
||||
if not isinstance(value, str):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
prompt += value + "</parameter>"
|
||||
prompt += "\n</function>\n</seed:tool_call>"
|
||||
function_texts.append(prompt)
|
||||
|
||||
return "\n".join(function_texts)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
results = []
|
||||
regex = re.compile(
|
||||
r"<seed:tool_call>\s*<function=\s*([^\s<]+)\s*(.*?)\s*</function>\s*</seed:tool_call>",
|
||||
re.DOTALL
|
||||
)
|
||||
for func_name, params_block in re.findall(regex, content):
|
||||
args_dict = {}
|
||||
param_pattern = re.compile(r"<parameter=(.*?)>(.*?)</parameter>", re.DOTALL)
|
||||
for key, raw_value in re.findall(param_pattern, params_block.strip()):
|
||||
value = raw_value.strip()
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_value = raw_value
|
||||
args_dict[key] = parsed_value
|
||||
|
||||
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
|
||||
|
||||
return results
|
||||
|
||||
TOOLS = {
|
||||
"default": DefaultToolUtils(),
|
||||
@ -354,6 +412,7 @@ TOOLS = {
|
||||
"mistral": MistralToolUtils(),
|
||||
"qwen": QwenToolUtils(),
|
||||
"glm4_moe": GLM4MOEToolUtils(),
|
||||
"seed_oss": SeedToolUtils(),
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user