support multiimage inference

Former-commit-id: e80a4819274d46ac9e85db7469dc59d7c4e323c7
This commit is contained in:
hiyouga 2024-11-01 07:25:20 +00:00
parent 9108df2b97
commit 8ecc12ee2a
7 changed files with 103 additions and 63 deletions

View File

@ -69,7 +69,7 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0: if len(request.messages) == 0:
@ -84,7 +84,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = [] input_messages = []
image = None images = []
for i, message in enumerate(request.messages): for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@ -111,10 +111,11 @@ def _process_request(
else: # web uri else: # web uri
image_stream = requests.get(image_url, stream=True).raw image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_stream).convert("RGB") images.append(Image.open(image_stream).convert("RGB"))
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
images = None if len(images) == 0 else images
tool_list = request.tools tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list): if isinstance(tool_list, list) and len(tool_list):
try: try:
@ -124,7 +125,7 @@ def _process_request(
else: else:
tools = None tools = None
return input_messages, system, tools, image return input_messages, system, tools, images
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(
@ -143,12 +144,12 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse": ) -> "ChatCompletionResponse":
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
responses = await chat_model.achat( responses = await chat_model.achat(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
@ -194,7 +195,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel" request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
completion_id = f"chatcmpl-{uuid.uuid4().hex}" completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request) input_messages, system, tools, images = _process_request(request)
if tools: if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@ -208,7 +209,7 @@ async def create_stream_chat_completion_response(
input_messages, input_messages,
system, system,
tools, tools,
image, images,
do_sample=request.do_sample, do_sample=request.do_sample,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,

View File

@ -66,8 +66,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
@ -81,8 +81,8 @@ class BaseEngine(ABC):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""

View File

@ -64,15 +64,15 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Gets a list of responses of the chat model. Gets a list of responses of the chat model.
""" """
task = asyncio.run_coroutine_threadsafe( task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
) )
return task.result() return task.result()
@ -81,28 +81,28 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r""" r"""
Asynchronously gets a list of responses of the chat model. Asynchronously gets a list of responses of the chat model.
""" """
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs) return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
def stream_chat( def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r""" r"""
Gets the response token-by-token of the chat model. Gets the response token-by-token of the chat model.
""" """
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs) generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
while True: while True:
try: try:
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop) task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
@ -115,14 +115,14 @@ class ChatModel:
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r""" r"""
Asynchronously gets the response token-by-token of the chat model. Asynchronously gets the response token-by-token of the chat model.
""" """
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs): async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
yield new_token yield new_token
def get_scores( def get_scores(

View File

@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]} mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
if image is not None: if images is not None:
mm_input_dict.update({"images": [image], "imglens": [1]}) mm_input_dict.update({"images": images, "imglens": [len(images)]})
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
if video is not None: if videos is not None:
mm_input_dict.update({"videos": [video], "vidlens": [1]}) mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
if VIDEO_PLACEHOLDER not in messages[0]["content"]: if VIDEO_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
messages = template.mm_plugin.process_messages( messages = template.mm_plugin.process_messages(
messages, mm_input_dict["images"], mm_input_dict["videos"], processor messages, mm_input_dict["images"], mm_input_dict["videos"], processor
@ -186,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List["Response"]: ) -> List["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
@ -222,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs model,
tokenizer,
processor,
template,
generating_args,
messages,
system,
tools,
images,
videos,
input_kwargs,
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
@ -270,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
if not self.can_generate: if not self.can_generate:
@ -287,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
video, videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:
@ -301,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:
@ -318,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
messages, messages,
system, system,
tools, tools,
image, images,
video, videos,
input_kwargs, input_kwargs,
) )
async with self.semaphore: async with self.semaphore:

View File

@ -101,14 +101,14 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}" request_id = f"chatcmpl-{uuid.uuid4().hex}"
if image is not None: if images is not None:
if IMAGE_PLACEHOLDER not in messages[0]["content"]: if IMAGE_PLACEHOLDER not in messages[0]["content"]:
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] system = system or self.generating_args["default_system"]
@ -157,14 +157,18 @@ class VllmEngine(BaseEngine):
skip_special_tokens=True, skip_special_tokens=True,
) )
if image is not None: # add image features if images is not None: # add image features
if not isinstance(image, (str, ImageObject)): image_data = []
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.") for image in images:
if not isinstance(image, (str, ImageObject)):
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
if isinstance(image, str): if isinstance(image, str):
image = Image.open(image).convert("RGB") image = Image.open(image).convert("RGB")
multi_modal_data = {"image": image} image_data.append(image)
multi_modal_data = {"image": image_data}
else: else:
multi_modal_data = None multi_modal_data = None
@ -182,12 +186,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
final_output = None final_output = None
generator = await self._generate(messages, system, tools, image, video, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for request_output in generator: async for request_output in generator:
final_output = request_output final_output = request_output
@ -210,12 +214,12 @@ class VllmEngine(BaseEngine):
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
image: Optional["ImageInput"] = None, images: Optional[Sequence["ImageInput"]] = None,
video: Optional["VideoInput"] = None, videos: Optional[Sequence["VideoInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""
generator = await self._generate(messages, system, tools, image, video, **input_kwargs) generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
async for result in generator: async for result in generator:
delta_text = result.outputs[0].text[len(generated_text) :] delta_text = result.outputs[0].text[len(generated_text) :]
generated_text = result.outputs[0].text generated_text = result.outputs[0].text

View File

@ -226,6 +226,14 @@ class BasePlugin:
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r""" r"""
Builds batched multimodal inputs for VLMs. Builds batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
seqlens: number of tokens in each sample, shape (batch_size,)
processor: a processor for pre-processing images and videos
""" """
self._validate_input(images, videos) self._validate_input(images, videos)
return {} return {}

View File

@ -141,7 +141,14 @@ class WebChatModel(ChatModel):
chatbot[-1][1] = "" chatbot[-1][1] = ""
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
messages, system, tools, image, video, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature messages,
system,
tools,
images=[image],
videos=[video],
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
): ):
response += new_text response += new_text
if tools: if tools: