From b2a5f49a24f8a1e96dc383280595ecf687f154d0 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 2 Sep 2024 01:47:40 +0800 Subject: [PATCH] fix mm inference Former-commit-id: 60fc6b926ead923dbb487b595ed4aa4cbfd94805 --- src/llamafactory/api/chat.py | 4 +--- src/llamafactory/chat/chat_model.py | 10 +++++----- src/llamafactory/chat/hf_engine.py | 12 ++++++------ src/llamafactory/chat/vllm_engine.py | 8 ++++---- src/llamafactory/webui/chatter.py | 6 ++---- src/llamafactory/webui/components/chatbot.py | 2 +- 6 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index da055ac9..ee5c7412 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -20,8 +20,6 @@ import re import uuid from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple -import numpy as np - from ..data import Role as DataRole from ..extras.logging import get_logger from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available @@ -114,7 +112,7 @@ def _process_request( else: # web uri image_stream = requests.get(image_url, stream=True).raw - image = np.array(Image.open(image_stream).convert("RGB")) + image = Image.open(image_stream).convert("RGB") else: input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 3ea3b44f..6df83b57 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -27,7 +27,7 @@ from .vllm_engine import VllmEngine if TYPE_CHECKING: - from numpy.typing import NDArray + from PIL.Image import Image from .base_engine import BaseEngine, Response @@ -56,7 +56,7 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, **input_kwargs, ) -> List["Response"]: task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop) @@ -67,7 +67,7 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, **input_kwargs, ) -> List["Response"]: return await self.engine.chat(messages, system, tools, image, **input_kwargs) @@ -77,7 +77,7 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, **input_kwargs, ) -> Generator[str, None, None]: generator = self.astream_chat(messages, system, tools, image, **input_kwargs) @@ -93,7 +93,7 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs): diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 7be6c2ef..42f8f200 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -30,7 +30,7 @@ from .base_engine import BaseEngine, Response if TYPE_CHECKING: - from numpy.typing import NDArray + from PIL.Image import Image from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin from trl import PreTrainedModelWrapper @@ -78,7 +78,7 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Tuple[Dict[str, Any], int]: if image is not None: @@ -175,7 +175,7 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> List["Response"]: gen_kwargs, prompt_length = HuggingfaceEngine._process_args( @@ -210,7 +210,7 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Callable[[], str]: gen_kwargs, _ = HuggingfaceEngine._process_args( @@ -267,7 +267,7 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, **input_kwargs, ) -> List["Response"]: if not self.can_generate: @@ -295,7 +295,7 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: if not self.can_generate: diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 9220230f..ed4613d0 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -39,7 +39,7 @@ if is_vllm_available(): if TYPE_CHECKING: - from numpy.typing import NDArray + from PIL.Image import Image from transformers.image_processing_utils import BaseImageProcessor from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -111,7 +111,7 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = "chatcmpl-{}".format(uuid.uuid4().hex) @@ -195,7 +195,7 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, **input_kwargs, ) -> List["Response"]: final_output = None @@ -221,7 +221,7 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["NDArray"] = None, + image: Optional["Image"] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: generated_text = "" diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 0b8267ed..d4366982 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -14,9 +14,7 @@ import json import os -from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple - -from numpy.typing import NDArray +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple from ..chat import ChatModel from ..data import Role @@ -134,7 +132,7 @@ class WebChatModel(ChatModel): messages: Sequence[Dict[str, str]], system: str, tools: str, - image: Optional[NDArray], + image: Optional[Any], max_new_tokens: int, top_p: float, temperature: float, diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py index ad74114b..d7fc58f5 100644 --- a/src/llamafactory/webui/components/chatbot.py +++ b/src/llamafactory/webui/components/chatbot.py @@ -44,7 +44,7 @@ def create_chat_box( tools = gr.Textbox(show_label=False, lines=3) with gr.Column() as image_box: - image = gr.Image(sources=["upload"], type="numpy") + image = gr.Image(sources=["upload"], type="pil") query = gr.Textbox(show_label=False, lines=8) submit_btn = gr.Button(variant="primary")