From 9aa640f27bce0c827aded79e4e1a8176fce07d1b Mon Sep 17 00:00:00 2001 From: stceum <50257864+stceum@users.noreply.github.com> Date: Mon, 24 Jun 2024 20:39:20 +0800 Subject: [PATCH 1/3] Bug Fix: `off` is parsed as `False` in yaml file, changed to `disabled` to avoid this. Former-commit-id: 3ed063f281d1c2563df1b9eb3800543208c9dc16 --- src/llamafactory/hparams/model_args.py | 2 +- src/llamafactory/hparams/parser.py | 4 ++++ src/llamafactory/model/model_utils/attention.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 996e9130..9b51c064 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -97,7 +97,7 @@ class ModelArguments: default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) - flash_attn: Literal["off", "sdpa", "fa2", "auto"] = field( + flash_attn: Literal["disabled", "sdpa", "fa2", "auto"] = field( default="auto", metadata={"help": "Enable FlashAttention for faster training and inference."}, ) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index a593bf45..9ef2d607 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -102,6 +102,10 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("Quantized model only accepts a single adapter. Merge them first.") + # In case that `flash_attn` is set to `off` in the yaml file, and parsed as `False` afterwards. + if model_args.flash_attn == False: + raise ValueError("flash_attn should be \"disabled\", \"sdpa\", \"fa2\" or \"auto\".") + def _check_extra_dependencies( model_args: "ModelArguments", diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 8ff3807b..dfd90936 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -32,7 +32,7 @@ def configure_attn_implementation(config: "PretrainedConfig", model_args: "Model if model_args.flash_attn == "auto": return - elif model_args.flash_attn == "off": + elif model_args.flash_attn == "disabled": requested_attn_implementation = "eager" elif model_args.flash_attn == "sdpa": From a9f10a9abd8f5fa54e7ac92f7502658777bb6e46 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 24 Jun 2024 21:35:34 +0800 Subject: [PATCH 2/3] Update test_attention.py Former-commit-id: a9b3d91952dd5a51ff97fbb40a2dd88885d380b8 --- tests/model/model_utils/test_attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/model/model_utils/test_attention.py b/tests/model/model_utils/test_attention.py index 97ac9dcc..4cae3d7c 100644 --- a/tests/model/model_utils/test_attention.py +++ b/tests/model/model_utils/test_attention.py @@ -29,7 +29,7 @@ INFER_ARGS = { def test_attention(): - attention_available = ["off"] + attention_available = ["disabled"] if is_torch_sdpa_available(): attention_available.append("sdpa") @@ -37,7 +37,7 @@ def test_attention(): attention_available.append("fa2") llama_attention_classes = { - "off": "LlamaAttention", + "disabled": "LlamaAttention", "sdpa": "LlamaSdpaAttention", "fa2": "LlamaFlashAttention2", } From e74fcdf7b1a58eeade554c01bc66d6e21c6cd243 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 24 Jun 2024 21:37:42 +0800 Subject: [PATCH 3/3] Update parser.py Former-commit-id: e90c424f55b17e4971f8b9d85b6aeac89bb6b98e --- src/llamafactory/hparams/parser.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 9ef2d607..a593bf45 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -102,10 +102,6 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: raise ValueError("Quantized model only accepts a single adapter. Merge them first.") - # In case that `flash_attn` is set to `off` in the yaml file, and parsed as `False` afterwards. - if model_args.flash_attn == False: - raise ValueError("flash_attn should be \"disabled\", \"sdpa\", \"fa2\" or \"auto\".") - def _check_extra_dependencies( model_args: "ModelArguments",