mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-28 01:30:36 +08:00
fix yi vl vllm infer
Former-commit-id: de54e5d7ec06dd7c20ec82c9ff032fc16cd50244
This commit is contained in:
@@ -2,9 +2,11 @@ import uuid
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import get_device_count, infer_optim_dtype
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
@@ -22,6 +24,9 @@ if TYPE_CHECKING:
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class VllmEngine(BaseEngine):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -57,13 +62,19 @@ class VllmEngine(BaseEngine):
|
||||
}
|
||||
|
||||
if model_args.visual_inputs:
|
||||
# TODO: auto derive from config
|
||||
# https://github.com/vllm-project/vllm/pull/3042#issuecomment-1984893549
|
||||
self.image_feature_size = 576
|
||||
image_size = config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.image_feature_size = (image_size // patch_size) ** 2
|
||||
engine_args["image_input_type"] = "pixel_values"
|
||||
engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("<image>")
|
||||
engine_args["image_input_shape"] = "1,3,336,336"
|
||||
engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size)
|
||||
engine_args["image_feature_size"] = self.image_feature_size
|
||||
if getattr(config, "is_yi_vl_derived_model", None):
|
||||
# bug in vllm 0.4.2, see: https://github.com/vllm-project/vllm/pull/4828
|
||||
import vllm.model_executor.models.llava
|
||||
|
||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
||||
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
||||
|
||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||
if model_args.adapter_name_or_path is not None:
|
||||
|
||||
Reference in New Issue
Block a user