mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-06 19:56:01 +08:00
[infer] support mixed multimodal payloads (#10225)
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
This commit is contained in:
@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
|
||||
else self.generating_args["skip_special_tokens"],
|
||||
)
|
||||
|
||||
multi_modal_data = {}
|
||||
if images is not None: # add image features
|
||||
multi_modal_data = {
|
||||
"image": self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)["images"]
|
||||
}
|
||||
elif videos is not None:
|
||||
multi_modal_data = {
|
||||
"video": self.template.mm_plugin._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=self.model_args.video_max_pixels,
|
||||
image_min_pixels=self.model_args.video_min_pixels,
|
||||
video_fps=self.model_args.video_fps,
|
||||
video_maxlen=self.model_args.video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
elif audios is not None:
|
||||
multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)["images"]
|
||||
|
||||
if videos is not None:
|
||||
multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=self.model_args.video_max_pixels,
|
||||
image_min_pixels=self.model_args.video_min_pixels,
|
||||
video_fps=self.model_args.video_fps,
|
||||
video_maxlen=self.model_args.video_maxlen,
|
||||
)["videos"]
|
||||
|
||||
if audios is not None:
|
||||
audio_data = self.template.mm_plugin._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=self.model_args.audio_sampling_rate,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||
|
||||
result_generator = self.model.generate(
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=self.lora_request,
|
||||
|
||||
Reference in New Issue
Block a user