diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 15a3f36d..829599c3 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -128,8 +128,8 @@ def load_model_and_tokenizer( LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask logger.info("Using FlashAttention-2 for faster training and inference.") - elif getattr(config, "model_type", None) == "qwen": - logger.info("Qwen models automatically enable FlashAttention if installed.") + elif getattr(config, "model_type", None) in ["qwen", "Yi"]: + logger.info("Current model automatically enables FlashAttention if installed.") else: logger.warning("Current model does not support FlashAttention-2.") elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":