fix inputs

This commit is contained in:
hiyouga
2024-11-23 18:25:45 +00:00
parent b1e43e56db
commit 446441fdb0
14 changed files with 148 additions and 95 deletions

View File

@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
processor: Optional["ProcessorMixin"] = None
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], []
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_videos.extend(videos)
batch_imglens.append(len(images))
batch_vidlens.append(len(videos))
batch_seqlens.append(len(feature["input_ids"]))
batch_input_ids.append(feature["input_ids"])
mm_inputs = self.template.mm_plugin.get_mm_inputs(
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
)
if "token_type_ids" in mm_inputs:
token_type_ids = mm_inputs.pop("token_type_ids")