mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
CO3Dv2 multi-category extension
Summary: Allows loading of multiple categories. Multiple categories are provided in a comma-separated list of category names. Reviewed By: bottler, shapovalov Differential Revision: D40803297 fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c54e048666
commit
e4a3298149
@@ -41,22 +41,73 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||
categories = ["A", "B"]
|
||||
subset_name = "test"
|
||||
eval_batch_size = 5
|
||||
n_frames = 8 * 3
|
||||
n_sequences = 5
|
||||
n_eval_batches = 10
|
||||
with tempfile.TemporaryDirectory() as tmpd:
|
||||
_make_random_json_dataset_map_provider_v2_data(
|
||||
tmpd,
|
||||
categories,
|
||||
eval_batch_size=eval_batch_size,
|
||||
n_frames=n_frames,
|
||||
n_sequences=n_sequences,
|
||||
n_eval_batches=n_eval_batches,
|
||||
)
|
||||
for n_known_frames_for_test in [0, 2]:
|
||||
for category in categories:
|
||||
dataset_provider = JsonIndexDatasetMapProviderV2(
|
||||
dataset_providers = {
|
||||
category: JsonIndexDatasetMapProviderV2(
|
||||
category=category,
|
||||
subset_name="test",
|
||||
dataset_root=tmpd,
|
||||
n_known_frames_for_test=n_known_frames_for_test,
|
||||
)
|
||||
for category in [*categories, ",".join(sorted(categories))]
|
||||
}
|
||||
for category, dataset_provider in dataset_providers.items():
|
||||
dataset_map = dataset_provider.get_dataset_map()
|
||||
for set_ in ["train", "val", "test"]:
|
||||
dataset = getattr(dataset_map, set_)
|
||||
|
||||
cat2seq = dataset.category_to_sequence_names()
|
||||
self.assertEqual(",".join(sorted(cat2seq.keys())), category)
|
||||
|
||||
if not (n_known_frames_for_test != 0 and set_ == "test"):
|
||||
# check the lengths only in case we do not have the
|
||||
# n_known_frames_for_test set
|
||||
expected_dataset_len = n_frames * n_sequences // 3
|
||||
if "," in category:
|
||||
# multicategory json index dataset, sum the lengths of
|
||||
# category-specific ones
|
||||
expected_dataset_len = sum(
|
||||
len(
|
||||
getattr(
|
||||
dataset_providers[c].get_dataset_map(), set_
|
||||
)
|
||||
)
|
||||
for c in categories
|
||||
)
|
||||
self.assertEqual(
|
||||
sum(len(s) for s in cat2seq.values()),
|
||||
n_sequences * len(categories),
|
||||
)
|
||||
self.assertEqual(len(cat2seq), len(categories))
|
||||
else:
|
||||
self.assertEqual(
|
||||
len(cat2seq[category]),
|
||||
n_sequences,
|
||||
)
|
||||
self.assertEqual(len(cat2seq), 1)
|
||||
self.assertEqual(len(dataset), expected_dataset_len)
|
||||
|
||||
if set_ == "test":
|
||||
# check the number of eval batches
|
||||
expected_n_eval_batches = n_eval_batches
|
||||
if "," in category:
|
||||
expected_n_eval_batches *= len(categories)
|
||||
self.assertTrue(
|
||||
len(dataset.get_eval_batches())
|
||||
== expected_n_eval_batches
|
||||
)
|
||||
if set_ in ["train", "val"]:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
getattr(dataset_map, set_),
|
||||
@@ -80,6 +131,7 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||
dataset_provider.get_category_to_subset_name_list()
|
||||
)
|
||||
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
||||
|
||||
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
||||
|
||||
|
||||
@@ -88,6 +140,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
categories: List[str],
|
||||
n_frames: int = 8,
|
||||
n_sequences: int = 5,
|
||||
n_eval_batches: int = 10,
|
||||
H: int = 50,
|
||||
W: int = 30,
|
||||
subset_name: str = "test",
|
||||
@@ -100,7 +153,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
sequence_annotations = []
|
||||
frame_index = []
|
||||
for seq_i in range(n_sequences):
|
||||
seq_name = str(seq_i)
|
||||
seq_name = category + str(seq_i)
|
||||
for i in range(n_frames):
|
||||
# generate and store image
|
||||
imdir = os.path.join(root, category, seq_name, "images")
|
||||
@@ -165,7 +218,8 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
json.dump(set_list, f)
|
||||
|
||||
eval_batches = [
|
||||
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
|
||||
random.sample(test_frame_index, eval_batch_size)
|
||||
for _ in range(n_eval_batches)
|
||||
]
|
||||
|
||||
eval_b_dir = os.path.join(root, category, "eval_batches")
|
||||
|
||||
Reference in New Issue
Block a user