mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix mm inference
Former-commit-id: 60fc6b926ead923dbb487b595ed4aa4cbfd94805
This commit is contained in:
parent
f13e974930
commit
b2a5f49a24
@ -20,8 +20,6 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from ..data import Role as DataRole
|
from ..data import Role as DataRole
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||||
@ -114,7 +112,7 @@ def _process_request(
|
|||||||
else: # web uri
|
else: # web uri
|
||||||
image_stream = requests.get(image_url, stream=True).raw
|
image_stream = requests.get(image_url, stream=True).raw
|
||||||
|
|
||||||
image = np.array(Image.open(image_stream).convert("RGB"))
|
image = Image.open(image_stream).convert("RGB")
|
||||||
else:
|
else:
|
||||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ from .vllm_engine import VllmEngine
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
from PIL.Image import Image
|
||||||
|
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
|
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
|
||||||
@ -67,7 +67,7 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
|
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
|
||||||
@ -77,7 +77,7 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
||||||
@ -93,7 +93,7 @@ class ChatModel:
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
|
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
|
||||||
|
@ -30,7 +30,7 @@ from .base_engine import BaseEngine, Response
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
from PIL.Image import Image
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
from trl import PreTrainedModelWrapper
|
from trl import PreTrainedModelWrapper
|
||||||
|
|
||||||
@ -78,7 +78,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
if image is not None:
|
if image is not None:
|
||||||
@ -175,7 +175,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||||
@ -210,7 +210,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> Callable[[], str]:
|
) -> Callable[[], str]:
|
||||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||||
@ -267,7 +267,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
@ -295,7 +295,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
if not self.can_generate:
|
if not self.can_generate:
|
||||||
|
@ -39,7 +39,7 @@ if is_vllm_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
from PIL.Image import Image
|
||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
|
|
||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
@ -111,7 +111,7 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncIterator["RequestOutput"]:
|
) -> AsyncIterator["RequestOutput"]:
|
||||||
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
request_id = "chatcmpl-{}".format(uuid.uuid4().hex)
|
||||||
@ -195,7 +195,7 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
final_output = None
|
final_output = None
|
||||||
@ -221,7 +221,7 @@ class VllmEngine(BaseEngine):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
image: Optional["NDArray"] = None,
|
image: Optional["Image"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
|
@ -14,9 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from ..data import Role
|
from ..data import Role
|
||||||
@ -134,7 +132,7 @@ class WebChatModel(ChatModel):
|
|||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
image: Optional[NDArray],
|
image: Optional[Any],
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
|
@ -44,7 +44,7 @@ def create_chat_box(
|
|||||||
tools = gr.Textbox(show_label=False, lines=3)
|
tools = gr.Textbox(show_label=False, lines=3)
|
||||||
|
|
||||||
with gr.Column() as image_box:
|
with gr.Column() as image_box:
|
||||||
image = gr.Image(sources=["upload"], type="numpy")
|
image = gr.Image(sources=["upload"], type="pil")
|
||||||
|
|
||||||
query = gr.Textbox(show_label=False, lines=8)
|
query = gr.Textbox(show_label=False, lines=8)
|
||||||
submit_btn = gr.Button(variant="primary")
|
submit_btn = gr.Button(variant="primary")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user