mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix mm inference
Former-commit-id: 60fc6b926ead923dbb487b595ed4aa4cbfd94805
This commit is contained in:
parent
f13e974930
commit
b2a5f49a24
@ -20,8 +20,6 @@ import re
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..data import Role as DataRole
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||
@ -114,7 +112,7 @@ def _process_request(
|
||||
else: # web uri
|
||||
image_stream = requests.get(image_url, stream=True).raw
|
||||
|
||||
image = np.array(Image.open(image_stream).convert("RGB"))
|
||||
image = Image.open(image_stream).convert("RGB")
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
|
@ -27,7 +27,7 @@ from .vllm_engine import VllmEngine
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image
|
||||
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
@ -56,7 +56,7 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
|
||||
@ -67,7 +67,7 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
|
||||
@ -77,7 +77,7 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
||||
@ -93,7 +93,7 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
|
||||
|
@ -30,7 +30,7 @@ from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
@ -78,7 +78,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
if image is not None:
|
||||
@ -175,7 +175,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
@ -210,7 +210,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
@ -267,7 +267,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@ -295,7 +295,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
|
@ -39,7 +39,7 @@ if is_vllm_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
from PIL.Image import Image
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
@ -111,7 +111,7 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||
@ -195,7 +195,7 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
final_output = None
|
||||
@ -221,7 +221,7 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
image: Optional["Image"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generated_text = ""
|
||||
|
@ -14,9 +14,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data import Role
|
||||
@ -134,7 +132,7 @@ class WebChatModel(ChatModel):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: str,
|
||||
tools: str,
|
||||
image: Optional[NDArray],
|
||||
image: Optional[Any],
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float,
|
||||
|
@ -44,7 +44,7 @@ def create_chat_box(
|
||||
tools = gr.Textbox(show_label=False, lines=3)
|
||||
|
||||
with gr.Column() as image_box:
|
||||
image = gr.Image(sources=["upload"], type="numpy")
|
||||
image = gr.Image(sources=["upload"], type="pil")
|
||||
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
Loading…
x
Reference in New Issue
Block a user