The split heads should be concatenated in dim=2


Former-commit-id: a793e8456b664ea0b48f0ba162999f18d06b4c2f
This commit is contained in:
hiyouga 2024-06-11 00:37:17 +08:00
parent d984776d35
commit 2f164c2c41

View File

@ -96,7 +96,8 @@ def llama_attention_forward(
( (
attn_output[:, :, : self.num_heads // 2], attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
) ),
dim=2,
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -194,7 +195,8 @@ def llama_flash_attention_2_forward(
( (
attn_output[:, :, : self.num_heads // 2], attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
) ),
dim=2,
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@ -293,7 +295,8 @@ def llama_sdpa_attention_forward(
( (
attn_output[:, :, : self.num_heads // 2], attn_output[:, :, : self.num_heads // 2],
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
) ),
dim=2,
) )
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@ -303,7 +306,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None: def _apply_llama_patch() -> None:
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2") require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
LlamaAttention.forward = llama_attention_forward LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward