mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix flashattn warning
Former-commit-id: 4bd8e3906d09bf6ec4b8f6b553a347fca9db4f80
This commit is contained in:
parent
55e097aaac
commit
982e0e79c2
@ -5,11 +5,14 @@ from typing import Optional, Tuple
|
|||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
is_flash_attn_2_available = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||||
|
is_flash_attn_2_available = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
|
is_flash_attn_2_available = False
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -123,9 +123,12 @@ def load_model_and_tokenizer(
|
|||||||
# Set FlashAttention-2
|
# Set FlashAttention-2
|
||||||
if model_args.flash_attn:
|
if model_args.flash_attn:
|
||||||
if getattr(config, "model_type", None) == "llama":
|
if getattr(config, "model_type", None) == "llama":
|
||||||
|
if LlamaPatches.is_flash_attn_2_available:
|
||||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||||
|
else:
|
||||||
|
logger.warning("FlashAttention-2 is not installed.")
|
||||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user