diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index a3f9dfd1..5aea4c61 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence import torch +from torch.nn.utils.rnn import pad_sequence import torch.nn.functional as F from transformers import DataCollatorForSeq2Seq @@ -36,6 +37,10 @@ if TYPE_CHECKING: from .template import Template +def pad(seq, padding_value=0): + return pad_sequence(seq, batch_first=True, padding_value=padding_value) + + def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": r""" Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), @@ -151,7 +156,11 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features = features.data # use default_collate() instead of BatchEncoding.to() if "image_bound" in features: # for minicpmv inputs - features = self.template.mm_plugin.pad_data(features) + features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] + features["input_ids"] = pad(features["input_ids"],) + features["position_ids"] = pad(features["position_ids"]) + features["labels"] = pad(features["labels"], padding_value=-100) + features["attention_mask"] = pad(features["attention_mask"],) new_features = {} new_features.update({"data": features}) new_features.update(features) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 3a102d59..a40f78fc 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDi import numpy as np import torch -from torch.nn.utils.rnn import pad_sequence from transformers.image_utils import get_image_size, to_numpy_array from typing_extensions import override @@ -350,26 +349,6 @@ class CpmOPlugin(BasePlugin): return mm_inputs - def trim_and_pad(self, seq, padding_value=0): - return pad_sequence(seq, batch_first=True, padding_value=padding_value) - - def pad_data(self, features): - features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] - features["input_ids"] = self.trim_and_pad( - features["input_ids"], - ) - features["position_ids"] = self.trim_and_pad( - features["position_ids"], - ) - features["labels"] = self.trim_and_pad( - features["labels"], - padding_value=-100, - ) - features["attention_mask"] = self.trim_and_pad( - features["attention_mask"], - ) - return features - @override def get_mm_inputs( self,