mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[misc] fix script (#6977)
Former-commit-id: 775efa1d8cbdb1b7d122be2a986d47f85214e0a1
This commit is contained in:
		
							parent
							
								
									f5cd17881e
								
							
						
					
					
						commit
						be33ef67fb
					
				@ -46,13 +46,14 @@ def vllm_infer(
 | 
			
		||||
    max_new_tokens: int = 1024,
 | 
			
		||||
    repetition_penalty: float = 1.0,
 | 
			
		||||
    pipeline_parallel_size: int = 1,
 | 
			
		||||
    image_resolution: int = 512 * 512,
 | 
			
		||||
    image_max_pixels: int = 768 * 768,
 | 
			
		||||
    image_min_pixels: int = 32 * 32,
 | 
			
		||||
):
 | 
			
		||||
    r"""
 | 
			
		||||
    Performs batch generation using vLLM engine, which supports tensor parallelism.
 | 
			
		||||
    Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
 | 
			
		||||
    """
 | 
			
		||||
    check_version("vllm>=0.4.3,<=0.6.5")
 | 
			
		||||
    check_version("vllm>=0.4.3,<=0.7.2")
 | 
			
		||||
    if pipeline_parallel_size > get_device_count():
 | 
			
		||||
        raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
 | 
			
		||||
 | 
			
		||||
@ -86,7 +87,9 @@ def vllm_infer(
 | 
			
		||||
    for sample in dataset_module["train_dataset"]:
 | 
			
		||||
        if sample["images"]:
 | 
			
		||||
            multi_modal_data = {
 | 
			
		||||
                "image": template_obj.mm_plugin._regularize_images(sample["images"], image_resolution=image_resolution)
 | 
			
		||||
                "image": template_obj.mm_plugin._regularize_images(
 | 
			
		||||
                    sample["images"], image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
 | 
			
		||||
                )
 | 
			
		||||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
@ -21,18 +21,13 @@ from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras import logging
 | 
			
		||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.misc import get_device_count
 | 
			
		||||
from ..extras.packages import is_pillow_available, is_vllm_available
 | 
			
		||||
from ..extras.packages import is_vllm_available
 | 
			
		||||
from ..model import load_config, load_tokenizer
 | 
			
		||||
from ..model.model_utils.quantization import QuantizationMethod
 | 
			
		||||
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
 | 
			
		||||
from .base_engine import BaseEngine, Response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_pillow_available():
 | 
			
		||||
    from PIL import Image
 | 
			
		||||
    from PIL.Image import Image as ImageObject
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_vllm_available():
 | 
			
		||||
    from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
 | 
			
		||||
    from vllm.lora.request import LoRARequest
 | 
			
		||||
@ -54,6 +49,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        finetuning_args: "FinetuningArguments",
 | 
			
		||||
        generating_args: "GeneratingArguments",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.model_args = model_args
 | 
			
		||||
        config = load_config(model_args)  # may download model from ms hub
 | 
			
		||||
        if getattr(config, "quantization_config", None):  # gptq models should use float16
 | 
			
		||||
            quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
 | 
			
		||||
@ -180,15 +176,13 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if images is not None:  # add image features
 | 
			
		||||
            multi_modal_data = {"image": []}
 | 
			
		||||
            for image in images:
 | 
			
		||||
                if not isinstance(image, (str, ImageObject)):
 | 
			
		||||
                    raise ValueError(f"Expected image input is a path or PIL.Image, but got {type(image)}.")
 | 
			
		||||
 | 
			
		||||
                if isinstance(image, str):
 | 
			
		||||
                    image = Image.open(image).convert("RGB")
 | 
			
		||||
 | 
			
		||||
                multi_modal_data["image"].append(image)
 | 
			
		||||
            multi_modal_data = {
 | 
			
		||||
                "image": self.template.mm_plugin._regularize_images(
 | 
			
		||||
                    images,
 | 
			
		||||
                    image_max_pixels=self.model_args.image_max_pixels,
 | 
			
		||||
                    image_min_pixels=self.model_args.image_min_pixels,
 | 
			
		||||
                )
 | 
			
		||||
            }
 | 
			
		||||
        else:
 | 
			
		||||
            multi_modal_data = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1112,9 +1112,13 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
        self._validate_input(images, videos, audios)
 | 
			
		||||
        image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
 | 
			
		||||
        merge_length: int = getattr(image_processor, "merge_size") ** 2
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        image_grid_thw = mm_inputs.get("image_grid_thw", [])
 | 
			
		||||
        video_grid_thw = mm_inputs.get("video_grid_thw", [])
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
            image_grid_thw = mm_inputs.get("image_grid_thw", [])
 | 
			
		||||
            video_grid_thw = mm_inputs.get("video_grid_thw", [])
 | 
			
		||||
        else:
 | 
			
		||||
            image_grid_thw = [None] * len(images)
 | 
			
		||||
            video_grid_thw = [None] * len(videos)
 | 
			
		||||
 | 
			
		||||
        num_image_tokens, num_video_tokens = 0, 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user