mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 11:12:50 +08:00
tiny fix
Former-commit-id: 3f24337a8a995b145b1e8075bc23878eaa363844
This commit is contained in:
parent
d6632fefc9
commit
e3baa5aa08
@ -182,11 +182,9 @@ def llama_flash_attention_2_forward(
|
|||||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||||
else:
|
|
||||||
groupsz = q_len
|
|
||||||
|
|
||||||
attn_output: torch.Tensor = self._flash_attention_forward(
|
attn_output: torch.Tensor = self._flash_attention_forward(
|
||||||
query_states, key_states, value_states, attention_mask, groupsz, dropout=dropout_rate
|
query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||||
|
Loading…
x
Reference in New Issue
Block a user