Test R2N2 loads correct numbers of instances

Summary:
Sample/Get all views at the loading phase instead of returning phase;
Load only views from the split instead of all 24 views;
Test the numbers of views loaded are correct for each category.

Reviewed By: nikhilaravi

Differential Revision: D22631414

fbshipit-source-id: 1c5ce99fe2bdf6618c1aa0b69bb6899473376bc2
This commit is contained in:
Luya Gao 2020-07-23 10:15:50 -07:00 committed by Facebook GitHub Bot
parent 7cb9d8ea86
commit 483e538dae
5 changed files with 96 additions and 45 deletions

View File

@ -11,6 +11,7 @@ import torch
from PIL import Image
from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.io import load_obj
from tabulate import tabulate
SYNSET_DICT_DIR = Path(__file__).resolve().parent
@ -21,7 +22,8 @@ class R2N2(ShapeNetBase):
This class loads the R2N2 dataset from a given directory into a Dataset object.
The R2N2 dataset contains 13 categories that are a subset of the ShapeNetCore v.1
dataset. The R2N2 dataset also contains its own 24 renderings of each object and
voxelized models.
voxelized models. Most of the models have all 24 views in the same split, but there
are eight of them that divide their views between train and test splits.
"""
def __init__(
@ -40,13 +42,13 @@ class R2N2(ShapeNetBase):
shapenet_dir (path): Path to ShapeNet core v1.
r2n2_dir (path): Path to the R2N2 dataset.
splits_file (path): File containing the train/val/test splits.
return_all_views (bool): Indicator of whether or not to return all 24 views. If set
to False, one of the 24 views would be randomly selected and returned.
return_all_views (bool): Indicator of whether or not to load all the views in
the split. If set to False, one of the views in the split will be randomly
selected and loaded.
"""
super().__init__()
self.shapenet_dir = shapenet_dir
self.r2n2_dir = r2n2_dir
self.return_all_views = return_all_views
# Examine if split is valid.
if split not in ["train", "val", "test"]:
raise ValueError("split has to be one of (train, val, test).")
@ -73,6 +75,10 @@ class R2N2(ShapeNetBase):
warnings.warn(msg)
synset_set = set()
# Store lists of views of each model in a list.
self.views_per_model_list = []
# Store tuples of synset label and total number of views in each category in a list.
synset_num_instances = []
for synset in split_dict.keys():
# Examine if the given synset is present in the ShapeNetCore dataset
# and is also part of the standard R2N2 dataset.
@ -88,9 +94,10 @@ class R2N2(ShapeNetBase):
continue
synset_set.add(synset)
self.synset_starts[synset] = len(self.synset_ids)
models = split_dict[synset].keys()
for model in models:
self.synset_start_idxs[synset] = len(self.synset_ids)
# Start counting total number of views in the current category.
synset_view_count = 0
for model in split_dict[synset]:
# Examine if the given model is present in the ShapeNetCore path.
shapenet_path = path.join(shapenet_dir, synset, model)
if not path.isdir(shapenet_path):
@ -103,13 +110,28 @@ class R2N2(ShapeNetBase):
continue
self.synset_ids.append(synset)
self.model_ids.append(model)
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
model_views = split_dict[synset][model]
# Randomly select a view index if return_all_views set to False.
if not return_all_views:
rand_idx = torch.randint(len(model_views), (1,))
model_views = [model_views[rand_idx]]
self.views_per_model_list.append(model_views)
synset_view_count += len(model_views)
synset_num_instances.append((self.synset_dict[synset], synset_view_count))
model_count = len(self.synset_ids) - self.synset_start_idxs[synset]
self.synset_num_models[synset] = model_count
headers = ["category", "#instances"]
synset_num_instances.append(("total", sum(n for _, n in synset_num_instances)))
print(
tabulate(synset_num_instances, headers, numalign="left", stralign="center")
)
# Examine if all the synsets in the standard R2N2 mapping are present.
# Update self.synset_inv so that it only includes the loaded categories.
synset_not_present = [
self.synset_inv.pop(self.synset_dict[synset])
for synset in self.synset_dict.keys()
for synset in self.synset_dict
if synset not in synset_set
]
if len(synset_not_present) > 0:
@ -126,8 +148,9 @@ class R2N2(ShapeNetBase):
Args:
model_idx: The idx of the model to be retrieved in the dataset.
view_idx: List of indices of the view to be returned. Each index needs to be
between 0 and 23, inclusive. If an invalid index is supplied, view_idx will be
ignored and views will be sampled according to self.return_all_views.
contained in the loaded split (always between 0 and 23, inclusive). If
an invalid index is supplied, view_idx will be ignored and all the loaded
views will be returned.
Returns:
dictionary with following keys:
@ -139,8 +162,31 @@ class R2N2(ShapeNetBase):
- images: FloatTensor of shape (V, H, W, C), where V is number of views
returned. Returns a batch of the renderings of the models from the R2N2 dataset.
"""
if type(model_idx) is tuple:
if isinstance(model_idx, tuple):
model_idx, view_idxs = model_idx
if view_idxs is not None:
if isinstance(view_idxs, int):
view_idxs = [view_idxs]
if not isinstance(view_idxs, list) and not torch.is_tensor(view_idxs):
raise TypeError(
"view_idxs is of type %s but it needs to be a list."
% type(view_idxs)
)
model_views = self.views_per_model_list[model_idx]
if view_idxs is not None and any(
idx not in self.views_per_model_list[model_idx] for idx in view_idxs
):
msg = """At least one of the indices in view_idxs is not available.
Specified view of the model needs to be contained in the
loaded split. If return_all_views is set to False, only one
random view is loaded. Try accessing the specified view(s)
after loading the dataset with self.return_all_views set to True.
Now returning all view(s) in the loaded dataset."""
warnings.warn(msg)
elif view_idxs is not None:
model_views = view_idxs
model = self._get_item_ids(model_idx)
model_path = path.join(
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
@ -152,19 +198,6 @@ class R2N2(ShapeNetBase):
model["images"] = None
# Retrieve R2N2's renderings if required.
if self.return_images:
ranges = (
range(24) if self.return_all_views else torch.randint(24, (1,)).tolist()
)
if view_idxs is not None and any(idx < 0 or idx > 23 for idx in view_idxs):
msg = (
"One of the indicies in view_idxs is out of range. "
"Index needs to be between 0 and 23, inclusive. "
"Now sampling according to self.return_all_views."
)
warnings.warn(msg)
elif view_idxs is not None:
ranges = view_idxs
rendering_path = path.join(
self.r2n2_dir,
"ShapeNetRendering",
@ -174,7 +207,7 @@ class R2N2(ShapeNetBase):
)
images = []
for i in ranges:
for i in model_views:
# Read image.
image_path = path.join(rendering_path, "%02d.png" % i)
raw_img = Image.open(image_path)

View File

@ -100,7 +100,7 @@ class ShapeNetCore(ShapeNetBase):
# Each grandchildren directory of data_dir contains an object, and the name
# of the directory is the object's model_id.
for synset in synset_set:
self.synset_starts[synset] = len(self.synset_ids)
self.synset_start_idxs[synset] = len(self.synset_ids)
for model in os.listdir(path.join(data_dir, synset)):
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
msg = (
@ -111,7 +111,8 @@ class ShapeNetCore(ShapeNetBase):
continue
self.synset_ids.append(synset)
self.model_ids.append(model)
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
model_count = len(self.synset_ids) - self.synset_start_idxs[synset]
self.synset_num_models[synset] = model_count
def __getitem__(self, idx: int) -> Dict:
"""

