mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 19:06:26 +08:00
@@ -96,7 +96,8 @@ def llama_attention_forward(
|
|||||||
(
|
(
|
||||||
attn_output[:, :, : self.num_heads // 2],
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
)
|
),
|
||||||
|
dim=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
@@ -194,7 +195,8 @@ def llama_flash_attention_2_forward(
|
|||||||
(
|
(
|
||||||
attn_output[:, :, : self.num_heads // 2],
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
)
|
),
|
||||||
|
dim=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
@@ -293,7 +295,8 @@ def llama_sdpa_attention_forward(
|
|||||||
(
|
(
|
||||||
attn_output[:, :, : self.num_heads // 2],
|
attn_output[:, :, : self.num_heads // 2],
|
||||||
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1),
|
||||||
)
|
),
|
||||||
|
dim=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
@@ -303,7 +306,7 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
|
|
||||||
def _apply_llama_patch() -> None:
|
def _apply_llama_patch() -> None:
|
||||||
require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2")
|
require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2")
|
||||||
LlamaAttention.forward = llama_attention_forward
|
LlamaAttention.forward = llama_attention_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
|
|||||||
Reference in New Issue
Block a user