mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 02:12:14 +08:00
[misc] fix import error (#9296)
This commit is contained in:
parent
8c341cbaae
commit
a442fa90ad
@ -14,8 +14,6 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
|
||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import AttentionFunction
|
from ...extras.constants import AttentionFunction
|
||||||
|
|
||||||
@ -30,6 +28,8 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
|
||||||
|
from transformers.utils import is_flash_attn_2_available
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "gemma2":
|
if getattr(config, "model_type", None) == "gemma2":
|
||||||
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
@ -51,6 +51,8 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
requested_attn_implementation = "eager"
|
requested_attn_implementation = "eager"
|
||||||
|
|
||||||
elif model_args.flash_attn == AttentionFunction.SDPA:
|
elif model_args.flash_attn == AttentionFunction.SDPA:
|
||||||
|
from transformers.utils import is_torch_sdpa_available
|
||||||
|
|
||||||
if not is_torch_sdpa_available():
|
if not is_torch_sdpa_available():
|
||||||
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||||
return
|
return
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user