fix mm inference

Former-commit-id: 60fc6b926ead923dbb487b595ed4aa4cbfd94805
This commit is contained in:
hiyouga 2024-09-02 01:47:40 +08:00
parent f13e974930
commit b2a5f49a24
6 changed files with 19 additions and 23 deletions

View File

@ -20,8 +20,6 @@ import re
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
import numpy as np
from ..data import Role as DataRole
from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
@ -114,7 +112,7 @@ def _process_request(
else: # web uri
image_stream = requests.get(image_url, stream=True).raw
image = np.array(Image.open(image_stream).convert("RGB"))
image = Image.open(image_stream).convert("RGB")
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})

View File

@ -27,7 +27,7 @@ from .vllm_engine import VllmEngine
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image
from .base_engine import BaseEngine, Response
@ -56,7 +56,7 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
@ -67,7 +67,7 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
@ -77,7 +77,7 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
@ -93,7 +93,7 @@ class ChatModel:
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):

View File

@ -30,7 +30,7 @@ from .base_engine import BaseEngine, Response
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
from trl import PreTrainedModelWrapper
@ -78,7 +78,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]:
if image is not None:
@ -175,7 +175,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
@ -210,7 +210,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args(
@ -267,7 +267,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
if not self.can_generate:
@ -295,7 +295,7 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
if not self.can_generate:

View File

@ -39,7 +39,7 @@ if is_vllm_available():
if TYPE_CHECKING:
from numpy.typing import NDArray
from PIL.Image import Image
from transformers.image_processing_utils import BaseImageProcessor
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@ -111,7 +111,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncIterator["RequestOutput"]:
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
@ -195,7 +195,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> List["Response"]:
final_output = None
@ -221,7 +221,7 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
image: Optional["NDArray"] = None,
image: Optional["Image"] = None,
**input_kwargs,
) -> AsyncGenerator[str, None]:
generated_text = ""

View File

@ -14,9 +14,7 @@
import json
import os
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
from numpy.typing import NDArray
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
from ..chat import ChatModel
from ..data import Role
@ -134,7 +132,7 @@ class WebChatModel(ChatModel):
messages: Sequence[Dict[str, str]],
system: str,
tools: str,
image: Optional[NDArray],
image: Optional[Any],
max_new_tokens: int,
top_p: float,
temperature: float,

View File

@ -44,7 +44,7 @@ def create_chat_box(
tools = gr.Textbox(show_label=False, lines=3)
with gr.Column() as image_box:
image = gr.Image(sources=["upload"], type="numpy")
image = gr.Image(sources=["upload"], type="pil")
query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary")