fix inputs

Former-commit-id: 446441fdb0
This commit is contained in:
hiyouga
2024-11-23 18:25:45 +00:00
parent 23fc0c863e
commit e99031daa4
14 changed files with 148 additions and 95 deletions

View File

@@ -35,7 +35,7 @@ from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.packages import is_transformers_version_greater_than_4_43
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
@@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
if is_transformers_version_greater_than_4_43():
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(

View File

@@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.packages import is_transformers_version_greater_than_4_43
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
@@ -115,7 +115,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
if is_transformers_version_greater_than_4_43():
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data

View File

@@ -26,7 +26,7 @@ from ...extras import logging
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments
@@ -159,27 +159,25 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
image_seqlen = config.vision_config.num_image_tokens
else:
image_seqlen = -1
elif model_type == "mllama":
image_seqlen = (
(config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1
) * config.vision_config.max_num_tiles
return image_seqlen
def get_patch_size(config: "PretrainedConfig") -> int:
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", -1)
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy

View File

@@ -66,11 +66,11 @@ def patch_processor(
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "patch_size", get_patch_size(config))
setattr(processor, "patch_size", get_patch_size(config, processor))
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
def patch_config(