From 4ee9efbd989849fc63597a3a6d06eba6829a13a8 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sat, 7 Sep 2024 01:21:14 +0800 Subject: [PATCH] fix #5384 Former-commit-id: 36665f3001647b8411ba0e256d1e64eb157abfaf --- src/llamafactory/chat/vllm_engine.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 7d34965a..5ac26623 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -39,9 +39,9 @@ if is_vllm_available(): if TYPE_CHECKING: - from PIL.Image import Image from transformers.image_processing_utils import BaseImageProcessor + from ..data.mm_plugin import ImageInput, VideoInput from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -111,7 +111,8 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = "chatcmpl-{}".format(uuid.uuid4().hex) @@ -195,11 +196,12 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> List["Response"]: final_output = None - generator = await self._generate(messages, system, tools, image, **input_kwargs) + generator = await self._generate(messages, system, tools, image, video, **input_kwargs) async for request_output in generator: final_output = request_output @@ -221,11 +223,12 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["Image"] = None, + image: Optional["ImageInput"] = None, + video: Optional["VideoInput"] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: generated_text = "" - generator = await self._generate(messages, system, tools, image, **input_kwargs) + generator = await self._generate(messages, system, tools, image, video, **input_kwargs) async for result in generator: delta_text = result.outputs[0].text[len(generated_text) :] generated_text = result.outputs[0].text