From 53fcc531b55baa674b227da079eed76228e755b1 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Nov 2023 16:28:21 +0800 Subject: [PATCH] update info Former-commit-id: 17c64a05796cac70fc76ed728705cd60efa41cae --- src/llmtuner/tuner/core/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 15a3f36d..829599c3 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -128,8 +128,8 @@ def load_model_and_tokenizer( LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask logger.info("Using FlashAttention-2 for faster training and inference.") - elif getattr(config, "model_type", None) == "qwen": - logger.info("Qwen models automatically enable FlashAttention if installed.") + elif getattr(config, "model_type", None) in ["qwen", "Yi"]: + logger.info("Current model automatically enables FlashAttention if installed.") else: logger.warning("Current model does not support FlashAttention-2.") elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":