mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
add docstrings, refactor logger
This commit is contained in:
@@ -16,21 +16,36 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
from .tool_utils import get_tool_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .tool_utils import FunctionCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
r"""
|
||||
Forms a list of slots according to the inputs to encode.
|
||||
"""
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extract a list of tuples from the response message if using tools.
|
||||
|
||||
Each tuple consists of function name and function arguments.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
|
||||
if has_placeholder:
|
||||
raise ValueError("Empty formatter should not contain any placeholder.")
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
return self.slots
|
||||
|
||||
@@ -60,6 +76,7 @@ class StringFormatter(Formatter):
|
||||
if not has_placeholder:
|
||||
raise ValueError("A placeholder is required in the string formatter.")
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
@@ -83,6 +100,7 @@ class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
functions: List[Tuple[str, str]] = []
|
||||
@@ -116,6 +134,7 @@ class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
self.tool_utils = get_tool_utils(self.tool_format)
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
@@ -124,5 +143,6 @@ class ToolFormatter(Formatter):
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
||||
Reference in New Issue
Block a user