CO3Dv2 multi-category extension

Summary:
Allows loading of multiple categories.
Multiple categories are provided in a comma-separated list of category names.

Reviewed By: bottler, shapovalov

Differential Revision: D40803297

fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
David Novotny
2022-11-02 13:55:25 -07:00
committed by Facebook GitHub Bot
parent c54e048666
commit e4a3298149
9 changed files with 272 additions and 25 deletions

View File

@@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
test_on_train: false
only_test_set: false
load_eval_batches: true
num_load_workers: 4
n_known_frames_for_test: 0
dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory

View File

@@ -11,17 +11,20 @@ from dataclasses import dataclass
from itertools import product
import numpy as np
import torch
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
DoublePoolBatchSampler,
)
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@dataclass
class MockFrameAnnotation:
frame_number: int
sequence_name: str = "sequence"
frame_timestamp: float = 0.0
@@ -41,6 +44,9 @@ class MockDataset(DatasetBase):
self.frame_annots = [
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
]
for seq_name, idx in self._seq_to_idx.items():
for i in idx:
self.frame_annots[i]["frame_annotation"].sequence_name = seq_name
def get_frame_numbers_and_timestamps(self, idxs):
out = []
@@ -51,6 +57,16 @@ class MockDataset(DatasetBase):
)
return out
def __getitem__(self, index: int):
fa = self.frame_annots[index]["frame_annotation"]
fd = FrameData(
sequence_name=fa.sequence_name,
sequence_category="default_category",
frame_number=torch.LongTensor([fa.frame_number]),
frame_timestamp=torch.LongTensor([fa.frame_timestamp]),
)
return fd
class TestSceneBatchSampler(unittest.TestCase):
def setUp(self):

View File

@@ -41,22 +41,73 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
categories = ["A", "B"]
subset_name = "test"
eval_batch_size = 5
n_frames = 8 * 3
n_sequences = 5
n_eval_batches = 10
with tempfile.TemporaryDirectory() as tmpd:
_make_random_json_dataset_map_provider_v2_data(
tmpd,
categories,
eval_batch_size=eval_batch_size,
n_frames=n_frames,
n_sequences=n_sequences,
n_eval_batches=n_eval_batches,
)
for n_known_frames_for_test in [0, 2]:
for category in categories:
dataset_provider = JsonIndexDatasetMapProviderV2(
dataset_providers = {
category: JsonIndexDatasetMapProviderV2(
category=category,
subset_name="test",
dataset_root=tmpd,
n_known_frames_for_test=n_known_frames_for_test,
)
for category in [*categories, ",".join(sorted(categories))]
}
for category, dataset_provider in dataset_providers.items():
dataset_map = dataset_provider.get_dataset_map()
for set_ in ["train", "val", "test"]:
dataset = getattr(dataset_map, set_)
cat2seq = dataset.category_to_sequence_names()
self.assertEqual(",".join(sorted(cat2seq.keys())), category)
if not (n_known_frames_for_test != 0 and set_ == "test"):
# check the lengths only in case we do not have the
# n_known_frames_for_test set
expected_dataset_len = n_frames * n_sequences // 3
if "," in category:
# multicategory json index dataset, sum the lengths of
# category-specific ones
expected_dataset_len = sum(
len(
getattr(
dataset_providers[c].get_dataset_map(), set_
)
)
for c in categories
)
self.assertEqual(
sum(len(s) for s in cat2seq.values()),
n_sequences * len(categories),
)
self.assertEqual(len(cat2seq), len(categories))
else:
self.assertEqual(
len(cat2seq[category]),
n_sequences,
)
self.assertEqual(len(cat2seq), 1)
self.assertEqual(len(dataset), expected_dataset_len)
if set_ == "test":
# check the number of eval batches
expected_n_eval_batches = n_eval_batches
if "," in category:
expected_n_eval_batches *= len(categories)
self.assertTrue(
len(dataset.get_eval_batches())
== expected_n_eval_batches
)
if set_ in ["train", "val"]:
dataloader = torch.utils.data.DataLoader(
getattr(dataset_map, set_),
@@ -80,6 +131,7 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
dataset_provider.get_category_to_subset_name_list()
)
category_to_subset_list_ = {c: [subset_name] for c in categories}
self.assertTrue(category_to_subset_list == category_to_subset_list_)
@@ -88,6 +140,7 @@ def _make_random_json_dataset_map_provider_v2_data(
categories: List[str],
n_frames: int = 8,
n_sequences: int = 5,
n_eval_batches: int = 10,
H: int = 50,
W: int = 30,
subset_name: str = "test",
@@ -100,7 +153,7 @@ def _make_random_json_dataset_map_provider_v2_data(
sequence_annotations = []
frame_index = []
for seq_i in range(n_sequences):
seq_name = str(seq_i)
seq_name = category + str(seq_i)
for i in range(n_frames):
# generate and store image
imdir = os.path.join(root, category, seq_name, "images")
@@ -165,7 +218,8 @@ def _make_random_json_dataset_map_provider_v2_data(
json.dump(set_list, f)
eval_batches = [
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
random.sample(test_frame_index, eval_batch_size)
for _ in range(n_eval_batches)
]
eval_b_dir = os.path.join(root, category, "eval_batches")