disentangle model from tuner and rename modules

This commit is contained in:
hiyouga
2023-11-15 16:29:09 +08:00
parent 2f02f688e1
commit 4736344eb1
57 changed files with 324 additions and 263 deletions

View File

@@ -3,16 +3,19 @@ import torch
import torch.nn as nn
from typing import Optional, Tuple
from transformers.utils import logging
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
is_flash_attn_2_available = False
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb
try:
from transformers.models.llama.modeling_llama import repeat_kv
except ImportError:
print("Please upgrade `transformers`.")
from llmtuner.extras.packages import is_flash_attn2_available
if is_flash_attn2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
is_flash_attn_2_available = True
except ImportError:
is_flash_attn_2_available = False
logger = logging.get_logger(__name__)