mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
JsonIndexDatasetProviderV2
Summary: A new version of json index dataset provider supporting CO3Dv2 Reviewed By: shapovalov Differential Revision: D37690918 fbshipit-source-id: bf2d5fc9d0f1220259e08661dafc69cdbe6b7f94
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4300030d7a
commit
e8390d3500
@@ -42,6 +42,47 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
sort_frames: false
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||
category: ???
|
||||
subset_name: ???
|
||||
dataset_root: ''
|
||||
test_on_train: false
|
||||
only_test_set: false
|
||||
load_eval_batches: true
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
dataset_JsonIndexDataset_args:
|
||||
path_manager: null
|
||||
frame_annotations_file: ''
|
||||
sequence_annotations_file: ''
|
||||
subset_lists_file: ''
|
||||
subsets: null
|
||||
limit_to: 0
|
||||
limit_sequences_to: 0
|
||||
pick_sequence: []
|
||||
exclude_sequence: []
|
||||
limit_category_to: []
|
||||
dataset_root: ''
|
||||
load_images: true
|
||||
load_depths: true
|
||||
load_depth_masks: true
|
||||
load_masks: true
|
||||
load_point_clouds: false
|
||||
max_points: 0
|
||||
mask_images: false
|
||||
mask_depths: false
|
||||
image_height: 800
|
||||
image_width: 800
|
||||
box_crop: true
|
||||
box_crop_mask_thr: 0.4
|
||||
box_crop_context: 0.3
|
||||
remove_empty_masks: true
|
||||
n_frames_per_sequence: -1
|
||||
seed: 0
|
||||
sort_frames: false
|
||||
eval_batches: null
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
dataset_map_provider_LlffDatasetMapProvider_args:
|
||||
base_dir: ???
|
||||
object_name: ???
|
||||
|
||||
155
tests/implicitron/test_json_index_dataset_provider_v2.py
Normal file
155
tests/implicitron/test_json_index_dataset_provider_v2.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
|
||||
JsonIndexDatasetMapProviderV2,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.types import (
|
||||
dump_dataclass_jgzip,
|
||||
FrameAnnotation,
|
||||
ImageAnnotation,
|
||||
MaskAnnotation,
|
||||
SequenceAnnotation,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
|
||||
|
||||
class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||
def test_random_dataset(self):
|
||||
# store random frame annotations
|
||||
expand_args_fields(JsonIndexDatasetMapProviderV2)
|
||||
categories = ["A", "B"]
|
||||
subset_name = "test"
|
||||
with tempfile.TemporaryDirectory() as tmpd:
|
||||
_make_random_json_dataset_map_provider_v2_data(tmpd, categories)
|
||||
for category in categories:
|
||||
dataset_provider = JsonIndexDatasetMapProviderV2(
|
||||
category=category,
|
||||
subset_name="test",
|
||||
dataset_root=tmpd,
|
||||
)
|
||||
dataset_map = dataset_provider.get_dataset_map()
|
||||
for set_ in ["train", "val", "test"]:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
getattr(dataset_map, set_),
|
||||
batch_size=3,
|
||||
shuffle=True,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
for _ in dataloader:
|
||||
pass
|
||||
category_to_subset_list = (
|
||||
dataset_provider.get_category_to_subset_name_list()
|
||||
)
|
||||
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
||||
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
||||
|
||||
|
||||
def _make_random_json_dataset_map_provider_v2_data(
|
||||
root: str,
|
||||
categories: List[str],
|
||||
n_frames: int = 8,
|
||||
n_sequences: int = 5,
|
||||
H: int = 50,
|
||||
W: int = 30,
|
||||
subset_name: str = "test",
|
||||
):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
category_to_subset_list = {}
|
||||
for category in categories:
|
||||
frame_annotations = []
|
||||
sequence_annotations = []
|
||||
frame_index = []
|
||||
for seq_i in range(n_sequences):
|
||||
seq_name = str(seq_i)
|
||||
for i in range(n_frames):
|
||||
# generate and store image
|
||||
imdir = os.path.join(root, category, seq_name, "images")
|
||||
os.makedirs(imdir, exist_ok=True)
|
||||
img_path = os.path.join(imdir, f"frame{i:05d}.jpg")
|
||||
img = torch.rand(3, H, W)
|
||||
torchvision.utils.save_image(img, img_path)
|
||||
|
||||
# generate and store mask
|
||||
maskdir = os.path.join(root, category, seq_name, "masks")
|
||||
os.makedirs(maskdir, exist_ok=True)
|
||||
mask_path = os.path.join(maskdir, f"frame{i:05d}.png")
|
||||
mask = np.zeros((H, W))
|
||||
mask[H // 2 :, W // 2 :] = 1
|
||||
Image.fromarray((mask * 255.0).astype(np.uint8), mode="L",).convert(
|
||||
"L"
|
||||
).save(mask_path)
|
||||
|
||||
fa = FrameAnnotation(
|
||||
sequence_name=seq_name,
|
||||
frame_number=i,
|
||||
frame_timestamp=float(i),
|
||||
image=ImageAnnotation(
|
||||
path=img_path.replace(os.path.normpath(root) + "/", ""),
|
||||
size=list(img.shape[-2:]),
|
||||
),
|
||||
mask=MaskAnnotation(
|
||||
path=mask_path.replace(os.path.normpath(root) + "/", ""),
|
||||
mass=mask.sum().item(),
|
||||
),
|
||||
)
|
||||
frame_annotations.append(fa)
|
||||
frame_index.append((seq_name, i, fa.image.path))
|
||||
|
||||
sequence_annotations.append(
|
||||
SequenceAnnotation(
|
||||
sequence_name=seq_name,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
dump_dataclass_jgzip(
|
||||
os.path.join(root, category, "frame_annotations.jgz"),
|
||||
frame_annotations,
|
||||
)
|
||||
dump_dataclass_jgzip(
|
||||
os.path.join(root, category, "sequence_annotations.jgz"),
|
||||
sequence_annotations,
|
||||
)
|
||||
|
||||
test_frame_index = frame_index[2::3]
|
||||
|
||||
set_list = {
|
||||
"train": frame_index[0::3],
|
||||
"val": frame_index[1::3],
|
||||
"test": test_frame_index,
|
||||
}
|
||||
set_lists_dir = os.path.join(root, category, "set_lists")
|
||||
os.makedirs(set_lists_dir, exist_ok=True)
|
||||
set_list_file = os.path.join(set_lists_dir, f"set_lists_{subset_name}.json")
|
||||
with open(set_list_file, "w") as f:
|
||||
json.dump(set_list, f)
|
||||
|
||||
eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
|
||||
eval_b_dir = os.path.join(root, category, "eval_batches")
|
||||
os.makedirs(eval_b_dir, exist_ok=True)
|
||||
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
|
||||
with open(eval_b_file, "w") as f:
|
||||
json.dump(eval_batches, f)
|
||||
|
||||
category_to_subset_list[category] = [subset_name]
|
||||
|
||||
with open(os.path.join(root, "category_to_subset_name_list.json"), "w") as f:
|
||||
json.dump(category_to_subset_list, f)
|
||||
Reference in New Issue
Block a user