[model] enable using FA in npu (#9397)

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
魅影 2025-11-04 19:32:30 +08:00 committed by GitHub
parent 5a9939050e
commit 14abb75126
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 4 deletions

View File

@ -110,6 +110,10 @@ def is_starlette_available():
def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
@lru_cache
def is_torch_version_greater_than(content: str):
return _get_package_version("torch") >= version.parse(content)
def is_uvicorn_available():
return _is_package_available("uvicorn")

View File

@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
from ...extras import logging
from ...extras.constants import AttentionFunction
from ...extras.packages import is_torch_version_greater_than
if TYPE_CHECKING:
@ -51,15 +52,14 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation = "eager"
elif model_args.flash_attn == AttentionFunction.SDPA:
from transformers.utils import is_torch_sdpa_available
if not is_torch_sdpa_available():
if not is_torch_version_greater_than("2.1.1"):
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == AttentionFunction.FA2:
if not is_flash_attn_2_available():
from transformers import is_torch_npu_available
if not (is_flash_attn_2_available() or is_torch_npu_available()):
logger.warning_rank0("FlashAttention-2 is not installed.")
return