clean code

Former-commit-id: 2ed8270112
This commit is contained in:
hiyouga
2024-06-13 01:58:16 +08:00
parent 7366647b43
commit 833aa324c2
4 changed files with 17 additions and 27 deletions

View File

@@ -1,7 +1,8 @@
from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras.logging import get_logger
from ...extras.packages import is_flash_attn2_available, is_sdpa_available
if TYPE_CHECKING:
@@ -21,13 +22,13 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa":
if not is_sdpa_available():
if not is_torch_sdpa_available():
logger.warning("torch>=2.1.1 is required for SDPA attention.")
return
requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2":
if not is_flash_attn2_available():
if not is_flash_attn_2_available():
logger.warning("FlashAttention-2 is not installed.")
return