mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
Optimize the handling of QWEN2 in scenarios involving multiple tool calls.
Former-commit-id: 950e360ca00c29febadc14d5995de7d57b5c43a7
This commit is contained in:
parent
3f11ab800f
commit
b6d63b3324
@ -150,11 +150,14 @@ async def create_chat_completion_response(
|
||||
else:
|
||||
result = response.response_text
|
||||
|
||||
if isinstance(result, tuple):
|
||||
name, arguments = result
|
||||
if isinstance(result, list):
|
||||
tool_calls = []
|
||||
for tool in result:
|
||||
name, arguments = tool
|
||||
function = Function(name=name, arguments=arguments)
|
||||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function)
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call])
|
||||
tool_calls.append(tool_call)
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls)
|
||||
finish_reason = Finish.TOOL
|
||||
else:
|
||||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||
|
@ -72,23 +72,29 @@ def glm4_tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
|
||||
def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+).*?Action Input:\s*(.*)", re.DOTALL)
|
||||
action_match = re.search(regex, content)
|
||||
def default_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*({.*?})(?=\nAction:|\Z)", re.DOTALL)
|
||||
action_match = re.findall(regex, content)
|
||||
if not action_match:
|
||||
return content
|
||||
|
||||
tool_name = action_match.group(1).strip()
|
||||
tool_input = action_match.group(2).strip().strip('"').strip("```")
|
||||
results = []
|
||||
|
||||
for match in action_match:
|
||||
tool_name, tool_input = match
|
||||
tool_name = tool_name.strip()
|
||||
tool_input = tool_input.strip().strip('"').strip("```")
|
||||
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
||||
return results
|
||||
|
||||
|
||||
def glm4_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||
def glm4_tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
lines = content.strip().split("\n")
|
||||
if len(lines) != 2:
|
||||
return content
|
||||
@ -98,7 +104,7 @@ def glm4_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
|
||||
arguments = json.loads(tool_input)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
return tool_name, json.dumps(arguments, ensure_ascii=False)
|
||||
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
||||
|
||||
|
||||
|
||||
@ -110,7 +116,7 @@ class Formatter(ABC):
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -215,7 +221,7 @@ class ToolFormatter(Formatter):
|
||||
except Exception:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
if self.tool_format == "default":
|
||||
return default_tool_extractor(content)
|
||||
elif self.tool_format == "glm4":
|
||||
|
Loading…
x
Reference in New Issue
Block a user