mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix flash shift short attention
Former-commit-id: 0a356bc897690262190a8112e8ace37d349daee1
This commit is contained in:
parent
9c6d34020c
commit
180fd06e61
@ -61,7 +61,9 @@ class LlamaShiftShortAttention(LlamaAttention):
|
|||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||||
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
state = torch.cat((
|
||||||
|
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||||
|
), dim=2)
|
||||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
@ -80,7 +82,9 @@ class LlamaShiftShortAttention(LlamaAttention):
|
|||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
attn_output = torch.cat((
|
||||||
|
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||||
|
))
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
@ -151,7 +155,9 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||||
num_groups = q_len // groupsz
|
num_groups = q_len // groupsz
|
||||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||||
state[:, :, self.num_heads//2:] = state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
state = torch.cat((
|
||||||
|
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||||
|
), dim=2)
|
||||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
@ -184,7 +190,9 @@ class LlamaFlashAttention2(LlamaAttention):
|
|||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||||
attn_output[:, :, self.num_heads//2:] = attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
attn_output = torch.cat((
|
||||||
|
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||||
|
))
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user