diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 996e9130..9b51c064 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -97,7 +97,7 @@ class ModelArguments: default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) - flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field( + flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field( default="auto", metadata={"help": "Enable FlashAttention for faster training and inference."}, ) diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 8ff3807b..dfd90936 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -32,7 +32,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model if model_args.flash_attn == "auto": return - elif model_args.flash_attn == "off": + elif model_args.flash_attn == "disabled": requested_attn_implementation = "eager" elif model_args.flash_attn == "sdpa": diff --git a/tests/model/model_utils/test_attention.py b/tests/model/model_utils/test_attention.py index 97ac9dcc..4cae3d7c 100644 --- a/tests/model/model_utils/test_attention.py +++ b/tests/model/model_utils/test_attention.py @@ -29,7 +29,7 @@ INFER_ARGS = { def test_attention(): - attention_available = ["off"] + attention_available = ["disabled"] if is_torch_sdpa_available(): attention_available.append("sdpa") @@ -37,7 +37,7 @@ def test_attention(): attention_available.append("fa2") llama_attention_classes = { - "off": "LlamaAttention", + "disabled": "LlamaAttention", "sdpa": "LlamaSdpaAttention", "fa2": "LlamaFlashAttention2", }