mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[misc] fix script (#6977)
Former-commit-id: cc8c7e762b9c873ef79529152465bbed9231053c
This commit is contained in:
parent
1f4a0b11ba
commit
184c5d0882
@ -46,13 +46,14 @@ def vllm_infer(
|
|||||||
max_new_tokens: int = 1024,
|
max_new_tokens: int = 1024,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
image_resolution: int = 512 * 512,
|
image_max_pixels: int = 768 * 768,
|
||||||
|
image_min_pixels: int = 32 * 32,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
||||||
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
||||||
"""
|
"""
|
||||||
check_version("vllm>=0.4.3,<=0.6.5")
|
check_version("vllm>=0.4.3,<=0.7.2")
|
||||||
if pipeline_parallel_size > get_device_count():
|
if pipeline_parallel_size > get_device_count():
|
||||||
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
|
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
|
||||||
|
|
||||||
@ -86,7 +87,9 @@ def vllm_infer(
|
|||||||
for sample in dataset_module["train_dataset"]:
|
for sample in dataset_module["train_dataset"]:
|
||||||
if sample["images"]:
|
if sample["images"]:
|
||||||
multi_modal_data = {
|
multi_modal_data = {
|
||||||
"image": template_obj.mm_plugin._regularize_images(sample["images"], image_resolution=image_resolution)
|
"image": template_obj.mm_plugin._regularize_images(
|
||||||
|
sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||||
|
)
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
@ -21,18 +21,13 @@ from ..data import get_template_and_fix_tokenizer
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
from ..extras.packages import is_vllm_available
|
||||||
from ..model import load_config, load_tokenizer
|
from ..model import load_config, load_tokenizer
|
||||||
from ..model.model_utils.quantization import QuantizationMethod
|
from ..model.model_utils.quantization import QuantizationMethod
|
||||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
|
|
||||||
if is_pillow_available():
|
|
||||||
from PIL import Image
|
|
||||||
from PIL.Image import Image as ImageObject
|
|
||||||
|
|
||||||
|
|
||||||
if is_vllm_available():
|
if is_vllm_available():
|
||||||
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -54,6 +49,7 @@ class VllmEngine(BaseEngine):
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self.model_args = model_args
|
||||||
config = load_config(model_args) # may download model from ms hub
|
config = load_config(model_args) # may download model from ms hub
|
||||||
if getattr(config, "quantization_config", None): # gptq models should use float16
|
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||||
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
|
||||||
@ -180,15 +176,13 @@ class VllmEngine(BaseEngine):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if images is not None: # add image features
|
if images is not None: # add image features
|
||||||
multi_modal_data = {"image": []}
|
multi_modal_data = {
|
||||||
for image in images:
|
"image": self.template.mm_plugin._regularize_images(
|
||||||
if not isinstance(image, (str, ImageObject)):
|
images,
|
||||||
raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
|
image_max_pixels=self.model_args.image_max_pixels,
|
||||||
|
image_min_pixels=self.model_args.image_min_pixels,
|
||||||
if isinstance(image, str):
|
)
|
||||||
image = Image.open(image).convert("RGB")
|
}
|
||||||
|
|
||||||
multi_modal_data["image"].append(image)
|
|
||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
|
@ -1112,9 +1112,13 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
self._validate_input(images, videos, audios)
|
self._validate_input(images, videos, audios)
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
if self.expand_mm_tokens:
|
||||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
|
else:
|
||||||
|
image_grid_thw = [None] * len(images)
|
||||||
|
video_grid_thw = [None] * len(videos)
|
||||||
|
|
||||||
num_image_tokens, num_video_tokens = 0, 0
|
num_image_tokens, num_video_tokens = 0, 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user