mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
Merge pull request #4446 from stceum/bug-fix
Bug Fix: `off` is parsed as `False` in yaml file
Former-commit-id: cc452c32c7
This commit is contained in:
@@ -97,7 +97,7 @@ class ModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field(
|
||||
flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||
)
|
||||
|
||||
@@ -32,7 +32,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model
|
||||
if model_args.flash_attn == "auto":
|
||||
return
|
||||
|
||||
elif model_args.flash_attn == "off":
|
||||
elif model_args.flash_attn == "disabled":
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
|
||||
Reference in New Issue
Block a user