support vllm

This commit is contained in:
hiyouga
2024-03-07 20:26:31 +08:00
parent f74f804a71
commit d07ad5cc1c
32 changed files with 752 additions and 316 deletions

View File

@@ -2,7 +2,7 @@ import json
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Sequence, Set, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
@@ -72,7 +72,7 @@ def default_tool_extractor(content: str) -> Union[str, Tuple[str, str]]:
@dataclass
class Formatter(ABC):
slots: SLOTS = field(default_factory=list)
tool_format: Literal["default"] = "default"
tool_format: Optional[Literal["default"]] = None
@abstractmethod
def apply(self, **kwargs) -> SLOTS: ...
@@ -83,12 +83,30 @@ class Formatter(ABC):
@dataclass
class EmptyFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.")
def apply(self, **kwargs) -> SLOTS:
return self.slots
@dataclass
class StringFormatter(Formatter):
def __post_init__(self):
has_placeholder = False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
has_placeholder = True
if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.")
def apply(self, **kwargs) -> SLOTS:
elements = []
for slot in self.slots:
@@ -109,6 +127,17 @@ class StringFormatter(Formatter):
@dataclass
class FunctionFormatter(Formatter):
def __post_init__(self):
has_name, has_args = False, False
for slot in filter(lambda s: isinstance(s, str), self.slots):
if "{{name}}" in slot:
has_name = True
if "{{arguments}}" in slot:
has_args = True
if not has_name or not has_args:
raise ValueError("Name and arguments placeholders are required in the function formatter.")
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try:
@@ -133,6 +162,10 @@ class FunctionFormatter(Formatter):
@dataclass
class ToolFormatter(Formatter):
def __post_init__(self):
if self.tool_format is None:
raise ValueError("Tool format was not found.")
def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content")
try: