From 78cf256067fe844083a81d310e70b7faf9831c03 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sun, 8 Sep 2024 02:26:20 +0800 Subject: [PATCH] support vllm 0.6.0 Former-commit-id: b6681d7198acf4acbebfe271dd22095e236bc430 --- setup.py | 2 +- src/llamafactory/chat/hf_engine.py | 1 - src/llamafactory/chat/vllm_engine.py | 48 ++++++---------------------- src/llamafactory/extras/packages.py | 10 ------ src/llamafactory/hparams/parser.py | 2 +- 5 files changed, 12 insertions(+), 51 deletions(-) diff --git a/setup.py b/setup.py index ef2b666d..e3184823 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ extra_require = { "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"], "awq": ["autoawq"], "aqlm": ["aqlm[gpu]>=1.1.0"], - "vllm": ["vllm>=0.4.3"], + "vllm": ["vllm>=0.4.3,<=0.6.0"], "galore": ["galore-torch"], "badam": ["badam>=1.2.1"], "adam-mini": ["adam-mini"], diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index e0755ac2..8819dc79 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -103,7 +103,6 @@ class HuggingfaceEngine(BaseEngine): prompt_ids, _ = template.mm_plugin.process_token_ids( prompt_ids, None, mm_input_dict["images"], mm_input_dict["videos"], tokenizer, processor ) - prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) attention_mask = torch.ones_like(inputs, dtype=torch.bool) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 0e8e8b1c..46bbac00 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -21,7 +21,7 @@ from ..data import get_template_and_fix_tokenizer from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.logging import get_logger from ..extras.misc import get_device_count -from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1 +from ..extras.packages import is_vllm_available from ..model import load_config, load_tokenizer from ..model.model_utils.quantization import QuantizationMethod from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM @@ -32,17 +32,8 @@ if is_vllm_available(): from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm.lora.request import LoRARequest - if is_vllm_version_greater_than_0_5_1(): - pass - elif is_vllm_version_greater_than_0_5(): - from vllm.multimodal.image import ImagePixelData - else: - from vllm.sequence import MultiModalData - if TYPE_CHECKING: - from transformers.image_processing_utils import BaseImageProcessor - from ..data.mm_plugin import ImageInput, VideoInput from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -88,19 +79,11 @@ class VllmEngine(BaseEngine): "max_lora_rank": model_args.vllm_max_lora_rank, } - if getattr(config, "model_type", None) == "llava": - image_size = config.vision_config.image_size - patch_size = config.vision_config.patch_size - self.image_feature_size = (image_size // patch_size) ** 2 - engine_args["image_input_type"] = "pixel_values" - engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(self.template.image_token) - engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) - engine_args["image_feature_size"] = self.image_feature_size - if getattr(config, "is_yi_vl_derived_model", None): - import vllm.model_executor.models.llava + if getattr(config, "is_yi_vl_derived_model", None): + import vllm.model_executor.models.llava - logger.info("Detected Yi-VL model, applying projector patch.") - vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM + logger.info("Detected Yi-VL model, applying projector patch.") + vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) if model_args.adapter_name_or_path is not None: @@ -118,29 +101,13 @@ class VllmEngine(BaseEngine): **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = "chatcmpl-{}".format(uuid.uuid4().hex) - if image is not None: if IMAGE_PLACEHOLDER not in messages[0]["content"]: messages[0]["content"] = IMAGE_PLACEHOLDER + messages[0]["content"] - messages = self.template.mm_plugin.process_messages(messages, [image], self.processor) - paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools) - - if self.processor is not None and image is not None: # add image features - image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") - pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] - if is_vllm_version_greater_than_0_5_1(): - multi_modal_data = {"image": pixel_values} - elif is_vllm_version_greater_than_0_5(): - multi_modal_data = ImagePixelData(image=pixel_values) - else: # TODO: remove vllm 0.4.3 support - multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) - else: - multi_modal_data = None - prompt_length = len(prompt_ids) use_beam_search: bool = self.generating_args["num_beams"] > 1 @@ -185,6 +152,11 @@ class VllmEngine(BaseEngine): skip_special_tokens=True, ) + if image is not None: # add image features + multi_modal_data = {"image": image} + else: + multi_modal_data = None + result_generator = self.model.generate( inputs={"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, sampling_params=sampling_params, diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 3465a92a..f9cdf146 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -85,13 +85,3 @@ def is_uvicorn_available(): def is_vllm_available(): return _is_package_available("vllm") - - -@lru_cache -def is_vllm_version_greater_than_0_5(): - return _get_package_version("vllm") >= version.parse("0.5.0") - - -@lru_cache -def is_vllm_version_greater_than_0_5_1(): - return _get_package_version("vllm") >= version.parse("0.5.1") diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index a7dfb0bd..fd112607 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -123,7 +123,7 @@ def _check_extra_dependencies( require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") if model_args.infer_backend == "vllm": - require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3") + require_version("vllm>=0.4.3,<=0.6.0", "To fix: pip install vllm>=0.4.3,<=0.6.0") if finetuning_args.use_galore: require_version("galore_torch", "To fix: pip install galore_torch")