mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
fix mllama cross_mask
Former-commit-id: 598c22e43f3f10a335933339cc612744c4835eb0
This commit is contained in:
parent
e0325b1123
commit
006022cadd
@ -19,6 +19,7 @@ from dataclasses import dataclass
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
|
||||||
@ -98,6 +99,12 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
feature["token_type_ids"] = token_type_ids[i]
|
feature["token_type_ids"] = token_type_ids[i]
|
||||||
|
|
||||||
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
features: Dict[str, "torch.Tensor"] = super().__call__(features)
|
||||||
|
if "cross_attention_mask" in mm_inputs: # for mllama inputs
|
||||||
|
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
|
||||||
|
seq_len = features["input_ids"].size(1)
|
||||||
|
orig_len = cross_attention_mask.size(1)
|
||||||
|
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
|
||||||
|
|
||||||
features.update(mm_inputs)
|
features.update(mm_inputs)
|
||||||
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||||
|
@ -241,7 +241,7 @@ class BasePlugin:
|
|||||||
videos: a list of video inputs, shape (num_videos,)
|
videos: a list of video inputs, shape (num_videos,)
|
||||||
imglens: number of images in each sample, shape (batch_size,)
|
imglens: number of images in each sample, shape (batch_size,)
|
||||||
vidlens: number of videos in each sample, shape (batch_size,)
|
vidlens: number of videos in each sample, shape (batch_size,)
|
||||||
batch_ids: input ids of samples, shape (batch_size, seq_len)
|
batch_ids: token ids of input samples, shape (batch_size, seq_len)
|
||||||
processor: a processor for pre-processing images and videos
|
processor: a processor for pre-processing images and videos
|
||||||
"""
|
"""
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
@ -760,7 +760,7 @@ class MllamaPlugin(BasePlugin):
|
|||||||
max_num_tiles=max_image_tiles,
|
max_num_tiles=max_image_tiles,
|
||||||
length=max(len(input_ids) for input_ids in batch_ids),
|
length=max(len(input_ids) for input_ids in batch_ids),
|
||||||
)
|
)
|
||||||
)
|
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user