support multiimage inference

This commit is contained in:
hiyouga
2024-11-01 07:25:20 +00:00
parent 641d0dab08
commit e80a481927
7 changed files with 103 additions and 63 deletions

View File

@@ -69,7 +69,7 @@ ROLE_MAPPING = {
def _process_request(
request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
if len(request.messages) == 0:
@@ -84,7 +84,7 @@ def _process_request(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
input_messages = []
image = None
images = []
for i, message in enumerate(request.messages):
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
@@ -111,10 +111,11 @@ def _process_request(
else: # web uri
image_stream = requests.get(image_url, stream=True).raw
image = Image.open(image_stream).convert("RGB")
images.append(Image.open(image_stream).convert("RGB"))
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
images = None if len(images) == 0 else images
tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list):
try:
@@ -124,7 +125,7 @@ def _process_request(
else:
tools = None
return input_messages, system, tools, image
return input_messages, system, tools, images
def _create_stream_chat_completion_chunk(
@@ -143,12 +144,12 @@ async def create_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> "ChatCompletionResponse":
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request)
input_messages, system, tools, images = _process_request(request)
responses = await chat_model.achat(
input_messages,
system,
tools,
image,
images,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,
@@ -194,7 +195,7 @@ async def create_stream_chat_completion_response(
request: "ChatCompletionRequest", chat_model: "ChatModel"
) -> AsyncGenerator[str, None]:
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
input_messages, system, tools, image = _process_request(request)
input_messages, system, tools, images = _process_request(request)
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
@@ -208,7 +209,7 @@ async def create_stream_chat_completion_response(
input_messages,
system,
tools,
image,
images,
do_sample=request.do_sample,
temperature=request.temperature,
top_p=request.top_p,