fix mllama cross_mask

Former-commit-id: 598c22e43f3f10a335933339cc612744c4835eb0
This commit is contained in:
hiyouga 2024-11-26 15:54:44 +00:00
parent e0325b1123
commit 006022cadd
2 changed files with 9 additions and 2 deletions

View File

@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch
import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
@ -98,6 +99,12 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
if "cross_attention_mask" in mm_inputs: # for mllama inputs
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)
orig_len = cross_attention_mask.size(1)
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()

View File

@ -241,7 +241,7 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
batch_ids: input ids of samples, shape (batch_size, seq_len)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos)
@ -760,7 +760,7 @@ class MllamaPlugin(BasePlugin):
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
)
) # shape: (batch_size, length, max_num_images, max_num_tiles)
return mm_inputs