mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-11 14:36:00 +08:00
[infer] support mixed multimodal payloads (#10225)
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
This commit is contained in:
@@ -154,25 +154,24 @@ def vllm_infer(
|
|||||||
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
||||||
|
|
||||||
for j in range(len(batch["input_ids"])):
|
for j in range(len(batch["input_ids"])):
|
||||||
|
multi_modal_data = {}
|
||||||
|
video_metadata_kwargs = None
|
||||||
|
|
||||||
if batch["images"][j] is not None:
|
if batch["images"][j] is not None:
|
||||||
image = batch["images"][j]
|
image = batch["images"][j]
|
||||||
multi_modal_data = {
|
multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
|
||||||
"image": template_obj.mm_plugin._regularize_images(
|
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
)["images"]
|
||||||
)["images"]
|
|
||||||
}
|
if batch["videos"][j] is not None:
|
||||||
elif batch["videos"][j] is not None:
|
|
||||||
video_metadata, video_metadata_kwargs = None, None
|
|
||||||
video = batch["videos"][j]
|
video = batch["videos"][j]
|
||||||
multi_modal_data = {
|
multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
|
||||||
"video": template_obj.mm_plugin._regularize_videos(
|
video,
|
||||||
video,
|
image_max_pixels=image_max_pixels,
|
||||||
image_max_pixels=image_max_pixels,
|
image_min_pixels=image_min_pixels,
|
||||||
image_min_pixels=image_min_pixels,
|
video_fps=video_fps,
|
||||||
video_fps=video_fps,
|
video_maxlen=video_maxlen,
|
||||||
video_maxlen=video_maxlen,
|
)["videos"]
|
||||||
)["videos"]
|
|
||||||
}
|
|
||||||
if need_video_kwargs:
|
if need_video_kwargs:
|
||||||
container = av.open(video[0], "r")
|
container = av.open(video[0], "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
@@ -192,18 +191,17 @@ def vllm_infer(
|
|||||||
video_backend="opencv",
|
video_backend="opencv",
|
||||||
)
|
)
|
||||||
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||||
elif batch["audios"][j] is not None:
|
|
||||||
|
if batch["audios"][j] is not None:
|
||||||
audio = batch["audios"][j]
|
audio = batch["audios"][j]
|
||||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
)
|
)
|
||||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||||
else:
|
|
||||||
multi_modal_data = None
|
|
||||||
|
|
||||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
|
||||||
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
if video_metadata_kwargs is not None:
|
||||||
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||||
|
|
||||||
vllm_inputs.append(vllm_input_data)
|
vllm_inputs.append(vllm_input_data)
|
||||||
|
|||||||
@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
|
|||||||
else self.generating_args["skip_special_tokens"],
|
else self.generating_args["skip_special_tokens"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
multi_modal_data = {}
|
||||||
if images is not None: # add image features
|
if images is not None: # add image features
|
||||||
multi_modal_data = {
|
multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
|
||||||
"image": self.template.mm_plugin._regularize_images(
|
images,
|
||||||
images,
|
image_max_pixels=self.model_args.image_max_pixels,
|
||||||
image_max_pixels=self.model_args.image_max_pixels,
|
image_min_pixels=self.model_args.image_min_pixels,
|
||||||
image_min_pixels=self.model_args.image_min_pixels,
|
)["images"]
|
||||||
)["images"]
|
|
||||||
}
|
if videos is not None:
|
||||||
elif videos is not None:
|
multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
|
||||||
multi_modal_data = {
|
videos,
|
||||||
"video": self.template.mm_plugin._regularize_videos(
|
image_max_pixels=self.model_args.video_max_pixels,
|
||||||
videos,
|
image_min_pixels=self.model_args.video_min_pixels,
|
||||||
image_max_pixels=self.model_args.video_max_pixels,
|
video_fps=self.model_args.video_fps,
|
||||||
image_min_pixels=self.model_args.video_min_pixels,
|
video_maxlen=self.model_args.video_maxlen,
|
||||||
video_fps=self.model_args.video_fps,
|
)["videos"]
|
||||||
video_maxlen=self.model_args.video_maxlen,
|
|
||||||
)["videos"]
|
if audios is not None:
|
||||||
}
|
|
||||||
elif audios is not None:
|
|
||||||
audio_data = self.template.mm_plugin._regularize_audios(
|
audio_data = self.template.mm_plugin._regularize_audios(
|
||||||
audios,
|
audios,
|
||||||
sampling_rate=self.model_args.audio_sampling_rate,
|
sampling_rate=self.model_args.audio_sampling_rate,
|
||||||
)
|
)
|
||||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||||
else:
|
|
||||||
multi_modal_data = None
|
|
||||||
|
|
||||||
result_generator = self.model.generate(
|
result_generator = self.model.generate(
|
||||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=self.lora_request,
|
lora_request=self.lora_request,
|
||||||
|
|||||||
Reference in New Issue
Block a user