mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
@@ -35,7 +35,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
|
||||
|
||||
if is_transformers_version_greater_than_4_43():
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
attn_output: "torch.Tensor" = _flash_attention_forward(
|
||||
|
||||
@@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||
from ...extras.packages import is_transformers_version_greater_than
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -115,7 +115,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
|
||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
if is_transformers_version_greater_than_4_43():
|
||||
if is_transformers_version_greater_than("4.43.0"):
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||
|
||||
@@ -26,7 +26,7 @@ from ...extras import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
|
||||
|
||||
from ...hparams import FinetuningArguments, ModelArguments
|
||||
|
||||
@@ -159,27 +159,25 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
image_seqlen = config.vision_config.num_image_tokens
|
||||
else:
|
||||
image_seqlen = -1
|
||||
elif model_type == "mllama":
|
||||
image_seqlen = (
|
||||
(config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1
|
||||
) * config.vision_config.max_num_tiles
|
||||
|
||||
return image_seqlen
|
||||
|
||||
|
||||
def get_patch_size(config: "PretrainedConfig") -> int:
|
||||
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Computes the patch size of the vit.
|
||||
"""
|
||||
patch_size = getattr(config.vision_config, "patch_size", -1)
|
||||
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
|
||||
return patch_size
|
||||
|
||||
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Get the vision_feature_select_strategy.
|
||||
"""
|
||||
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
|
||||
vision_feature_select_strategy = getattr(
|
||||
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
|
||||
)
|
||||
return vision_feature_select_strategy
|
||||
|
||||
|
||||
|
||||
@@ -66,11 +66,11 @@ def patch_processor(
|
||||
setattr(processor, "tokenizer", tokenizer)
|
||||
setattr(processor, "image_seqlen", get_image_seqlen(config))
|
||||
setattr(processor, "image_resolution", model_args.image_resolution)
|
||||
setattr(processor, "patch_size", get_patch_size(config))
|
||||
setattr(processor, "patch_size", get_patch_size(config, processor))
|
||||
setattr(processor, "video_resolution", model_args.video_resolution)
|
||||
setattr(processor, "video_fps", model_args.video_fps)
|
||||
setattr(processor, "video_maxlen", model_args.video_maxlen)
|
||||
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
|
||||
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
|
||||
|
||||
|
||||
def patch_config(
|
||||
|
||||
Reference in New Issue
Block a user