BUAADreamer 175b56bced add multimodal LLM BLIP-2 and InstructBLIP
Former-commit-id: 4dcb11eab7bbeac866043d2a7c748b8d06fbd243
2024-04-23 18:45:43 +08:00

70 lines
2.6 KiB
Python

import json
import os
from dataclasses import dataclass
import torch
from torch.utils.data import Dataset as Dataset_torch
from datasets import Dataset
from PIL import Image
from transformers import AutoProcessor
class ImageCaptioningDataset(Dataset_torch):
def __init__(self, dataset: Dataset, image_path: str, processor: AutoProcessor):
self.processor = processor
self.dataset = dataset
self.image_path = image_path
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
source = self.dataset[idx]
image_id = source['image']
image = Image.open(os.path.join(self.image_path, image_id))
convs = source['conversations']
prompt = convs[0]['value']
label = convs[1]['value']
image_inputs = self.processor(image, return_tensors="pt")
image_inputs = {k: v.squeeze() for k, v in image_inputs.items()}
inputs = {
"input_ids": prompt,
"labels": label,
}
for key in image_inputs:
inputs[key] = image_inputs[key]
return inputs
@dataclass
class DataCollatorForVis2Seq:
processor: AutoProcessor
use_qformer: bool = False
def __call__(self, features, return_tensors=None):
processed_batch = {}
for key in features[0].keys():
if key == 'pixel_values':
processed_batch[key] = torch.stack([example[key] for example in features])
elif key == 'input_ids':
text_inputs = self.processor.tokenizer(
[example[key] for example in features], padding="max_length", return_tensors="pt",
max_length=512,
)
processed_batch["input_ids"] = text_inputs["input_ids"]
processed_batch["attention_mask"] = text_inputs["attention_mask"]
if self.use_qformer:
qformer_text_inputs = self.processor.qformer_tokenizer(
[example[key] for example in features], padding="max_length", return_tensors="pt",
max_length=512,
)
processed_batch["qformer_input_ids"] = qformer_text_inputs["input_ids"]
processed_batch["qformer_attention_mask"] = qformer_text_inputs["attention_mask"]
elif key == 'labels':
text_inputs = self.processor.tokenizer(
[example[key] for example in features], padding="max_length", return_tensors="pt",
max_length=512,
)
processed_batch["labels"] = text_inputs["input_ids"]
return processed_batch