Merge pull request #6151 from hiyouga/hiyouga/fix_mllama

[model] fix mllama cross mask

Former-commit-id: 88f087c8b9cb22fa4f4e4f867ea3d71dd8606a98
This commit is contained in:
hoshi-hiyouga 2024-11-27 00:07:54 +08:00 committed by GitHub
commit 08ca40876a
2 changed files with 9 additions and 2 deletions

View File

@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch
import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq
@ -98,6 +99,12 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
feature["token_type_ids"] = token_type_ids[i]
features: Dict[str, "torch.Tensor"] = super().__call__(features)
if "cross_attention_mask" in mm_inputs: # for mllama inputs
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)
orig_len = cross_attention_mask.size(1)
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
features.update(mm_inputs)
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
features = features.data # use default_collate() instead of BatchEncoding.to()

View File

@ -241,7 +241,7 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
batch_ids: input ids of samples, shape (batch_size, seq_len)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self._validate_input(images, videos)
@ -760,7 +760,7 @@ class MllamaPlugin(BasePlugin):
max_num_tiles=max_image_tiles,
length=max(len(input_ids) for input_ids in batch_ids),
)
)
) # shape: (batch_size, length, max_num_images, max_num_tiles)
return mm_inputs