From 8ac7ec0b48de9a0013e4892b67abb6852845a3d9 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 11 Sep 2023 18:27:08 +0800 Subject: [PATCH] tiny fix Former-commit-id: 3b306478d4ccbf037ae1acc122f6dca11c718731 --- src/llmtuner/extras/models/flash_llama.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/llmtuner/extras/models/flash_llama.py b/src/llmtuner/extras/models/flash_llama.py index 670c3e8f..d6c078bd 100644 --- a/src/llmtuner/extras/models/flash_llama.py +++ b/src/llmtuner/extras/models/flash_llama.py @@ -7,30 +7,27 @@ # With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17 import torch -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple from transformers.utils import logging -from transformers.models.llama.configuration_llama import LlamaConfig +if TYPE_CHECKING: + from transformers.models.llama.configuration_llama import LlamaConfig try: from flash_attn.flash_attn_interface import ( flash_attn_kvpacked_func, - flash_attn_varlen_kvpacked_func, + flash_attn_varlen_kvpacked_func ) - from flash_attn.bert_padding import unpad_input, pad_input - flash_attn_v2_installed = True - print('>>>> Flash Attention installed') + from flash_attn.bert_padding import pad_input, unpad_input + print(">>>> FlashAttention installed") except ImportError: - flash_attn_v2_installed = False - raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`') + raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention") try: from flash_attn.layers.rotary import apply_rotary_emb_func - flash_rope_installed = True - print('>>>> Flash RoPE installed') + print(">>>> Flash RoPE installed") except ImportError: - flash_rope_installed = False - raise ImportError('Please install RoPE kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`') + raise ImportError("Please install RoPE kernels from https://github.com/Dao-AILab/flash-attention") logger = logging.get_logger(__name__)