From ca736bcab72c73674e6437a7894f2154fccde82b Mon Sep 17 00:00:00 2001 From: Amirreza A Date: Sat, 28 Sep 2024 19:03:36 +0330 Subject: [PATCH 1/2] made a small change to a warning about fa2 for gemma2 models. Former-commit-id: e0695a026d822c896cb4f5b33e0c4f88441d75e9 --- src/llamafactory/model/model_utils/attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 96e2c8a9..dfb42a9f 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -37,7 +37,10 @@ def configure_attn_implementation( if is_flash_attn_2_available(): require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") - logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") + + if model_args.flash_attn != "fa2": + logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") + model_args.flash_attn = "fa2" else: logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.") From 1ded3abdf17e681116e11cc77cbcee9b0cfcc16d Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 29 Sep 2024 10:47:41 +0800 Subject: [PATCH 2/2] Update attention.py Former-commit-id: 2adf79c195053bb4541e0317573a2c89da28b5bc --- src/llamafactory/model/model_utils/attention.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index dfb42a9f..7667b069 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -37,13 +37,11 @@ def configure_attn_implementation( if is_flash_attn_2_available(): require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") - if model_args.flash_attn != "fa2": logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") - - model_args.flash_attn = "fa2" + model_args.flash_attn = "fa2" else: - logger.warning("Gemma-2 should use eager attention, change `flash_attn` to disabled.") + logger.warning("FlashAttention-2 is not installed, use eager attention.") model_args.flash_attn = "disabled" elif model_args.flash_attn == "sdpa": logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")