fix mixed mm inputs and rlhf-v

This commit is contained in:
hiyouga
2024-09-01 20:52:47 +08:00
parent b5063b4144
commit 9967ccb3ae
20 changed files with 306 additions and 277 deletions

View File

@@ -16,16 +16,16 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
from .data_utils import SLOTS
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
from .tool_utils import get_tool_utils
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Optional[Literal["default", "glm4"]] = None
tool_format: Optional[str] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
@@ -81,12 +81,7 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self.slots = DefaultToolUtils.get_function_slots() + self.slots
elif self.tool_format == "glm4":
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
@@ -119,22 +114,15 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format == "default":
self._tool_formatter = DefaultToolUtils.tool_formatter
self._tool_extractor = DefaultToolUtils.tool_extractor
elif self.tool_format == "glm4":
self._tool_formatter = GLM4ToolUtils.tool_formatter
self._tool_extractor = GLM4ToolUtils.tool_extractor
else:
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
self.tool_utils = get_tool_utils(self.tool_format)
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
tools = json.loads(content)
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
except json.JSONDecodeError:
return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
return self._tool_extractor(content)
return self.tool_utils.tool_extractor(content)