diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 996e9130..9b51c064 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -97,7 +97,7 @@ class ModelArguments: default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) - flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field( + flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field( default="auto", metadata={"help": "Enable FlashAttention for faster training and inference."}, ) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index a593bf45..9ef2d607 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -102,6 +102,10 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("Quantized model only accepts a single adapter. Merge them first.") + # In case that `flash_attn` is set to `off` in the yaml file, and parsed as `False` afterwards. + if model_args.flash_attn == False: + raise ValueError("flash_attn should be \"disabled\", \"sdpa\", \"fa2\" or \"auto\".") + def _check_extra_dependencies( model_args: "ModelArguments", diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 8ff3807b..dfd90936 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -32,7 +32,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model if model_args.flash_attn == "auto": return - elif model_args.flash_attn == "off": + elif model_args.flash_attn == "disabled": requested_attn_implementation = "eager" elif model_args.flash_attn == "sdpa":