diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 945856cb..2c7e11e2 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -92,9 +92,11 @@ def _process_request( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): - name = message.tool_calls[0].function.name - arguments = message.tool_calls[0].function.arguments - content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) + tool_calls = [ + {"name": tool_call.function.name, "argument": tool_call.function.arguments} + for tool_call in message.tool_calls + ] + content = json.dumps(tool_calls, ensure_ascii=False) input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) elif isinstance(message.content, list): for input_item in message.content: @@ -118,7 +120,7 @@ def _process_request( if isinstance(tool_list, list) and len(tool_list): try: tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) - except Exception: + except json.JSONDecodeError: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") else: tools = None @@ -160,17 +162,16 @@ async def create_chat_completion_response( choices = [] for i, response in enumerate(responses): if tools: - result = chat_model.engine.template.format_tools.extract(response.response_text) + result = chat_model.engine.template.extract_tool(response.response_text) else: result = response.response_text if isinstance(result, list): tool_calls = [] for tool in result: - name, arguments = tool - function = Function(name=name, arguments=arguments) - tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) - tool_calls.append(tool_call) + function = Function(name=tool[0], arguments=tool[1]) + tool_calls.append(FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)) + response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) finish_reason = Finish.TOOL else: diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index fa35df5b..70be6a5a 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -22,29 +22,20 @@ from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Uni SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] -JSON_FORMAT_PROMPT = ( - """, in a JSON format representing the kwargs (e.g. ```{"input": "hello world", "num_beams": 5}```)""" -) - - -TOOL_SYSTEM_PROMPT = ( +DEFAULT_TOOL_PROMPT = ( "You have access to the following tools:\n{tool_text}" "Use the following format if using a tool:\n" "```\n" "Action: tool name (one of [{tool_names}]).\n" - "Action Input: the input to the tool{format_prompt}.\n" + "Action Input: the input to the tool, in a JSON format representing the kwargs " + """(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n""" "```\n" ) -GLM4_TOOL_SUFFIX_PROMPT = ( - "在调用上述函数时,请使用 Json 格式表示调用的参数。" -) - GLM4_TOOL_PROMPT = ( - "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持," - "{tool_text}" - + "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。{tool_text}" ) @@ -73,32 +64,19 @@ def default_tool_formatter(tools: List[Dict[str, Any]]) -> str: ) tool_names.append(tool["name"]) - return TOOL_SYSTEM_PROMPT.format( - tool_text=tool_text, tool_names=", ".join(tool_names), format_prompt=JSON_FORMAT_PROMPT - ) + return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) -def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str: - tool_text = "" - for tool in tools: - tool_name = tool["name"] - tool_text += f"\n\n## {tool_name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n{GLM4_TOOL_SUFFIX_PROMPT}" - return GLM4_TOOL_PROMPT.format(tool_text=tool_text) - - def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: - regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL) - action_match = re.findall(regex, content) + regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|$)", re.DOTALL) + action_match: List[Tuple[str, str]] = re.findall(regex, content) if not action_match: return content results = [] - for match in action_match: - tool_name, tool_input = match - tool_name = tool_name.strip() - tool_input = tool_input.strip().strip('"').strip("```") - + tool_name = match[0].strip() + tool_input = match[1].strip().strip('"').strip("```") try: arguments = json.loads(tool_input) results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) @@ -108,19 +86,28 @@ def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: return results +def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( + name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) + ) + + return GLM4_TOOL_PROMPT.format(tool_text=tool_text) + + def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: - lines = content.strip().split("\n") - if len(lines) != 2: + if "\n" not in content: return content - tool_name = lines[0].strip() - tool_input = lines[1].strip() + + tool_name, tool_input = content.split("\n", maxsplit=1) try: arguments = json.loads(tool_input) except json.JSONDecodeError: return content + return [(tool_name, json.dumps(arguments, ensure_ascii=False))] - @dataclass class Formatter(ABC): @@ -193,22 +180,28 @@ class FunctionFormatter(Formatter): def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") + functions: List[Tuple[str, str]] = [] try: - function = json.loads(content) - name = function["name"] - arguments = json.dumps(function["arguments"], ensure_ascii=False) - except Exception: - name, arguments = "", "" + tool_calls = json.loads(content) + if not isinstance(tool_calls, list): # parallel function call + tool_calls = [tool_calls] + + for tool_call in tool_calls: + functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + + except json.JSONDecodeError: + functions = [] elements = [] - for slot in self.slots: - if isinstance(slot, str): - slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) - elements.append(slot) - elif isinstance(slot, (dict, set)): - elements.append(slot) - else: - raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + for name, arguments in functions: + for slot in self.slots: + if isinstance(slot, str): + slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) return elements @@ -216,29 +209,22 @@ class FunctionFormatter(Formatter): @dataclass class ToolFormatter(Formatter): def __post_init__(self): - if self.tool_format is None: + if self.tool_format == "default": + self._tool_formatter = default_tool_formatter + self._tool_extractor = default_tool_extractor + elif self.tool_format == "glm4": + self._tool_formatter = glm4_tool_formatter + self._tool_extractor = glm4_tool_extractor + else: raise ValueError("Tool format was not found.") def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") try: tools = json.loads(content) - if not len(tools): - return [""] - - if self.tool_format == "default": - return [default_tool_formatter(tools)] - elif self.tool_format == "glm4": - return [glm4_tool_formatter(tools)] - else: - raise NotImplementedError - except Exception: + return [self._tool_formatter(tools) if len(tools) != 0 else ""] + except json.JSONDecodeError: return [""] def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: - if self.tool_format == "default": - return default_tool_extractor(content) - elif self.tool_format == "glm4": - return glm4_tool_extractor(content) - else: - raise NotImplementedError + return self._tool_extractor(content) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index d97699b0..77694c59 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -79,6 +79,12 @@ class Template: """ return self._encode(tokenizer, messages, system, tools, cutoff_len, reserved_label_len) + def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: + r""" + Extracts tool message. + """ + return self.format_tools.extract(content) + def _encode( self, tokenizer: "PreTrainedTokenizer", @@ -100,7 +106,8 @@ class Template: if i == 0 and (system or tools or self.force_system): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" elements += self.format_system.apply(content=(system + tool_text)) - elif i > 0 and i % 2 == 0: + + if i > 0 and i % 2 == 0: elements += self.format_separator.apply() if message["role"] == Role.USER.value: @@ -191,7 +198,8 @@ class Llama2Template(Template): if i == 0 and (system or tools or self.force_system): tool_text = self.format_tools.apply(content=tools)[0] if tools else "" system_text = self.format_system.apply(content=(system + tool_text))[0] - elif i > 0 and i % 2 == 0: + + if i > 0 and i % 2 == 0: elements += self.format_separator.apply() if message["role"] == Role.USER.value: @@ -259,7 +267,9 @@ def _register_template( template_class = Llama2Template if name.startswith("llama2") else Template default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) - default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) + default_function_formatter = FunctionFormatter( + slots=["Action: {{name}}\nAction Input: {{arguments}}\n"] + eos_slots + ) default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() TEMPLATES[name] = template_class( diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 864c41c7..a2b54dce 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -140,16 +140,15 @@ class WebChatModel(ChatModel): ): response += new_text if tools: - result = self.engine.template.format_tools.extract(response) + result = self.engine.template.extract_tool(response) else: result = response - if isinstance(result, tuple): - name, arguments = result - arguments = json.loads(arguments) - tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False) - output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_call}] - bot_text = "```json\n" + tool_call + "\n```" + if isinstance(result, list): + tool_calls = [{"name": tool[0], "arguments": json.loads(tool[1])} for tool in result] + tool_calls = json.dumps(tool_calls, indent=4, ensure_ascii=False) + output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}] + bot_text = "```json\n" + tool_calls + "\n```" else: output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] bot_text = result diff --git a/tests/data/test_formatter.py b/tests/data/test_formatter.py new file mode 100644 index 00000000..430eb0e6 --- /dev/null +++ b/tests/data/test_formatter.py @@ -0,0 +1,125 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from llamafactory.data.formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter + + +def test_empty_formatter(): + formatter = EmptyFormatter(slots=["\n"]) + assert formatter.apply() == ["\n"] + + +def test_string_formatter(): + formatter = StringFormatter(slots=["", "Human: {{content}}\nAssistant:"]) + assert formatter.apply(content="Hi") == ["", "Human: Hi\nAssistant:"] + + +def test_function_formatter(): + formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"]) + tool_calls = json.dumps({"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}) + assert formatter.apply(content=tool_calls) == [ + """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""" + ] + + +def test_multi_function_formatter(): + formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}\n"]) + tool_calls = json.dumps([{"name": "tool_name", "arguments": {"foo": "bar", "size": 10}}] * 2) + assert formatter.apply(content=tool_calls) == [ + """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", + """Action: tool_name\nAction Input: {\"foo\": \"bar\", \"size\": 10}\n""", + ] + + +def test_default_tool_formatter(): + formatter = ToolFormatter(tool_format="default") + tools = [ + { + "name": "test_tool", + "description": "tool_desc", + "parameters": { + "type": "object", + "properties": { + "foo": {"type": "string", "description": "foo_desc"}, + "bar": {"type": "number", "description": "bar_desc"}, + }, + "required": ["foo"], + }, + } + ] + assert formatter.apply(content=json.dumps(tools)) == [ + "You have access to the following tools:\n" + "> Tool Name: test_tool\n" + "Tool Description: tool_desc\n" + "Tool Args:\n" + " - foo (string, required): foo_desc\n" + " - bar (number): bar_desc\n\n" + "Use the following format if using a tool:\n" + "```\n" + "Action: tool name (one of [test_tool]).\n" + "Action Input: the input to the tool, in a JSON format representing the kwargs " + """(e.g. ```{"input": "hello world", "num_beams": 5}```).\n""" + "```\n" + ] + + +def test_default_tool_extractor(): + formatter = ToolFormatter(tool_format="default") + result = """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")] + + +def test_default_multi_tool_extractor(): + formatter = ToolFormatter(tool_format="default") + result = ( + """Action: test_tool\nAction Input: {"foo": "bar", "size": 10}\n""" + """Action: another_tool\nAction Input: {"foo": "job", "size": 2}\n""" + ) + assert formatter.extract(result) == [ + ("test_tool", """{"foo": "bar", "size": 10}"""), + ("another_tool", """{"foo": "job", "size": 2}"""), + ] + + +def test_glm4_tool_formatter(): + formatter = ToolFormatter(tool_format="glm4") + tools = [ + { + "name": "test_tool", + "description": "tool_desc", + "parameters": { + "type": "object", + "properties": { + "foo": {"type": "string", "description": "foo_desc"}, + "bar": {"type": "number", "description": "bar_desc"}, + }, + "required": ["foo"], + }, + } + ] + assert formatter.apply(content=json.dumps(tools)) == [ + "你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。" + "\n\n## test_tool\n\n{}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( + json.dumps(tools[0], indent=4) + ) + ] + + +def test_glm4_tool_extractor(): + formatter = ToolFormatter(tool_format="glm4") + result = """test_tool\n{"foo": "bar", "size": 10}\n""" + assert formatter.extract(result) == [("test_tool", """{"foo": "bar", "size": 10}""")]