fix inputs

Former-commit-id: 446441fdb020b5a102480251cb8536dd8b3f8f99
This commit is contained in:
hiyouga 2024-11-23 18:25:45 +00:00
parent 23fc0c863e
commit e99031daa4
14 changed files with 148 additions and 95 deletions

3
.gitignore vendored
View File

@ -159,6 +159,9 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# vscode
.vscode/
# custom .gitignore
ms_cache/
hf_cache/

View File

@ -186,6 +186,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |

View File

@ -187,6 +187,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
| [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 |
| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama |
| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava |
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |

View File

@ -164,7 +164,7 @@ class HuggingfaceEngine(BaseEngine):
logits_processor=get_logits_processor(),
)
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor)
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
for key, value in mm_inputs.items():
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs
value = torch.stack(value) # assume they have same sizes

View File

@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_videos.extend(videos)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_seqlens.append(len(feature["input_ids"]))
batch_input_ids.append(feature["input_ids"])
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")

View File

@ -9,7 +9,7 @@ from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
if is_pillow_available():
@ -21,6 +21,13 @@ if is_pyav_available():
import av
if is_transformers_version_greater_than("4.45.0"):
from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask,
)
if TYPE_CHECKING:
from av.stream import Stream
from transformers import PreTrainedTokenizer, ProcessorMixin
@ -75,8 +82,8 @@ class BasePlugin:
Pre-processes a single image.
"""
image_resolution: int = kwargs.get("image_resolution")
if max(image.width, image.height) > image_resolution:
resize_factor = image_resolution / max(image.width, image.height)
if image.width * image.height > image_resolution:
resize_factor = math.sqrt(image_resolution / (image.width * image.height))
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
image = image.resize((width, height), resample=Image.NEAREST)
@ -165,15 +172,15 @@ class BasePlugin:
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512),
image_resolution=getattr(processor, "image_resolution", 512 * 512),
)
input_dict["images"] = images
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128),
video_fps=getattr(processor, "video_fps", 1.0),
image_resolution=getattr(processor, "video_resolution", 128 * 128),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
)
input_dict["videos"] = videos
@ -223,7 +230,7 @@ class BasePlugin:
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
r"""
@ -234,7 +241,7 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
seqlens: number of tokens in each sample, shape (batch_size,)
batch_ids: input ids of samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos)
@ -258,12 +265,12 @@ class LlavaPlugin(BasePlugin):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@ -274,7 +281,7 @@ class LlavaPlugin(BasePlugin):
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
@ -296,23 +303,27 @@ class LlavaNextPlugin(BasePlugin):
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "image_sizes" in mm_inputs:
image_sizes = iter(mm_inputs["image_sizes"])
if "pixel_values" in mm_inputs:
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while self.image_token in content:
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
@ -322,12 +333,11 @@ class LlavaNextPlugin(BasePlugin):
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
res = self._get_mm_inputs(images, videos, processor)
return res
return self._get_mm_inputs(images, videos, processor)
class LlavaNextVideoPlugin(BasePlugin):
@ -340,8 +350,7 @@ class LlavaNextVideoPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
if "pixel_values" in mm_inputs:
@ -349,15 +358,15 @@ class LlavaNextVideoPlugin(BasePlugin):
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
for message in messages:
content = message["content"]
while self.image_token in content:
while IMAGE_PLACEHOLDER in content:
image_size = next(image_sizes)
orig_height, orig_width = image_size
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
if processor.vision_feature_select_strategy == "default":
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
message["content"] = content.replace("{{image}}", self.image_token)
@ -367,19 +376,19 @@ class LlavaNextVideoPlugin(BasePlugin):
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
for message in messages:
content = message["content"]
while self.video_token in content:
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1)
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens")
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@ -390,7 +399,7 @@ class LlavaNextVideoPlugin(BasePlugin):
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
@ -418,7 +427,7 @@ class PaliGemmaPlugin(BasePlugin):
message["content"] = content.replace("{{image}}", "")
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@ -449,10 +458,11 @@ class PaliGemmaPlugin(BasePlugin):
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
seqlens = [len(input_ids) for input_ids in batch_ids]
mm_inputs = self._get_mm_inputs(images, videos, processor)
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
return mm_inputs
@ -481,7 +491,7 @@ class PixtralPlugin(BasePlugin):
content = message["content"]
while IMAGE_PLACEHOLDER in content:
if image_input_sizes is None:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
raise ValueError("Cannot get image input sizes.")
image_size = image_input_sizes[0][num_image_tokens]
height, width = image_size
@ -497,7 +507,7 @@ class PixtralPlugin(BasePlugin):
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@ -508,7 +518,7 @@ class PixtralPlugin(BasePlugin):
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
@ -592,10 +602,10 @@ class Qwen2vlPlugin(BasePlugin):
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens")
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@ -606,7 +616,7 @@ class Qwen2vlPlugin(BasePlugin):
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
@ -623,42 +633,45 @@ class VideoLlavaPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
num_image_tokens, num_video_tokens = 0, 0
messages = deepcopy(messages)
mm_inputs = self._get_mm_inputs(images, videos, processor)
num_frames = 0
exist_images = "pixel_values_images" in mm_inputs
exist_videos = "pixel_values_videos" in mm_inputs
if exist_videos or exist_images:
if exist_images:
has_images = "pixel_values_images" in mm_inputs
has_videos = "pixel_values_videos" in mm_inputs
if has_images or has_videos:
if has_images:
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
num_frames = 1
if exist_videos:
if has_videos:
pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
height, width = get_image_size(pixel_values_video[0])
num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
video_seqlen = image_seqlen * num_frames
if processor.vision_feature_select_strategy == "default":
if getattr(processor, "vision_feature_select_strategy") == "default":
image_seqlen -= 1
for message in messages:
content = message["content"]
while self.image_token in content:
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(self.image_token, "{{image}}", 1)
while self.video_token in content:
num_video_tokens += 1
content = content.replace(self.video_token, "{{video}}", 1)
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
content = content.replace("{{image}}", self.image_token * image_seqlen)
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
while VIDEO_PLACEHOLDER in content:
num_video_tokens += 1
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
content = content.replace("{{image}}", self.image_token)
message["content"] = content.replace("{{video}}", self.video_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {self.image_token} tokens")
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens")
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
return messages
@ -669,7 +682,7 @@ class VideoLlavaPlugin(BasePlugin):
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
@ -686,30 +699,67 @@ class MllamaPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
content = content.replace(IMAGE_PLACEHOLDER, "<|image|>", 1)
message["content"] = content
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
) -> Dict[str, "torch.Tensor"]:
r"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
return image_processor([[image] for image in images], return_tensors="pt")
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
seqlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._get_mm_inputs(images, videos, processor)
if images is not None:
images = [Image.open(image) if isinstance(image, str) else image for image in images]
image_features = processor.image_processor(images)
_ = image_features.pop("num_tiles")
image_features = {k: v if isinstance(v, torch.Tensor) else torch.tensor(v) for k, v in image_features.items()}
return image_features
self._validate_input(images, videos)
if len(images) != len(batch_ids):
raise ValueError("Mllama only supports one image per sample.")
mm_inputs = self._get_mm_inputs(images, videos, processor)
num_tiles = mm_inputs.pop("num_tiles")
image_token_id = getattr(processor, "image_token_id")
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
cross_attention_token_mask = [
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
]
mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense(
cross_attention_token_mask,
num_tiles=num_tiles,
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
return mm_inputs
PLUGINS = {

View File

@ -785,7 +785,7 @@ _register_template(
stop_words=["<|eot_id|>"],
replace_eos=True,
replace_jinja_template=False,
mm_plugin=get_mm_plugin(name="mllama", image_token="<image>"),
mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"),
)

View File

@ -861,7 +861,7 @@ register_model_group(
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct",
},
"LlamaVision3.2-90B-Instruct": {
"Llama-3.2-90B-Vision-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct",
},

View File

@ -75,8 +75,8 @@ def is_starlette_available():
@lru_cache
def is_transformers_version_greater_than_4_43():
return _get_package_version("transformers") >= version.parse("4.43.0")
def is_transformers_version_greater_than(content: str):
return _get_package_version("transformers") >= version.parse(content)
@lru_cache

View File

@ -59,12 +59,12 @@ class ProcessorArguments:
"""
image_resolution: int = field(
default=512,
metadata={"help": "Keeps the height or width of image below this resolution."},
default=512 * 512,
metadata={"help": "Keeps the number of pixels of image below this resolution."},
)
video_resolution: int = field(
default=128,
metadata={"help": "Keeps the height or width of video below this resolution."},
default=128 * 128,
metadata={"help": "Keeps the number of pixels of video below this resolution."},
)
video_fps: float = field(
default=2.0,

View File

@ -35,7 +35,7 @@ from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
from ...extras.packages import is_transformers_version_greater_than_4_43
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
@ -209,7 +209,7 @@ def llama_flash_attention_2_forward(
if attention_mask is not None:
attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1)
if is_transformers_version_greater_than_4_43():
if is_transformers_version_greater_than("4.43.0"):
from transformers.modeling_flash_attention_utils import _flash_attention_forward
attn_output: "torch.Tensor" = _flash_attention_forward(

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
from ...extras import logging
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
from ...extras.packages import is_transformers_version_greater_than_4_43
from ...extras.packages import is_transformers_version_greater_than
if TYPE_CHECKING:
@ -115,7 +115,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
def _patch_for_block_diag_attn(model_type: str) -> None:
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
if is_transformers_version_greater_than_4_43():
if is_transformers_version_greater_than("4.43.0"):
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data

View File

@ -26,7 +26,7 @@ from ...extras import logging
if TYPE_CHECKING:
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin
from ...hparams import FinetuningArguments, ModelArguments
@ -159,27 +159,25 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
image_seqlen = config.vision_config.num_image_tokens
else:
image_seqlen = -1
elif model_type == "mllama":
image_seqlen = (
(config.vision_config.image_size // config.vision_config.patch_size) ** 2 + 1
) * config.vision_config.max_num_tiles
return image_seqlen
def get_patch_size(config: "PretrainedConfig") -> int:
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Computes the patch size of the vit.
"""
patch_size = getattr(config.vision_config, "patch_size", -1)
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
return patch_size
def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int:
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
r"""
Get the vision_feature_select_strategy.
"""
vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default")
vision_feature_select_strategy = getattr(
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
)
return vision_feature_select_strategy

View File

@ -66,11 +66,11 @@ def patch_processor(
setattr(processor, "tokenizer", tokenizer)
setattr(processor, "image_seqlen", get_image_seqlen(config))
setattr(processor, "image_resolution", model_args.image_resolution)
setattr(processor, "patch_size", get_patch_size(config))
setattr(processor, "patch_size", get_patch_size(config, processor))
setattr(processor, "video_resolution", model_args.video_resolution)
setattr(processor, "video_fps", model_args.video_fps)
setattr(processor, "video_maxlen", model_args.video_maxlen)
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config))
setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor))
def patch_config(