fix flashattn warning

This commit is contained in:
hiyouga
2023-11-10 18:34:54 +08:00
parent a0c31c68c4
commit 4bd8e3906d
2 changed files with 10 additions and 4 deletions

View File

@@ -5,11 +5,14 @@ from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
is_flash_attn_2_available = False
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
is_flash_attn_2_available = True
except ImportError:
print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
is_flash_attn_2_available = False
logger = logging.get_logger(__name__)