From 2f164c2c41993a93129edef5ce692db483d31423 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 11 Jun 2024 00:37:17 +0800 Subject: [PATCH] fix #4160 The split heads should be concatenated in dim=2 Former-commit-id: a793e8456b664ea0b48f0ba162999f18d06b4c2f --- src/llamafactory/model/model_utils/longlora.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index c8dc52f5..cd468979 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -96,7 +96,8 @@ def llama_attention_forward( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), - ) + ), + dim=2, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -194,7 +195,8 @@ def llama_flash_attention_2_forward( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), - ) + ), + dim=2, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() @@ -293,7 +295,8 @@ def llama_sdpa_attention_forward( ( attn_output[:, :, : self.num_heads // 2], attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), - ) + ), + dim=2, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -303,7 +306,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers==4.40.2", "To fix: pip install transformers==4.40.2") + require_version("transformers==4.41.2", "To fix: pip install transformers==4.41.2") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward