From 8ecc12ee2af7c79b5dbe6f43d1ffc05fa900c8f3 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 1 Nov 2024 07:25:20 +0000 Subject: [PATCH] support multiimage inference Former-commit-id: e80a4819274d46ac9e85db7469dc59d7c4e323c7 --- src/llamafactory/api/chat.py | 17 ++++---- src/llamafactory/chat/base_engine.py | 8 ++-- src/llamafactory/chat/chat_model.py | 24 +++++------ src/llamafactory/chat/hf_engine.py | 64 ++++++++++++++++++---------- src/llamafactory/chat/vllm_engine.py | 36 +++++++++------- src/llamafactory/data/mm_plugin.py | 8 ++++ src/llamafactory/webui/chatter.py | 9 +++- 7 files changed, 103 insertions(+), 63 deletions(-) diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index f20e588e..ec3201c3 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -69,7 +69,7 @@ ROLE_MAPPING = { def _process_request( request: "ChatCompletionRequest", -) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]: +) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]: logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") if len(request.messages) == 0: @@ -84,7 +84,7 @@ def _process_request( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") input_messages = [] - image = None + images = [] for i, message in enumerate(request.messages): if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") @@ -111,10 +111,11 @@ def _process_request( else: # web uri image_stream = requests.get(image_url, stream=True).raw - image = Image.open(image_stream).convert("RGB") + images.append(Image.open(image_stream).convert("RGB")) else: input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) + images = None if len(images) == 0 else images tool_list = request.tools if isinstance(tool_list, list) and len(tool_list): try: @@ -124,7 +125,7 @@ def _process_request( else: tools = None - return input_messages, system, tools, image + return input_messages, system, tools, images def _create_stream_chat_completion_chunk( @@ -143,12 +144,12 @@ async def create_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> "ChatCompletionResponse": completion_id = f"chatcmpl-{uuid.uuid4().hex}" - input_messages, system, tools, image = _process_request(request) + input_messages, system, tools, images = _process_request(request) responses = await chat_model.achat( input_messages, system, tools, - image, + images, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, @@ -194,7 +195,7 @@ async def create_stream_chat_completion_response( request: "ChatCompletionRequest", chat_model: "ChatModel" ) -> AsyncGenerator[str, None]: completion_id = f"chatcmpl-{uuid.uuid4().hex}" - input_messages, system, tools, image = _process_request(request) + input_messages, system, tools, images = _process_request(request) if tools: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") @@ -208,7 +209,7 @@ async def create_stream_chat_completion_response( input_messages, system, tools, - image, + images, do_sample=request.do_sample, temperature=request.temperature, top_p=request.top_p, diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py index 7087c4e5..700e1eef 100644 --- a/src/llamafactory/chat/base_engine.py +++ b/src/llamafactory/chat/base_engine.py @@ -66,8 +66,8 @@ class BaseEngine(ABC): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> List["Response"]: r""" @@ -81,8 +81,8 @@ class BaseEngine(ABC): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: r""" diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 4f0e1b83..28f5e439 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -64,15 +64,15 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> List["Response"]: r""" Gets a list of responses of the chat model. """ task = asyncio.run_coroutine_threadsafe( - self.achat(messages, system, tools, image, video, **input_kwargs), self._loop + self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop ) return task.result() @@ -81,28 +81,28 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> List["Response"]: r""" Asynchronously gets a list of responses of the chat model. """ - return await self.engine.chat(messages, system, tools, image, video, **input_kwargs) + return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs) def stream_chat( self, messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> Generator[str, None, None]: r""" Gets the response token-by-token of the chat model. """ - generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs) + generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs) while True: try: task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) @@ -115,14 +115,14 @@ class ChatModel: messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: r""" Asynchronously gets the response token-by-token of the chat model. """ - async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs): + async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs): yield new_token def get_scores( diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 909f8161..5340587c 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Tuple[Dict[str, Any], int]: mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} - if image is not None: - mm_input_dict.update({"images": [image], "imglens": [1]}) + if images is not None: + mm_input_dict.update({"images": images, "imglens": [len(images)]}) if IMAGE_PLACEHOLDER not in messages[0]["content"]: - messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] - if video is not None: - mm_input_dict.update({"videos": [video], "vidlens": [1]}) + if videos is not None: + mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]}) if VIDEO_PLACEHOLDER not in messages[0]["content"]: - messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"] + messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"] messages = template.mm_plugin.process_messages( messages, mm_input_dict["images"], mm_input_dict["videos"], processor @@ -186,12 +186,22 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> List["Response"]: gen_kwargs, prompt_length = HuggingfaceEngine._process_args( - model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs + model, + tokenizer, + processor, + template, + generating_args, + messages, + system, + tools, + images, + videos, + input_kwargs, ) generate_output = model.generate(**gen_kwargs) response_ids = generate_output[:, prompt_length:] @@ -222,12 +232,22 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Callable[[], str]: gen_kwargs, _ = HuggingfaceEngine._process_args( - model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs + model, + tokenizer, + processor, + template, + generating_args, + messages, + system, + tools, + images, + videos, + input_kwargs, ) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer @@ -270,8 +290,8 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> List["Response"]: if not self.can_generate: @@ -287,8 +307,8 @@ class HuggingfaceEngine(BaseEngine): messages, system, tools, - image, - video, + images, + videos, input_kwargs, ) async with self.semaphore: @@ -301,8 +321,8 @@ class HuggingfaceEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: if not self.can_generate: @@ -318,8 +338,8 @@ class HuggingfaceEngine(BaseEngine): messages, system, tools, - image, - video, + images, + videos, input_kwargs, ) async with self.semaphore: diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 2b5a32b4..e228aba4 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -101,14 +101,14 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = f"chatcmpl-{uuid.uuid4().hex}" - if image is not None: + if images is not None: if IMAGE_PLACEHOLDER not in messages[0]["content"]: - messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] + messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] @@ -157,14 +157,18 @@ class VllmEngine(BaseEngine): skip_special_tokens=True, ) - if image is not None: # add image features - if not isinstance(image, (str, ImageObject)): - raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.") + if images is not None: # add image features + image_data = [] + for image in images: + if not isinstance(image, (str, ImageObject)): + raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.") - if isinstance(image, str): - image = Image.open(image).convert("RGB") + if isinstance(image, str): + image = Image.open(image).convert("RGB") - multi_modal_data = {"image": image} + image_data.append(image) + + multi_modal_data = {"image": image_data} else: multi_modal_data = None @@ -182,12 +186,12 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> List["Response"]: final_output = None - generator = await self._generate(messages, system, tools, image, video, **input_kwargs) + generator = await self._generate(messages, system, tools, images, videos, **input_kwargs) async for request_output in generator: final_output = request_output @@ -210,12 +214,12 @@ class VllmEngine(BaseEngine): messages: Sequence[Dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - image: Optional["ImageInput"] = None, - video: Optional["VideoInput"] = None, + images: Optional[Sequence["ImageInput"]] = None, + videos: Optional[Sequence["VideoInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: generated_text = "" - generator = await self._generate(messages, system, tools, image, video, **input_kwargs) + generator = await self._generate(messages, system, tools, images, videos, **input_kwargs) async for result in generator: delta_text = result.outputs[0].text[len(generated_text) :] generated_text = result.outputs[0].text diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 4e096c83..f6748883 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -226,6 +226,14 @@ class BasePlugin: ) -> Dict[str, Union[List[int], "torch.Tensor"]]: r""" Builds batched multimodal inputs for VLMs. + + Arguments: + images: a list of image inputs, shape (num_images,) + videos: a list of video inputs, shape (num_videos,) + imglens: number of images in each sample, shape (batch_size,) + vidlens: number of videos in each sample, shape (batch_size,) + seqlens: number of tokens in each sample, shape (batch_size,) + processor: a processor for pre-processing images and videos """ self._validate_input(images, videos) return {} diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 7512887b..78ef3efc 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -141,7 +141,14 @@ class WebChatModel(ChatModel): chatbot[-1][1] = "" response = "" for new_text in self.stream_chat( - messages, system, tools, image, video, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature + messages, + system, + tools, + images=[image], + videos=[video], + max_new_tokens=max_new_tokens, + top_p=top_p, + temperature=temperature, ): response += new_text if tools: