mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
support mllm hf inference
This commit is contained in:
@@ -8,6 +8,8 @@ from .vllm_engine import VllmEngine
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
@@ -36,9 +38,10 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, image, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def achat(
|
||||
@@ -46,18 +49,20 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
return await self.engine.chat(messages, system, tools, **input_kwargs)
|
||||
return await self.engine.chat(messages, system, tools, image, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
||||
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||
@@ -70,9 +75,10 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, **input_kwargs):
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
def get_scores(
|
||||
|
||||
Reference in New Issue
Block a user