mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
303 lines
12 KiB
Python
303 lines
12 KiB
Python
# coding=utf-8
|
|
# Modified from:
|
|
# [1] https://huggingface.co/Birchlabs/flash_llama/blob/main/modeling_flash_llama.py
|
|
# [2] https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama2_flash_attn_monkey_patch.py
|
|
# [3] https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py
|
|
# [4] https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
|
# With fix from Alex Birch: https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/discussions/17
|
|
|
|
import torch
|
|
from typing import TYPE_CHECKING, Optional, Tuple
|
|
from transformers.utils import logging
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
|
|
|
try:
|
|
from flash_attn.flash_attn_interface import (
|
|
flash_attn_kvpacked_func,
|
|
flash_attn_varlen_kvpacked_func
|
|
)
|
|
from flash_attn.bert_padding import pad_input, unpad_input
|
|
print(">>>> FlashAttention installed")
|
|
except ImportError:
|
|
raise ImportError("Please install FlashAttention from https://github.com/Dao-AILab/flash-attention")
|
|
|
|
try:
|
|
from flash_attn.layers.rotary import apply_rotary_emb_func
|
|
print(">>>> Flash RoPE installed")
|
|
except ImportError:
|
|
raise ImportError("Please install RoPE kernels from https://github.com/Dao-AILab/flash-attention")
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
class LlamaRMSNorm(torch.nn.Module):
|
|
|
|
def __init__(self, hidden_size, eps=1e-6):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
|
self.variance_epsilon = eps
|
|
|
|
def forward(self, hidden_states):
|
|
input_dtype = hidden_states.dtype
|
|
hidden_states = hidden_states.to(torch.float32)
|
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
return (self.weight * hidden_states).to(input_dtype) # for fp32 weight
|
|
|
|
|
|
class FlashRotaryEmbedding(torch.nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
base=10000.0,
|
|
interleaved=False,
|
|
scale_base=None,
|
|
scaling_factor=1.0,
|
|
pos_idx_in_fp32=True,
|
|
device=None
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.base = float(base)
|
|
self.pos_idx_in_fp32 = pos_idx_in_fp32
|
|
# Generate and save the inverse frequency buffer (non trainable)
|
|
inv_freq = self._compute_inv_freq(device)
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
self.interleaved = interleaved
|
|
self.scale_base = scale_base
|
|
self.scaling_factor = scaling_factor
|
|
scale = (
|
|
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
|
if scale_base is not None else None
|
|
)
|
|
self.register_buffer("scale", scale)
|
|
|
|
self._seq_len_cached = 0
|
|
self._cos_cached = None
|
|
self._sin_cached = None
|
|
self._cos_k_cached = None
|
|
self._sin_k_cached = None
|
|
|
|
def _compute_inv_freq(self, device=None):
|
|
return 1 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
|
|
|
|
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
|
if (
|
|
seqlen > self._seq_len_cached or self._cos_cached.device != device
|
|
or self._cos_cached.dtype != dtype
|
|
or (self.training and self._cos_cached.is_inference())
|
|
):
|
|
self._seq_len_cached = seqlen
|
|
if self.pos_idx_in_fp32:
|
|
t = torch.arange(seqlen, device=device, dtype=torch.float32)
|
|
t /= self.scaling_factor
|
|
if self.inv_freq.dtype != torch.float32:
|
|
inv_freq = self.inv_freq.to(torch.float32)
|
|
else:
|
|
inv_freq = self.inv_freq
|
|
else:
|
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
|
t /= self.scaling_factor
|
|
inv_freq = self.inv_freq
|
|
freqs = torch.outer(t, inv_freq)
|
|
if self.scale is None:
|
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
|
else:
|
|
power = (
|
|
(torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base
|
|
)
|
|
scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
|
|
# We want the multiplication by scale to happen in fp32
|
|
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
|
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
|
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
|
|
|
|
def forward(self, q: torch.Tensor, k: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
r"""
|
|
q: (batch, seqlen, nheads, headdim)
|
|
k: (batch, seqlen, nheads, headdim)
|
|
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
|
token in the batch.
|
|
"""
|
|
self._update_cos_sin_cache(q.shape[1] + seqlen_offset, device=q.device, dtype=q.dtype)
|
|
if self.scale is None:
|
|
return apply_rotary_emb_func(
|
|
q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
|
self.interleaved, True # inplace=True
|
|
), apply_rotary_emb_func(
|
|
k, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
|
|
self.interleaved, True # inplace=True
|
|
)
|
|
else:
|
|
assert False
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
r"""
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
"""
|
|
batch, slen, _, num_key_value_heads, head_dim = hidden_states.shape
|
|
if n_rep == 1:
|
|
return hidden_states
|
|
hidden_states = hidden_states[:, :, :, :, None, :].expand(batch, slen, 2, num_key_value_heads, n_rep, head_dim)
|
|
return hidden_states.reshape(batch, slen, 2, num_key_value_heads * n_rep, head_dim)
|
|
|
|
|
|
class LlamaAttention(torch.nn.Module):
|
|
|
|
def __init__(self, config: "LlamaConfig"):
|
|
super().__init__()
|
|
self.config = config
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.num_key_value_heads = config.num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
raise ValueError(
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
f" and `num_heads`: {self.num_heads})."
|
|
)
|
|
|
|
self.q_proj = torch.nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
self.k_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
self.v_proj = torch.nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
|
self.register_buffer(
|
|
"norm_factor",
|
|
torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
|
|
persistent=False,
|
|
)
|
|
|
|
if self.config.rope_scaling is None:
|
|
scaling_factor = 1
|
|
else:
|
|
scaling_type = self.config.rope_scaling["type"]
|
|
scaling_factor = self.config.rope_scaling["factor"]
|
|
assert scaling_type == "linear"
|
|
|
|
self.rotary_emb = FlashRotaryEmbedding(
|
|
self.head_dim, base=10000, interleaved=False, scaling_factor=scaling_factor
|
|
)
|
|
|
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
bsz, q_len, h_size = hidden_states.size()
|
|
|
|
has_layer_past = past_key_value is not None
|
|
|
|
if has_layer_past:
|
|
past_kv = past_key_value[0]
|
|
past_len = past_key_value[1]
|
|
else:
|
|
past_len = 0
|
|
|
|
q = self.q_proj(hidden_states)
|
|
k = self.k_proj(hidden_states)
|
|
v = self.v_proj(hidden_states)
|
|
|
|
q = q.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
|
v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
|
|
|
q, k = self.rotary_emb(q, k, past_len)
|
|
|
|
kv = torch.stack([k, v], 2)
|
|
kv = repeat_kv(kv, self.num_key_value_groups)
|
|
|
|
# Cache QKV values
|
|
if has_layer_past:
|
|
new_len = past_len+q.size(1)
|
|
if new_len > past_kv.size(1):
|
|
past_kv = torch.cat(
|
|
[past_kv, torch.empty(bsz, 256, 2, kv.size(3), kv.size(4), dtype=kv.dtype, device=kv.device)], 1
|
|
)
|
|
past_kv[:, past_len:new_len] = kv
|
|
kv = past_kv[:, :new_len]
|
|
else:
|
|
past_kv = kv
|
|
|
|
past_key_value = (past_kv, past_len + q.size(1)) if use_cache else None
|
|
|
|
if attention_mask is not None:
|
|
# varlen, ignore padding tokens, efficient for large batch with many paddings
|
|
logger.warning_once("padded sequences is less efficient")
|
|
|
|
unpadded_kv, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(kv, attention_mask)
|
|
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask[:, -q.size(1):])
|
|
attn_outputs = flash_attn_varlen_kvpacked_func(
|
|
unpadded_q, unpadded_kv, cu_seqlens_q, cu_seqlens_k,
|
|
max_seqlen_q, max_seqlen_k,
|
|
dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
|
|
causal=(not has_layer_past), return_attn_probs=output_attentions
|
|
)
|
|
|
|
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
|
attn_output = pad_input(
|
|
attn_output, indices_q, bsz, q_len
|
|
).reshape(bsz, q_len, h_size)
|
|
attn_weights = attn_outputs[2] if output_attentions else None
|
|
|
|
else:
|
|
# no padding tokens, more efficient
|
|
attn_outputs = flash_attn_kvpacked_func(
|
|
q, kv, dropout_p=0.0, softmax_scale=1.0/self.norm_factor,
|
|
causal=(not has_layer_past), return_attn_probs=output_attentions
|
|
)
|
|
attn_output = attn_outputs[0] if output_attentions else attn_outputs
|
|
attn_output = attn_output.reshape(bsz, q_len, h_size)
|
|
attn_weights = attn_outputs[2] if output_attentions else None
|
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
|
# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
|
|
def _prepare_decoder_attention_mask(
|
|
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
|
):
|
|
# [bsz, seq_len]
|
|
if past_key_values_length > 0 and attention_mask is not None:
|
|
attention_mask = torch.cat(
|
|
(
|
|
torch.full(
|
|
(input_shape[0], past_key_values_length),
|
|
True,
|
|
dtype=attention_mask.dtype,
|
|
device=attention_mask.device
|
|
),
|
|
attention_mask
|
|
),
|
|
dim=-1
|
|
)
|
|
|
|
if attention_mask is not None and torch.all(attention_mask):
|
|
return None # This uses the faster call when training with full samples
|
|
|
|
return attention_mask
|