diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index c794b7c7b..44bdbed83 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -154,25 +154,24 @@ def vllm_infer( batch = train_dataset[i : min(i + batch_size, len(train_dataset))] for j in range(len(batch["input_ids"])): + multi_modal_data = {} + video_metadata_kwargs = None + if batch["images"][j] is not None: image = batch["images"][j] - multi_modal_data = { - "image": template_obj.mm_plugin._regularize_images( - image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels - )["images"] - } - elif batch["videos"][j] is not None: - video_metadata, video_metadata_kwargs = None, None + multi_modal_data["image"] = template_obj.mm_plugin._regularize_images( + image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels + )["images"] + + if batch["videos"][j] is not None: video = batch["videos"][j] - multi_modal_data = { - "video": template_obj.mm_plugin._regularize_videos( - video, - image_max_pixels=image_max_pixels, - image_min_pixels=image_min_pixels, - video_fps=video_fps, - video_maxlen=video_maxlen, - )["videos"] - } + multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos( + video, + image_max_pixels=image_max_pixels, + image_min_pixels=image_min_pixels, + video_fps=video_fps, + video_maxlen=video_maxlen, + )["videos"] if need_video_kwargs: container = av.open(video[0], "r") video_stream = next(stream for stream in container.streams if stream.type == "video") @@ -192,18 +191,17 @@ def vllm_infer( video_backend="opencv", ) multi_modal_data["video"] = (multi_modal_data["video"], video_metadata) - elif batch["audios"][j] is not None: + + if batch["audios"][j] is not None: audio = batch["audios"][j] audio_data = template_obj.mm_plugin._regularize_audios( audio, sampling_rate=16000, ) - multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} - else: - multi_modal_data = None + multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"]) - vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data} - if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None: + vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None} + if video_metadata_kwargs is not None: vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs vllm_inputs.append(vllm_input_data) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 075924a2f..2ef72b5be 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -180,35 +180,32 @@ class VllmEngine(BaseEngine): else self.generating_args["skip_special_tokens"], ) + multi_modal_data = {} if images is not None: # add image features - multi_modal_data = { - "image": self.template.mm_plugin._regularize_images( - images, - image_max_pixels=self.model_args.image_max_pixels, - image_min_pixels=self.model_args.image_min_pixels, - )["images"] - } - elif videos is not None: - multi_modal_data = { - "video": self.template.mm_plugin._regularize_videos( - videos, - image_max_pixels=self.model_args.video_max_pixels, - image_min_pixels=self.model_args.video_min_pixels, - video_fps=self.model_args.video_fps, - video_maxlen=self.model_args.video_maxlen, - )["videos"] - } - elif audios is not None: + multi_modal_data["image"] = self.template.mm_plugin._regularize_images( + images, + image_max_pixels=self.model_args.image_max_pixels, + image_min_pixels=self.model_args.image_min_pixels, + )["images"] + + if videos is not None: + multi_modal_data["video"] = self.template.mm_plugin._regularize_videos( + videos, + image_max_pixels=self.model_args.video_max_pixels, + image_min_pixels=self.model_args.video_min_pixels, + video_fps=self.model_args.video_fps, + video_maxlen=self.model_args.video_maxlen, + )["videos"] + + if audios is not None: audio_data = self.template.mm_plugin._regularize_audios( audios, sampling_rate=self.model_args.audio_sampling_rate, ) - multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} - else: - multi_modal_data = None + multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"]) result_generator = self.model.generate( - {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, + {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None}, sampling_params=sampling_params, request_id=request_id, lora_request=self.lora_request,