From 14abb751263215a362c1baa28ce9173b2eba2650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AD=85=E5=BD=B1?= <46097299+frozenleaves@users.noreply.github.com> Date: Tue, 4 Nov 2025 19:32:30 +0800 Subject: [PATCH] [model] enable using FA in npu (#9397) Co-authored-by: frozenleaves --- src/llamafactory/extras/packages.py | 4 ++++ src/llamafactory/model/model_utils/attention.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 904207ae..c4f6dc7a 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -110,6 +110,10 @@ def is_starlette_available(): def is_transformers_version_greater_than(content: str): return _get_package_version("transformers") >= version.parse(content) +@lru_cache +def is_torch_version_greater_than(content: str): + return _get_package_version("torch") >= version.parse(content) + def is_uvicorn_available(): return _is_package_available("uvicorn") diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index d901a9a8..655c5814 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING from ...extras import logging from ...extras.constants import AttentionFunction +from ...extras.packages import is_torch_version_greater_than if TYPE_CHECKING: @@ -51,15 +52,14 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model requested_attn_implementation = "eager" elif model_args.flash_attn == AttentionFunction.SDPA: - from transformers.utils import is_torch_sdpa_available - - if not is_torch_sdpa_available(): + if not is_torch_version_greater_than("2.1.1"): logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.") return requested_attn_implementation = "sdpa" elif model_args.flash_attn == AttentionFunction.FA2: - if not is_flash_attn_2_available(): + from transformers import is_torch_npu_available + if not (is_flash_attn_2_available() or is_torch_npu_available()): logger.warning_rank0("FlashAttention-2 is not installed.") return