mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 04:02:49 +08:00
parent
414049ba20
commit
68540734fb
@ -41,9 +41,9 @@ def llama_attention_forward(
|
|||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
@ -125,9 +125,9 @@ def llama_flash_attention_2_forward(
|
|||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
@ -233,9 +233,9 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states: "torch.Tensor" = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
key_states: "torch.Tensor" = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states: "torch.Tensor" = self.v_proj(hidden_states)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
@ -270,8 +270,9 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
causal_mask = attention_mask
|
causal_mask = attention_mask
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
causal_mask = causal_mask[:, :, :, :groupsz]
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
|
if query_states.device.type == "cuda" and causal_mask is not None:
|
||||||
query_states = query_states.contiguous()
|
query_states = query_states.contiguous()
|
||||||
key_states = key_states.contiguous()
|
key_states = key_states.contiguous()
|
||||||
value_states = value_states.contiguous()
|
value_states = value_states.contiguous()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user