From b6d63b33248b2762f1e794ea301cbeb9c3ce9774 Mon Sep 17 00:00:00 2001 From: mMrBun <2015711377@qq.com> Date: Mon, 10 Jun 2024 02:00:14 +0800 Subject: [PATCH] Optimize the handling of QWEN2 in scenarios involving multiple tool calls. Former-commit-id: 950e360ca00c29febadc14d5995de7d57b5c43a7 --- src/llamafactory/api/chat.py | 13 +++++++----- src/llamafactory/data/formatter.py | 34 ++++++++++++++++++------------ 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 98957bc1..d4db1eea 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -150,11 +150,14 @@ async def create_chat_completion_response( else: result = response.response_text - if isinstance(result, tuple): - name, arguments = result - function = Function(name=name, arguments=arguments) - tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) - response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call]) + if isinstance(result, list): + tool_calls = [] + for tool in result: + name, arguments = tool + function = Function(name=name, arguments=arguments) + tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) + tool_calls.append(tool_call) + response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) finish_reason = Finish.TOOL else: response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 9f58915b..1d917887 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -72,23 +72,29 @@ def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str: return GLM4_TOOL_PROMPT.format(tool_text=tool_text) -def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: - regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL) - action_match = re.search(regex, content) +def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL) + action_match = re.findall(regex, content) if not action_match: return content - tool_name = action_match.group(1).strip() - tool_input = action_match.group(2).strip().strip('"').strip("```") - try: - arguments = json.loads(tool_input) - except json.JSONDecodeError: - return content + results = [] + + for match in action_match: + tool_name, tool_input = match + tool_name = tool_name.strip() + tool_input = tool_input.strip().strip('"').strip("```") - return tool_name, json.dumps(arguments, ensure_ascii=False) + try: + arguments = json.loads(tool_input) + results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) + except json.JSONDecodeError: + return content + + return results -def glm4_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: +def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: lines = content.strip().split("\n") if len(lines) != 2: return content @@ -98,7 +104,7 @@ def glm4_tool_extractor(content: str) -> Union[str, Tuple[str, str]]: arguments = json.loads(tool_input) except json.JSONDecodeError: return content - return tool_name, json.dumps(arguments, ensure_ascii=False) + return [(tool_name, json.dumps(arguments, ensure_ascii=False))] @@ -110,7 +116,7 @@ class Formatter(ABC): @abstractmethod def apply(self, **kwargs) -> SLOTS: ... - def extract(self, content: str) -> Union[str, Tuple[str, str]]: + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: raise NotImplementedError @@ -215,7 +221,7 @@ class ToolFormatter(Formatter): except Exception: return [""] - def extract(self, content: str) -> Union[str, Tuple[str, str]]: + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: if self.tool_format == "default": return default_tool_extractor(content) elif self.tool_format == "glm4":