fix yi vl vllm infer

Former-commit-id: de54e5d7ec06dd7c20ec82c9ff032fc16cd50244
This commit is contained in:
hiyouga
2024-05-15 19:25:48 +08:00
parent 096677b989
commit 40e3d3fbdd
16 changed files with 113 additions and 24 deletions

View File

@@ -2,9 +2,11 @@ import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_device_count, infer_optim_dtype
from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer
from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
from .base_engine import BaseEngine, Response
@@ -22,6 +24,9 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class VllmEngine(BaseEngine):
def __init__(
self,
@@ -57,13 +62,19 @@ class VllmEngine(BaseEngine):
}
if model_args.visual_inputs:
# TODO: auto derive from config
# https://github.com/vllm-project/vllm/pull/3042#issuecomment-1984893549
self.image_feature_size = 576
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_feature_size = (image_size // patch_size) ** 2
engine_args["image_input_type"] = "pixel_values"
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
engine_args["image_input_shape"] = "1,3,336,336"
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
engine_args["image_feature_size"] = self.image_feature_size
if getattr(config, "is_yi_vl_derived_model", None):
# bug in vllm 0.4.2, see: https://github.com/vllm-project/vllm/pull/4828
import vllm.model_executor.models.llava
logger.info("Detected Yi-VL model, applying projector patch.")
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
if model_args.adapter_name_or_path is not None: