mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-24 23:02:49 +08:00
70 lines
2.6 KiB
Python
70 lines
2.6 KiB
Python
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset as Dataset_torch
|
|
from datasets import Dataset
|
|
from PIL import Image
|
|
from transformers import AutoProcessor
|
|
|
|
|
|
class ImageCaptioningDataset(Dataset_torch):
|
|
def __init__(self, dataset: Dataset, image_path: str, processor: AutoProcessor):
|
|
self.processor = processor
|
|
self.dataset = dataset
|
|
self.image_path = image_path
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def __getitem__(self, idx):
|
|
source = self.dataset[idx]
|
|
image_id = source['image']
|
|
image = Image.open(os.path.join(self.image_path, image_id))
|
|
convs = source['conversations']
|
|
prompt = convs[0]['value']
|
|
label = convs[1]['value']
|
|
image_inputs = self.processor(image, return_tensors="pt")
|
|
image_inputs = {k: v.squeeze() for k, v in image_inputs.items()}
|
|
inputs = {
|
|
"input_ids": prompt,
|
|
"labels": label,
|
|
}
|
|
for key in image_inputs:
|
|
inputs[key] = image_inputs[key]
|
|
return inputs
|
|
|
|
|
|
@dataclass
|
|
class DataCollatorForVis2Seq:
|
|
processor: AutoProcessor
|
|
use_qformer: bool = False
|
|
|
|
def __call__(self, features, return_tensors=None):
|
|
processed_batch = {}
|
|
for key in features[0].keys():
|
|
if key == 'pixel_values':
|
|
processed_batch[key] = torch.stack([example[key] for example in features])
|
|
elif key == 'input_ids':
|
|
text_inputs = self.processor.tokenizer(
|
|
[example[key] for example in features], padding="max_length", return_tensors="pt",
|
|
max_length=512,
|
|
)
|
|
processed_batch["input_ids"] = text_inputs["input_ids"]
|
|
processed_batch["attention_mask"] = text_inputs["attention_mask"]
|
|
if self.use_qformer:
|
|
qformer_text_inputs = self.processor.qformer_tokenizer(
|
|
[example[key] for example in features], padding="max_length", return_tensors="pt",
|
|
max_length=512,
|
|
)
|
|
processed_batch["qformer_input_ids"] = qformer_text_inputs["input_ids"]
|
|
processed_batch["qformer_attention_mask"] = qformer_text_inputs["attention_mask"]
|
|
elif key == 'labels':
|
|
text_inputs = self.processor.tokenizer(
|
|
[example[key] for example in features], padding="max_length", return_tensors="pt",
|
|
max_length=512,
|
|
)
|
|
processed_batch["labels"] = text_inputs["input_ids"]
|
|
return processed_batch
|