diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 2550f5ba..4cdc290c 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -46,13 +46,14 @@ def vllm_infer( max_new_tokens: int = 1024, repetition_penalty: float = 1.0, pipeline_parallel_size: int = 1, - image_resolution: int = 512 * 512, + image_max_pixels: int = 768 * 768, + image_min_pixels: int = 32 * 32, ): r""" Performs batch generation using vLLM engine, which supports tensor parallelism. Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo """ - check_version("vllm>=0.4.3,<=0.6.5") + check_version("vllm>=0.4.3,<=0.7.2") if pipeline_parallel_size > get_device_count(): raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") @@ -86,7 +87,9 @@ def vllm_infer( for sample in dataset_module["train_dataset"]: if sample["images"]: multi_modal_data = { - "image": template_obj.mm_plugin._regularize_images(sample["images"], image_resolution=image_resolution) + "image": template_obj.mm_plugin._regularize_images( + sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels + ) } else: multi_modal_data = None diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 8627ab96..894d3260 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -21,18 +21,13 @@ from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.misc import get_device_count -from ..extras.packages import is_pillow_available, is_vllm_available +from ..extras.packages import is_vllm_available from ..model import load_config, load_tokenizer from ..model.model_utils.quantization import QuantizationMethod from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from .base_engine import BaseEngine, Response -if is_pillow_available(): - from PIL import Image - from PIL.Image import Image as ImageObject - - if is_vllm_available(): from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm.lora.request import LoRARequest @@ -54,6 +49,7 @@ class VllmEngine(BaseEngine): finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", ) -> None: + self.model_args = model_args config = load_config(model_args) # may download model from ms hub if getattr(config, "quantization_config", None): # gptq models should use float16 quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) @@ -180,15 +176,13 @@ class VllmEngine(BaseEngine): ) if images is not None: # add image features - multi_modal_data = {"image": []} - for image in images: - if not isinstance(image, (str, ImageObject)): - raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.") - - if isinstance(image, str): - image = Image.open(image).convert("RGB") - - multi_modal_data["image"].append(image) + multi_modal_data = { + "image": self.template.mm_plugin._regularize_images( + images, + image_max_pixels=self.model_args.image_max_pixels, + image_min_pixels=self.model_args.image_min_pixels, + ) + } else: multi_modal_data = None diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 1ace77f8..8a0fdab1 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -1112,9 +1112,13 @@ class Qwen2vlPlugin(BasePlugin): self._validate_input(images, videos, audios) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") merge_length: int = getattr(image_processor, "merge_size") ** 2 - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - image_grid_thw = mm_inputs.get("image_grid_thw", []) - video_grid_thw = mm_inputs.get("video_grid_thw", []) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages)