This commit is contained in:
hiyouga
2024-04-01 21:35:18 +08:00
parent eb259cc573
commit aee634cd20
7 changed files with 37 additions and 16 deletions

View File

@@ -193,6 +193,6 @@ def llama_flash_attn_forward(
def apply_llama_patch() -> None:
require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1")
require_version("transformers==4.39.2", "To fix: pip install transformers==4.39.2")
LlamaAttention.forward = llama_torch_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward