This commit is contained in:
hiyouga
2024-08-05 23:48:19 +08:00
parent c2921b9960
commit b7ca6c8dc1
13 changed files with 111 additions and 69 deletions

View File

@@ -36,7 +36,7 @@ def configure_attn_implementation(
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
if is_flash_attn_2_available():
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
require_version("flash_attn>=2.6.0", "To fix: pip install flash_attn>=2.6.0")
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2"
else:

View File

@@ -35,6 +35,7 @@ from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING:
@@ -50,14 +51,15 @@ transformers_logger = logging.get_logger(__name__)
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_attention_forward(
self: "LlamaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
@@ -68,7 +70,11 @@ def llama_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -130,14 +136,15 @@ def llama_attention_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_flash_attention_2_forward(
self: "LlamaFlashAttention2",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
@@ -151,7 +158,11 @@ def llama_flash_attention_2_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -198,9 +209,24 @@ def llama_flash_attention_2_forward(
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if is_transformers_version_greater_than_4_43():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_states.size(1),
dropout=dropout_rate,
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)
else:
attn_output: "torch.Tensor" = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
@@ -225,14 +251,15 @@ def llama_flash_attention_2_forward(
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
def llama_sdpa_attention_forward(
self: "LlamaSdpaAttention",
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
hidden_states: "torch.Tensor",
attention_mask: Optional["torch.Tensor"] = None,
position_ids: Optional["torch.LongTensor"] = None,
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
if output_attentions:
transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
@@ -258,7 +285,11 @@ def llama_sdpa_attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
if position_embeddings is None:
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
@@ -322,7 +353,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward

View File

@@ -41,11 +41,11 @@ from typing import TYPE_CHECKING, Tuple
import torch
import torch.nn.functional as F
import transformers.models
from transformers.utils.versions import require_version
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.logging import get_logger
from ...extras.packages import is_transformers_version_greater_than_4_43
if TYPE_CHECKING:
@@ -114,7 +114,15 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
require_version("transformers>=4.41.2,<=4.43.4", "To fix: pip install transformers>=4.41.2,<=4.43.4")
if is_transformers_version_greater_than_4_43():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
return
import transformers.models
if model_type == "cohere":
transformers.models.cohere.modeling_cohere._get_unpad_data = get_unpad_data
elif model_type == "falcon":