mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
fix mixed mm inputs and rlhf-v
This commit is contained in:
@@ -16,16 +16,16 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from .data_utils import SLOTS
|
||||
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
|
||||
from .tool_utils import get_tool_utils
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[Literal["default", "glm4"]] = None
|
||||
tool_format: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
@@ -81,12 +81,7 @@ class StringFormatter(Formatter):
|
||||
@dataclass
|
||||
class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
if self.tool_format == "default":
|
||||
self.slots = DefaultToolUtils.get_function_slots() + self.slots
|
||||
elif self.tool_format == "glm4":
|
||||
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
|
||||
else:
|
||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
@@ -119,22 +114,15 @@ class FunctionFormatter(Formatter):
|
||||
@dataclass
|
||||
class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
if self.tool_format == "default":
|
||||
self._tool_formatter = DefaultToolUtils.tool_formatter
|
||||
self._tool_extractor = DefaultToolUtils.tool_extractor
|
||||
elif self.tool_format == "glm4":
|
||||
self._tool_formatter = GLM4ToolUtils.tool_formatter
|
||||
self._tool_extractor = GLM4ToolUtils.tool_extractor
|
||||
else:
|
||||
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
||||
self.tool_utils = get_tool_utils(self.tool_format)
|
||||
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
tools = json.loads(content)
|
||||
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
return [self.tool_utils.tool_formatter(tools) if len(tools) != 0 else ""]
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
return self._tool_extractor(content)
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
||||
Reference in New Issue
Block a user