diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 05db57a7..c467a3e6 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -168,7 +168,7 @@ async def create_chat_completion_response( if isinstance(result, list): tool_calls = [] for tool in result: - function = Function(name=tool[0], arguments=tool[1]) + function = Function(name=tool.name, arguments=tool.arguments) tool_calls.append(FunctionCall(id=f"call_{uuid.uuid4().hex}", function=function)) response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=tool_calls) diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 28bf3fb1..f6c24468 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -16,16 +16,12 @@ import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import List, Optional, Union from typing_extensions import override from .data_utils import SLOTS -from .tool_utils import get_tool_utils - - -if TYPE_CHECKING: - from .tool_utils import FunctionCall +from .tool_utils import FunctionCall, get_tool_utils @dataclass @@ -98,19 +94,21 @@ class StringFormatter(Formatter): @dataclass class FunctionFormatter(Formatter): def __post_init__(self): - self.function_slots = get_tool_utils(self.tool_format).get_function_slots() + self.tool_utils = get_tool_utils(self.tool_format) @override def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") - functions: List[Tuple[str, str]] = [] + functions: List["FunctionCall"] = [] try: tool_calls = json.loads(content) if not isinstance(tool_calls, list): # parallel function call tool_calls = [tool_calls] for tool_call in tool_calls: - functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + functions.append( + FunctionCall(tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)) + ) except json.JSONDecodeError: raise RuntimeError(f"Invalid JSON format in function message: {str([content])}") # flat string @@ -118,15 +116,7 @@ class FunctionFormatter(Formatter): elements = [] for slot in self.slots: if slot == "{{content}}": - for name, arguments in functions: - for slot in self.function_slots: - if isinstance(slot, str): - slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) - elements.append(slot) - elif isinstance(slot, (dict, set)): - elements.append(slot) - else: - raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}") + elements += self.tool_utils.function_formatter(functions) else: elements.append(slot) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index cfc4a2cb..3297be39 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from ..hparams import DataArguments from .formatter import SLOTS, Formatter from .mm_plugin import BasePlugin + from .tool_utils import FunctionCall logger = logging.get_logger(__name__) @@ -83,7 +84,7 @@ class Template: encoded_messages = self._encode(tokenizer, messages, system, tools) return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] - def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: + def extract_tool(self, content: str) -> Union[str, List["FunctionCall"]]: r""" Extracts tool message. """ @@ -244,7 +245,7 @@ def _register_template( ) ``` """ - template_class = Llama2Template if name.startswith("llama2") else Template + template_class = Llama2Template if any(k in name for k in ("llama2", "mistral")) else Template default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}] default_user_formatter = StringFormatter(slots=["{{content}}"]) default_assistant_formatter = StringFormatter(slots=default_slots) @@ -854,7 +855,11 @@ _register_template( # copied from mistral template _register_template( name="llava_next_mistral", - format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), + format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), + format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), + format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), + format_tools=ToolFormatter(tool_format="mistral"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), mm_plugin=get_mm_plugin(name="llava_next", image_token=""), ) @@ -902,7 +907,11 @@ _register_template( # copied from mistral template _register_template( name="llava_next_video_mistral", - format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), + format_assistant=StringFormatter(slots=[" {{content}}", {"eos_token"}]), + format_function=FunctionFormatter(slots=["[TOOL_CALLS] ", "{{content}}", {"eos_token"}], tool_format="mistral"), + format_observation=StringFormatter(slots=["""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""]), + format_tools=ToolFormatter(tool_format="mistral"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), mm_plugin=get_mm_plugin(name="llava_next_video", image_token="", video_token="