mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[misc] fix script (#6977)
Former-commit-id: cc8c7e762b9c873ef79529152465bbed9231053c
This commit is contained in:
parent
1f4a0b11ba
commit
184c5d0882
@ -46,13 +46,14 @@ def vllm_infer(
|
||||
max_new_tokens: int = 1024,
|
||||
repetition_penalty: float = 1.0,
|
||||
pipeline_parallel_size: int = 1,
|
||||
image_resolution: int = 512 * 512,
|
||||
image_max_pixels: int = 768 * 768,
|
||||
image_min_pixels: int = 32 * 32,
|
||||
):
|
||||
r"""
|
||||
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
||||
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
||||
"""
|
||||
check_version("vllm>=0.4.3,<=0.6.5")
|
||||
check_version("vllm>=0.4.3,<=0.7.2")
|
||||
if pipeline_parallel_size > get_device_count():
|
||||
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
|
||||
|
||||
@ -86,7 +87,9 @@ def vllm_infer(
|
||||
for sample in dataset_module["train_dataset"]:
|
||||
if sample["images"]:
|
||||
multi_modal_data = {
|
||||
"image": template_obj.mm_plugin._regularize_images(sample["images"], image_resolution=image_resolution)
|
||||
"image": template_obj.mm_plugin._regularize_images(
|
||||
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)
|
||||
}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
@ -21,18 +21,13 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
from ..model.model_utils.quantization import QuantizationMethod
|
||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
if is_pillow_available():
|
||||
from PIL import Image
|
||||
from PIL.Image import Image as ImageObject
|
||||
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -54,6 +49,7 @@ class VllmEngine(BaseEngine):
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
self.model_args = model_args
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
@ -180,15 +176,13 @@ class VllmEngine(BaseEngine):
|
||||
)
|
||||
|
||||
if images is not None: # add image features
|
||||
multi_modal_data = {"image": []}
|
||||
for image in images:
|
||||
if not isinstance(image, (str, ImageObject)):
|
||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
||||
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image).convert("RGB")
|
||||
|
||||
multi_modal_data["image"].append(image)
|
||||
multi_modal_data = {
|
||||
"image": self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)
|
||||
}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
|
@ -1112,9 +1112,13 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
self._validate_input(images, videos, audios)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
else:
|
||||
image_grid_thw = [None] * len(images)
|
||||
video_grid_thw = [None] * len(videos)
|
||||
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
|
Loading…
x
Reference in New Issue
Block a user