fix mixed mm inputs and rlhf-v

Former-commit-id: 7c248fac20bf85d57a91132ce7a793c7f84e9218
This commit is contained in:
hiyouga
2024-09-01 20:52:47 +08:00
parent 1d8e9c7897
commit 7e4c5d4bb3
20 changed files with 306 additions and 277 deletions

View File

@@ -16,16 +16,16 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
from .tool_utils import get_tool_utils
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default", "glm4"]] = None
tool_format: Optional[str] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
@@ -81,12 +81,7 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
@@ -119,22 +114,15 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.tool_utils = get_tool_utils(self.tool_format)
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
return self._tool_extractor(content)
return self.tool_utils.tool_extractor(content)