mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
tiny fix
Former-commit-id: 3b306478d4ccbf037ae1acc122f6dca11c718731
This commit is contained in:
parent
4410387859
commit
8ac7ec0b48
@ -7,30 +7,27 @@
|
||||
# With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17
|
||||
|
||||
import torch
|
||||
from typing import Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from transformers.utils import logging
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func
|
||||
)
|
||||
from flash_attn.bert_padding import unpad_input, pad_input
|
||||
flash_attn_v2_installed = True
|
||||
print('>>>> Flash Attention installed')
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
print(">>>> FlashAttention installed")
|
||||
except ImportError:
|
||||
flash_attn_v2_installed = False
|
||||
raise ImportError('Please install Flash Attention: `pip install flash-attn --no-build-isolation`')
|
||||
raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention")
|
||||
|
||||
try:
|
||||
from flash_attn.layers.rotary import apply_rotary_emb_func
|
||||
flash_rope_installed = True
|
||||
print('>>>> Flash RoPE installed')
|
||||
print(">>>> Flash RoPE installed")
|
||||
except ImportError:
|
||||
flash_rope_installed = False
|
||||
raise ImportError('Please install RoPE kernels: `pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary`')
|
||||
raise ImportError("Please install RoPE kernels from https://github.com/Dao-AILab/flash-attention")
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
Loading…
x
Reference in New Issue
Block a user