View File

@ -31,8 +31,8 @@ class ShapeNetBase(torch.utils.data.Dataset):
self.synset_ids = []
self.model_ids = []
self.synset_inv = {}
self.synset_starts = {}
self.synset_lens = {}
self.synset_start_idxs = {}
self.synset_num_models = {}
self.shapenet_dir = ""
self.model_dir = "model.obj"
@ -227,9 +227,9 @@ class ShapeNetBase(torch.utils.data.Dataset):
category: category synset of the category to be sampled from. If not
specified, sample from all models in the loaded dataset.
"""
start = self.synset_starts[category] if category is not None else 0
start = self.synset_start_idxs[category] if category is not None else 0
range_len = (
self.synset_lens[category] if category is not None else self.__len__()
self.synset_num_models[category] if category is not None else self.__len__()
)
replacement = sample_num > range_len
sampled_idxs = (

View File

@ -17,6 +17,7 @@ def collate_batched_meshes(batch: List[Dict]):
Args:
batch: List of dictionaries containing information about objects
in the dataset.
Returns:
collated_dict: Dictionary of collated lists. If batch contains both
verts and faces, a collated mesh batch is also returned.

View File

@ -59,20 +59,30 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
check the first image returned is correct.
"""
# Load dataset in the train split.
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Load dataset in the test split.
r2n2_dataset = R2N2("test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Check total number of objects in the dataset is correct.
with open(SPLITS_PATH) as splits:
split_dict = json.load(splits)["train"]
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
split_dict = json.load(splits)["test"]
model_nums = [len(split_dict[synset]) for synset in split_dict]
self.assertEqual(len(r2n2_dataset), sum(model_nums))
# Randomly retrieve an object from the dataset.
rand_idx = torch.randint(len(r2n2_dataset), (1,))
rand_obj = r2n2_dataset[rand_idx]
# Check the numbers of loaded instances for each category are correct.
for synset in split_dict:
split_synset_nums = sum(
len(split_dict[synset][model]) for model in split_dict[synset]
)
idx_start = r2n2_dataset.synset_start_idxs[synset]
idx_end = idx_start + r2n2_dataset.synset_num_models[synset]
synset_views_list = r2n2_dataset.views_per_model_list[idx_start:idx_end]
loaded_synset_views = sum(len(views) for views in synset_views_list)
self.assertEqual(loaded_synset_views, split_synset_nums)
# Retrieve an object from the dataset.
r2n2_obj = r2n2_dataset[39]
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
verts, faces = rand_obj["verts"], rand_obj["faces"]
verts, faces = r2n2_obj["verts"], r2n2_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
self.assertTrue(faces.dtype == torch.int64)
self.assertEqual(verts.ndim, 2)
@ -81,11 +91,17 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(faces.shape[-1], 3)
# Check that image batch returned by __getitem__ has the correct shape.
self.assertEqual(rand_obj["images"].shape[0], 24)
self.assertEqual(rand_obj["images"].shape[1], 137)
self.assertEqual(rand_obj["images"].shape[2], 137)
self.assertEqual(rand_obj["images"].shape[-1], 3)
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
self.assertEqual(r2n2_obj["images"].shape[0], 24)
self.assertEqual(r2n2_obj["images"].shape[1], 137)
self.assertEqual(r2n2_obj["images"].shape[2], 137)
self.assertEqual(r2n2_obj["images"].shape[-1], 3)
self.assertEqual(r2n2_dataset[39, [21]]["images"].shape[0], 1)
self.assertEqual(r2n2_dataset[39, torch.tensor([12, 21])]["images"].shape[0], 2)
# Check models with total view counts less than 24 return image batches
# of the correct shapes.
self.assertEqual(r2n2_dataset[635]["images"].shape[0], 5)
self.assertEqual(r2n2_dataset[8369]["images"].shape[0], 10)
def test_collate_models(self):
"""