mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Merge pull request #4446 from stceum/bug-fix
Bug Fix: `off` is parsed as `False` in yaml file Former-commit-id: cc452c32c7f882c7f024a9d229352918a9eaa925
This commit is contained in:
commit
fe407e8de6
@ -97,7 +97,7 @@ class ModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
)
|
)
|
||||||
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field(
|
||||||
default="auto",
|
default="auto",
|
||||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||||
)
|
)
|
||||||
|
@ -32,7 +32,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
|||||||
if model_args.flash_attn == "auto":
|
if model_args.flash_attn == "auto":
|
||||||
return
|
return
|
||||||
|
|
||||||
elif model_args.flash_attn == "off":
|
elif model_args.flash_attn == "disabled":
|
||||||
requested_attn_implementation = "eager"
|
requested_attn_implementation = "eager"
|
||||||
|
|
||||||
elif model_args.flash_attn == "sdpa":
|
elif model_args.flash_attn == "sdpa":
|
||||||
|
@ -29,7 +29,7 @@ INFER_ARGS = {
|
|||||||
|
|
||||||
|
|
||||||
def test_attention():
|
def test_attention():
|
||||||
attention_available = ["off"]
|
attention_available = ["disabled"]
|
||||||
if is_torch_sdpa_available():
|
if is_torch_sdpa_available():
|
||||||
attention_available.append("sdpa")
|
attention_available.append("sdpa")
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ def test_attention():
|
|||||||
attention_available.append("fa2")
|
attention_available.append("fa2")
|
||||||
|
|
||||||
llama_attention_classes = {
|
llama_attention_classes = {
|
||||||
"off": "LlamaAttention",
|
"disabled": "LlamaAttention",
|
||||||
"sdpa": "LlamaSdpaAttention",
|
"sdpa": "LlamaSdpaAttention",
|
||||||
"fa2": "LlamaFlashAttention2",
|
"fa2": "LlamaFlashAttention2",
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user