mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 04:40:35 +08:00
support multiimage inference
This commit is contained in:
@@ -66,8 +66,8 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
@@ -81,8 +81,8 @@ class BaseEngine(ABC):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
|
||||
@@ -64,15 +64,15 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
task = asyncio.run_coroutine_threadsafe(
|
||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
||||
self.achat(messages, system, tools, images, videos, **input_kwargs), self._loop
|
||||
)
|
||||
return task.result()
|
||||
|
||||
@@ -81,28 +81,28 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Asynchronously gets a list of responses of the chat model.
|
||||
"""
|
||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
||||
return await self.engine.chat(messages, system, tools, images, videos, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
||||
generator = self.astream_chat(messages, system, tools, images, videos, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||
@@ -115,14 +115,14 @@ class ChatModel:
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Asynchronously gets the response token-by-token of the chat model.
|
||||
"""
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, images, videos, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
def get_scores(
|
||||
|
||||
@@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||
if image is not None:
|
||||
mm_input_dict.update({"images": [image], "imglens": [1]})
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||
|
||||
if video is not None:
|
||||
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
||||
if videos is not None:
|
||||
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||
@@ -186,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
template,
|
||||
generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
@@ -222,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
template,
|
||||
generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
@@ -270,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@@ -287,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
@@ -301,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
@@ -318,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
|
||||
@@ -101,14 +101,14 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncIterator["RequestOutput"]:
|
||||
request_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
if image is not None:
|
||||
if images is not None:
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||
|
||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||
system = system or self.generating_args["default_system"]
|
||||
@@ -157,14 +157,18 @@ class VllmEngine(BaseEngine):
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
if image is not None: # add image features
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
if images is not None: # add image features
|
||||
image_data = []
|
||||
for image in images:
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
multi_modal_data = {"image": image}
|
||||
image_data.append(image)
|
||||
|
||||
multi_modal_data = {"image": image_data}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
@@ -182,12 +186,12 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
final_output = None
|
||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
||||
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||
async for request_output in generator:
|
||||
final_output = request_output
|
||||
|
||||
@@ -210,12 +214,12 @@ class VllmEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generated_text = ""
|
||||
generator = await self._generate(messages, system, tools, image, video, **input_kwargs)
|
||||
generator = await self._generate(messages, system, tools, images, videos, **input_kwargs)
|
||||
async for result in generator:
|
||||
delta_text = result.outputs[0].text[len(generated_text) :]
|
||||
generated_text = result.outputs[0].text
|
||||
|
||||
Reference in New Issue
Block a user