mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
fix inputs
This commit is contained in:
@@ -35,7 +35,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
|
||||
if is_transformers_version_greater_than_4_43():
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
attn_output: "torch.Tensor" = _flash_attention_forward(
|
||||
|
||||
Reference in New Issue
Block a user