mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
support multiimage inference
This commit is contained in:
@@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
mm_input_dict = {"images": [], "videos": [], "imglens": [0], "vidlens": [0]}
|
||||
if image is not None:
|
||||
mm_input_dict.update({"images": [image], "imglens": [1]})
|
||||
if images is not None:
|
||||
mm_input_dict.update({"images": images, "imglens": [len(images)]})
|
||||
if IMAGE_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"]
|
||||
messages[0]["content"] = IMAGE_PLACEHOLDER * len(images) + messages[0]["content"]
|
||||
|
||||
if video is not None:
|
||||
mm_input_dict.update({"videos": [video], "vidlens": [1]})
|
||||
if videos is not None:
|
||||
mm_input_dict.update({"videos": videos, "vidlens": [len(videos)]})
|
||||
if VIDEO_PLACEHOLDER not in messages[0]["content"]:
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER + messages[0]["content"]
|
||||
messages[0]["content"] = VIDEO_PLACEHOLDER * len(videos) + messages[0]["content"]
|
||||
|
||||
messages = template.mm_plugin.process_messages(
|
||||
messages, mm_input_dict["images"], mm_input_dict["videos"], processor
|
||||
@@ -186,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> List["Response"]:
|
||||
gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
template,
|
||||
generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
generate_output = model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
@@ -222,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||
) -> Callable[[], str]:
|
||||
gen_kwargs, _ = HuggingfaceEngine._process_args(
|
||||
model, tokenizer, processor, template, generating_args, messages, system, tools, image, video, input_kwargs
|
||||
model,
|
||||
tokenizer,
|
||||
processor,
|
||||
template,
|
||||
generating_args,
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
@@ -270,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
if not self.can_generate:
|
||||
@@ -287,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
@@ -301,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
images: Optional[Sequence["ImageInput"]] = None,
|
||||
videos: Optional[Sequence["VideoInput"]] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if not self.can_generate:
|
||||
@@ -318,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
image,
|
||||
video,
|
||||
images,
|
||||
videos,
|
||||
input_kwargs,
|
||||
)
|
||||
async with self.semaphore:
|
||||
|
||||
Reference in New Issue
Block a user