fix mm inference

This commit is contained in:
hiyouga
2024-09-02 01:47:40 +08:00
parent 3a6f19f017
commit 60fc6b926e
6 changed files with 19 additions and 23 deletions

View File

@@ -30,7 +30,7 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from trl import PreTrainedModelWrapper
@@ -78,7 +78,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if image is not None:
@@ -175,7 +175,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
@@ -210,7 +210,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
@@ -267,7 +267,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@@ -295,7 +295,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate: