[misc] upgrade format to py39 (#7256)

This commit is contained in:
hoshi-hiyouga
2025-03-12 00:08:41 +08:00
committed by GitHub
parent bcd287848c
commit efa86e730c
113 changed files with 984 additions and 1407 deletions

View File

@@ -18,7 +18,7 @@
# limitations under the License.
import math
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import torch
import torch.nn as nn
@@ -54,14 +54,14 @@ def llama_attention_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(
attn_output: torch.Tensor = _flash_attention_forward(
query_states,
key_states,
value_states,
@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
is_causal=self.is_causal,
)
else:
attn_output: "torch.Tensor" = self._flash_attention_forward(
attn_output: torch.Tensor = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
)
@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
past_key_value: Optional["Cache"] = None,
output_attentions: bool = False,
cache_position: Optional["torch.LongTensor"] = None,
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
**kwargs,
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
if output_attentions:
transformers_logger.warning_once(
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
bsz, q_len, _ = hidden_states.size()
query_states: "torch.Tensor" = self.q_proj(hidden_states)
key_states: "torch.Tensor" = self.k_proj(hidden_states)
value_states: "torch.Tensor" = self.v_proj(hidden_states)
query_states: torch.Tensor = self.q_proj(hidden_states)
key_states: torch.Tensor = self.k_proj(hidden_states)
value_states: torch.Tensor = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)