fix mm inference

This commit is contained in:
hiyouga
2024-09-02 01:47:40 +08:00
parent 3a6f19f017
commit 60fc6b926e
6 changed files with 19 additions and 23 deletions

View File

@@ -39,7 +39,7 @@ if is_vllm_available():
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@@ -111,7 +111,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
@@ -195,7 +195,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
@@ -221,7 +221,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""