Former-commit-id: ae1f528df31194fe37a123ba1e5a4cd263a61602
This commit is contained in:
fzc8578 2025-01-10 21:25:32 +08:00
parent 994049380d
commit bcbe37ff52
2 changed files with 10 additions and 22 deletions

View File

@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F import torch.nn.functional as F
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
@ -36,6 +37,10 @@ if TYPE_CHECKING:
from .template import Template from .template import Template
def pad(seq, padding_value=0):
return pad_sequence(seq, batch_first=True, padding_value=padding_value)
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":
r""" r"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
@ -151,7 +156,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features = features.data # use default_collate() instead of BatchEncoding.to() features = features.data # use default_collate() instead of BatchEncoding.to()
if "image_bound" in features: # for minicpmv inputs if "image_bound" in features: # for minicpmv inputs
features = self.template.mm_plugin.pad_data(features) features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
features["input_ids"] = pad(features["input_ids"],)
features["position_ids"] = pad(features["position_ids"])
features["labels"] = pad(features["labels"], padding_value=-100)
features["attention_mask"] = pad(features["attention_mask"],)
new_features = {} new_features = {}
new_features.update({"data": features}) new_features.update({"data": features})
new_features.update(features) new_features.update(features)

View File

@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDi
import numpy as np import numpy as np
import torch import torch
from torch.nn.utils.rnn import pad_sequence
from transformers.image_utils import get_image_size, to_numpy_array from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override from typing_extensions import override
@ -350,26 +349,6 @@ class CpmOPlugin(BasePlugin):
return mm_inputs return mm_inputs
def trim_and_pad(self, seq, padding_value=0):
return pad_sequence(seq, batch_first=True, padding_value=padding_value)
def pad_data(self, features):
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
features["input_ids"] = self.trim_and_pad(
features["input_ids"],
)
features["position_ids"] = self.trim_and_pad(
features["position_ids"],
)
features["labels"] = self.trim_and_pad(
features["labels"],
padding_value=-100,
)
features["attention_mask"] = self.trim_and_pad(
features["attention_mask"],
)
return features
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,