From 180fd06e61fed232bf293fc1986a36db53e82a5e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 9 Oct 2023 17:54:48 +0800 Subject: [PATCH] fix flash shift short attention Former-commit-id: 0a356bc897690262190a8112e8ace37d349daee1 --- src/llmtuner/extras/patches/llama_patch.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index 046fe094..e516df76 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -61,7 +61,9 @@ class LlamaShiftShortAttention(LlamaAttention): num_groups = q_len // groupsz def shift(state: torch.Tensor) -> torch.Tensor: state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) - state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + state = torch.cat(( + state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + ), dim=2) return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) @@ -80,7 +82,9 @@ class LlamaShiftShortAttention(LlamaAttention): if getattr(self.config, "group_size_ratio", None) and self.training: # shift back attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) - attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) + attn_output = torch.cat(( + attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) + )) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -151,7 +155,9 @@ class LlamaFlashAttention2(LlamaAttention): assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) num_groups = q_len // groupsz def shift(state: torch.Tensor) -> torch.Tensor: - state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + state = torch.cat(( + state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) + ), dim=2) return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) @@ -184,7 +190,9 @@ class LlamaFlashAttention2(LlamaAttention): if getattr(self.config, "group_size_ratio", None) and self.training: # shift back attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) - attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) + attn_output = torch.cat(( + attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) + )) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output)