mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -54,14 +54,14 @@ def llama_attention_forward(
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
query_states: torch.Tensor = self.q_proj(hidden_states)
|
||||
key_states: torch.Tensor = self.k_proj(hidden_states)
|
||||
value_states: torch.Tensor = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -139,17 +139,17 @@ def llama_flash_attention_2_forward(
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
query_states: torch.Tensor = self.q_proj(hidden_states)
|
||||
key_states: torch.Tensor = self.k_proj(hidden_states)
|
||||
value_states: torch.Tensor = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
attn_output: "torch.Tensor" = _flash_attention_forward(
|
||||
attn_output: torch.Tensor = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
@@ -221,7 +221,7 @@ def llama_flash_attention_2_forward(
|
||||
is_causal=self.is_causal,
|
||||
)
|
||||
else:
|
||||
attn_output: "torch.Tensor" = self._flash_attention_forward(
|
||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||
)
|
||||
|
||||
@@ -254,9 +254,9 @@ def llama_sdpa_attention_forward(
|
||||
past_key_value: Optional["Cache"] = None,
|
||||
output_attentions: bool = False,
|
||||
cache_position: Optional["torch.LongTensor"] = None,
|
||||
position_embeddings: Optional[Tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
position_embeddings: Optional[tuple["torch.Tensor", "torch.Tensor"]] = None,
|
||||
**kwargs,
|
||||
) -> Tuple["torch.Tensor", Optional["torch.Tensor"], Optional[Tuple["torch.Tensor"]]]:
|
||||
) -> tuple["torch.Tensor", Optional["torch.Tensor"], Optional[tuple["torch.Tensor"]]]:
|
||||
if output_attentions:
|
||||
transformers_logger.warning_once(
|
||||
"SDPA does not support `output_attentions=True`. Falling back to the vanilla attention"
|
||||
@@ -274,9 +274,9 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||
query_states: torch.Tensor = self.q_proj(hidden_states)
|
||||
key_states: torch.Tensor = self.k_proj(hidden_states)
|
||||
value_states: torch.Tensor = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
Reference in New Issue
Block a user