fix gemma2 attention

Former-commit-id: aeafc68e169ae0ea5939cc81cb0cf89f0ca044b6
This commit is contained in:
hiyouga
2024-07-13 23:33:45 +08:00
parent 37c075de1c
commit 9cd850c3b9
7 changed files with 53 additions and 26 deletions

View File

@@ -326,7 +326,7 @@ def llama_sdpa_attention_forward(
def _apply_llama_patch() -> None:
require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3")
require_version("transformers>=4.41.2,<=4.42.4", "To fix: pip install transformers>=4.41.2,<=4.42.4")
LlamaAttention.forward = llama_attention_forward
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
LlamaSdpaAttention.forward = llama_sdpa_attention_forward