mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
fix mllama cross_mask
Former-commit-id: 598c22e43f3f10a335933339cc612744c4835eb0
This commit is contained in:
parent
e0325b1123
commit
006022cadd
@ -19,6 +19,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@ -98,6 +99,12 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
feature["token_type_ids"] = token_type_ids[i]
|
||||
|
||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||
if "cross_attention_mask" in mm_inputs: # for mllama inputs
|
||||
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||
seq_len = features["input_ids"].size(1)
|
||||
orig_len = cross_attention_mask.size(1)
|
||||
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
|
||||
|
||||
features.update(mm_inputs)
|
||||
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||
|
@ -241,7 +241,7 @@ class BasePlugin:
|
||||
videos: a list of video inputs, shape (num_videos,)
|
||||
imglens: number of images in each sample, shape (batch_size,)
|
||||
vidlens: number of videos in each sample, shape (batch_size,)
|
||||
batch_ids: input ids of samples, shape (batch_size, seq_len)
|
||||
batch_ids: token ids of input samples, shape (batch_size, seq_len)
|
||||
processor: a processor for pre-processing images and videos
|
||||
"""
|
||||
self._validate_input(images, videos)
|
||||
@ -760,7 +760,7 @@ class MllamaPlugin(BasePlugin):
|
||||
max_num_tiles=max_image_tiles,
|
||||
length=max(len(input_ids) for input_ids in batch_ids),
|
||||
)
|
||||
)
|
||||
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user