Former-commit-id: 3b306478d4ccbf037ae1acc122f6dca11c718731
This commit is contained in:
hiyouga 2023-09-11 18:27:08 +08:00
parent 4410387859
commit 8ac7ec0b48

View File

@ -7,30 +7,27 @@
# With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17
import torch
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.configuration_llama import LlamaConfig
if TYPE_CHECKING:
from transformers.models.llama.configuration_llama import LlamaConfig
try:
from flash_attn.flash_attn_interface import (
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_kvpacked_func
)
from flash_attn.bert_padding import unpad_input, pad_input
flash_attn_v2_installed = True
print('>>>> Flash Attention installed')
from flash_attn.bert_padding import pad_input, unpad_input
print(">>>> FlashAttention installed")
except ImportError:
flash_attn_v2_installed = False
raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`')
raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention")
try:
from flash_attn.layers.rotary import apply_rotary_emb_func
flash_rope_installed = True
print('>>>> Flash RoPE installed')
print(">>>> Flash RoPE installed")
except ImportError:
flash_rope_installed = False
raise ImportError('Please install RoPE kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`')
raise ImportError("Please install RoPE kernels from https://github.com/Dao-AILab/flash-attention")
logger = logging.get_logger(__name__)