diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 34126adf..e4c05478 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -1,12 +1,10 @@ import uuid from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union -from packaging import version - from ..data import get_template_and_fix_tokenizer from ..extras.logging import get_logger from ..extras.misc import get_device_count -from ..extras.packages import is_vllm_available, _get_package_version +from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5 from ..model import load_config, load_tokenizer from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from .base_engine import BaseEngine, Response @@ -16,7 +14,7 @@ if is_vllm_available(): from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams from vllm.lora.request import LoRARequest - if _get_package_version("vllm") >= version.parse("0.5.0"): + if is_vllm_version_greater_than_0_5(): from vllm.multimodal.image import ImagePixelData else: from vllm.sequence import MultiModalData @@ -112,9 +110,9 @@ class VllmEngine(BaseEngine): if self.processor is not None and image is not None: # add image features image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] - if _get_package_version("vllm") >= version.parse("0.5.0"): - multi_modal_data = ImagePixelData(pixel_values) - else: + if is_vllm_version_greater_than_0_5(): + multi_modal_data = ImagePixelData(image=pixel_values) + else: # TODO: remove vllm 0.4.3 support multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) else: multi_modal_data = None diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 4c9e6492..0746bb4f 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -1,5 +1,6 @@ import importlib.metadata import importlib.util +from functools import lru_cache from typing import TYPE_CHECKING from packaging import version @@ -24,10 +25,6 @@ def is_fastapi_available(): return _is_package_available("fastapi") -def is_flash_attn2_available(): - return _is_package_available("flash_attn") and _get_package_version("flash_attn") > version.parse("2.0.0") - - def is_galore_available(): return _is_package_available("galore_torch") @@ -36,18 +33,10 @@ def is_gradio_available(): return _is_package_available("gradio") -def is_jieba_available(): - return _is_package_available("jieba") - - def is_matplotlib_available(): return _is_package_available("matplotlib") -def is_nltk_available(): - return _is_package_available("nltk") - - def is_pillow_available(): return _is_package_available("PIL") @@ -60,10 +49,6 @@ def is_rouge_available(): return _is_package_available("rouge_chinese") -def is_sdpa_available(): - return _get_package_version("torch") > version.parse("2.1.1") - - def is_starlette_available(): return _is_package_available("sse_starlette") @@ -74,3 +59,8 @@ def is_uvicorn_available(): def is_vllm_available(): return _is_package_available("vllm") + + +@lru_cache +def is_vllm_version_greater_than_0_5(): + return _get_package_version("vllm") >= version.parse("0.5.0") diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index b52ddc86..2bd36fdc 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -1,7 +1,8 @@ from typing import TYPE_CHECKING +from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available + from ...extras.logging import get_logger -from ...extras.packages import is_flash_attn2_available, is_sdpa_available if TYPE_CHECKING: @@ -21,13 +22,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model requested_attn_implementation = "eager" elif model_args.flash_attn == "sdpa": - if not is_sdpa_available(): + if not is_torch_sdpa_available(): logger.warning("torch>=2.1.1 is required for SDPA attention.") return requested_attn_implementation = "sdpa" elif model_args.flash_attn == "fa2": - if not is_flash_attn2_available(): + if not is_flash_attn_2_available(): logger.warning("FlashAttention-2 is not installed.") return diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index b135fcfb..6ed356c1 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -2,9 +2,10 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union import numpy as np +from transformers.utils import is_jieba_available, is_nltk_available from ...extras.constants import IGNORE_INDEX -from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available +from ...extras.packages import is_rouge_available if TYPE_CHECKING: