fix inputs

This commit is contained in:
hiyouga
2024-11-23 18:25:45 +00:00
parent b1e43e56db
commit 446441fdb0
14 changed files with 148 additions and 95 deletions

View File

@@ -35,7 +35,7 @@ from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.packages import is_transformers_version_greater_than_4_43
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
if is_transformers_version_greater_than_4_43():
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(