mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
add some
This commit is contained in:
@@ -19,8 +19,8 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||
@@ -106,7 +106,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
batch_vidlens.append(len(videos))
|
||||
batch_input_ids.append(feature["input_ids"])
|
||||
|
||||
if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case
|
||||
if (
|
||||
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||
): # avoid process hanging in zero3/fsdp case
|
||||
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||
@@ -157,10 +159,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
|
||||
if "image_bound" in features: # for minicpmv inputs
|
||||
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
|
||||
features["input_ids"] = pad(features["input_ids"],)
|
||||
features["input_ids"] = pad(
|
||||
features["input_ids"],
|
||||
)
|
||||
features["position_ids"] = pad(features["position_ids"])
|
||||
features["labels"] = pad(features["labels"], padding_value=-100)
|
||||
features["attention_mask"] = pad(features["attention_mask"],)
|
||||
features["attention_mask"] = pad(
|
||||
features["attention_mask"],
|
||||
)
|
||||
new_features = {}
|
||||
new_features.update({"data": features})
|
||||
new_features.update(features)
|
||||
|
||||
Reference in New Issue
Block a user