mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
CO3Dv2 multi-category extension
Summary: Allows loading of multiple categories. Multiple categories are provided in a comma-separated list of category names. Reviewed By: bottler, shapovalov Differential Revision: D40803297 fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c54e048666
commit
e4a3298149
@@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||
test_on_train: false
|
||||
only_test_set: false
|
||||
load_eval_batches: true
|
||||
num_load_workers: 4
|
||||
n_known_frames_for_test: 0
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
|
||||
@@ -11,17 +11,20 @@ from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||
DoublePoolBatchSampler,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockFrameAnnotation:
|
||||
frame_number: int
|
||||
sequence_name: str = "sequence"
|
||||
frame_timestamp: float = 0.0
|
||||
|
||||
|
||||
@@ -41,6 +44,9 @@ class MockDataset(DatasetBase):
|
||||
self.frame_annots = [
|
||||
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
|
||||
]
|
||||
for seq_name, idx in self._seq_to_idx.items():
|
||||
for i in idx:
|
||||
self.frame_annots[i]["frame_annotation"].sequence_name = seq_name
|
||||
|
||||
def get_frame_numbers_and_timestamps(self, idxs):
|
||||
out = []
|
||||
@@ -51,6 +57,16 @@ class MockDataset(DatasetBase):
|
||||
)
|
||||
return out
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
fa = self.frame_annots[index]["frame_annotation"]
|
||||
fd = FrameData(
|
||||
sequence_name=fa.sequence_name,
|
||||
sequence_category="default_category",
|
||||
frame_number=torch.LongTensor([fa.frame_number]),
|
||||
frame_timestamp=torch.LongTensor([fa.frame_timestamp]),
|
||||
)
|
||||
return fd
|
||||
|
||||
|
||||
class TestSceneBatchSampler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
||||
@@ -41,22 +41,73 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||
categories = ["A", "B"]
|
||||
subset_name = "test"
|
||||
eval_batch_size = 5
|
||||
n_frames = 8 * 3
|
||||
n_sequences = 5
|
||||
n_eval_batches = 10
|
||||
with tempfile.TemporaryDirectory() as tmpd:
|
||||
_make_random_json_dataset_map_provider_v2_data(
|
||||
tmpd,
|
||||
categories,
|
||||
eval_batch_size=eval_batch_size,
|
||||
n_frames=n_frames,
|
||||
n_sequences=n_sequences,
|
||||
n_eval_batches=n_eval_batches,
|
||||
)
|
||||
for n_known_frames_for_test in [0, 2]:
|
||||
for category in categories:
|
||||
dataset_provider = JsonIndexDatasetMapProviderV2(
|
||||
dataset_providers = {
|
||||
category: JsonIndexDatasetMapProviderV2(
|
||||
category=category,
|
||||
subset_name="test",
|
||||
dataset_root=tmpd,
|
||||
n_known_frames_for_test=n_known_frames_for_test,
|
||||
)
|
||||
for category in [*categories, ",".join(sorted(categories))]
|
||||
}
|
||||
for category, dataset_provider in dataset_providers.items():
|
||||
dataset_map = dataset_provider.get_dataset_map()
|
||||
for set_ in ["train", "val", "test"]:
|
||||
dataset = getattr(dataset_map, set_)
|
||||
|
||||
cat2seq = dataset.category_to_sequence_names()
|
||||
self.assertEqual(",".join(sorted(cat2seq.keys())), category)
|
||||
|
||||
if not (n_known_frames_for_test != 0 and set_ == "test"):
|
||||
# check the lengths only in case we do not have the
|
||||
# n_known_frames_for_test set
|
||||
expected_dataset_len = n_frames * n_sequences // 3
|
||||
if "," in category:
|
||||
# multicategory json index dataset, sum the lengths of
|
||||
# category-specific ones
|
||||
expected_dataset_len = sum(
|
||||
len(
|
||||
getattr(
|
||||
dataset_providers[c].get_dataset_map(), set_
|
||||
)
|
||||
)
|
||||
for c in categories
|
||||
)
|
||||
self.assertEqual(
|
||||
sum(len(s) for s in cat2seq.values()),
|
||||
n_sequences * len(categories),
|
||||
)
|
||||
self.assertEqual(len(cat2seq), len(categories))
|
||||
else:
|
||||
self.assertEqual(
|
||||
len(cat2seq[category]),
|
||||
n_sequences,
|
||||
)
|
||||
self.assertEqual(len(cat2seq), 1)
|
||||
self.assertEqual(len(dataset), expected_dataset_len)
|
||||
|
||||
if set_ == "test":
|
||||
# check the number of eval batches
|
||||
expected_n_eval_batches = n_eval_batches
|
||||
if "," in category:
|
||||
expected_n_eval_batches *= len(categories)
|
||||
self.assertTrue(
|
||||
len(dataset.get_eval_batches())
|
||||
== expected_n_eval_batches
|
||||
)
|
||||
if set_ in ["train", "val"]:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
getattr(dataset_map, set_),
|
||||
@@ -80,6 +131,7 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||
dataset_provider.get_category_to_subset_name_list()
|
||||
)
|
||||
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
||||
|
||||
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
||||
|
||||
|
||||
@@ -88,6 +140,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
categories: List[str],
|
||||
n_frames: int = 8,
|
||||
n_sequences: int = 5,
|
||||
n_eval_batches: int = 10,
|
||||
H: int = 50,
|
||||
W: int = 30,
|
||||
subset_name: str = "test",
|
||||
@@ -100,7 +153,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
sequence_annotations = []
|
||||
frame_index = []
|
||||
for seq_i in range(n_sequences):
|
||||
seq_name = str(seq_i)
|
||||
seq_name = category + str(seq_i)
|
||||
for i in range(n_frames):
|
||||
# generate and store image
|
||||
imdir = os.path.join(root, category, seq_name, "images")
|
||||
@@ -165,7 +218,8 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
json.dump(set_list, f)
|
||||
|
||||
eval_batches = [
|
||||
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
|
||||
random.sample(test_frame_index, eval_batch_size)
|
||||
for _ in range(n_eval_batches)
|
||||
]
|
||||
|
||||
eval_b_dir = os.path.join(root, category, "eval_batches")
|
||||
|
||||
Reference in New Issue
Block a user