This commit is contained in:
fzc8578
2025-01-10 23:29:06 +08:00
parent ae1f528df3
commit 771cc80294
2 changed files with 41 additions and 17 deletions

View File

@@ -19,8 +19,8 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from transformers import DataCollatorForSeq2Seq
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
@@ -106,7 +106,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"])
if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case
if (
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
): # avoid process hanging in zero3/fsdp case
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
@@ -157,10 +159,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if "image_bound" in features: # for minicpmv inputs
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
features["input_ids"] = pad(features["input_ids"],)
features["input_ids"] = pad(
features["input_ids"],
)
features["position_ids"] = pad(features["position_ids"])
features["labels"] = pad(features["labels"], padding_value=-100)
features["attention_mask"] = pad(features["attention_mask"],)
features["attention_mask"] = pad(
features["attention_mask"],
)
new_features = {}
new_features.update({"data": features})
new_features.update(features)