diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index 553736a5..46865b05 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -11,6 +11,7 @@ import torch from PIL import Image from pytorch3d.datasets.shapenet_base import ShapeNetBase from pytorch3d.io import load_obj +from tabulate import tabulate SYNSET_DICT_DIR = Path(__file__).resolve().parent @@ -21,7 +22,8 @@ class R2N2(ShapeNetBase): This class loads the R2N2 dataset from a given directory into a Dataset object. The R2N2 dataset contains 13 categories that are a subset of the ShapeNetCore v.1 dataset. The R2N2 dataset also contains its own 24 renderings of each object and - voxelized models. + voxelized models. Most of the models have all 24 views in the same split, but there + are eight of them that divide their views between train and test splits. """ def __init__( @@ -40,13 +42,13 @@ class R2N2(ShapeNetBase): shapenet_dir (path): Path to ShapeNet core v1. r2n2_dir (path): Path to the R2N2 dataset. splits_file (path): File containing the train/val/test splits. - return_all_views (bool): Indicator of whether or not to return all 24 views. If set - to False, one of the 24 views would be randomly selected and returned. + return_all_views (bool): Indicator of whether or not to load all the views in + the split. If set to False, one of the views in the split will be randomly + selected and loaded. """ super().__init__() self.shapenet_dir = shapenet_dir self.r2n2_dir = r2n2_dir - self.return_all_views = return_all_views # Examine if split is valid. if split not in ["train", "val", "test"]: raise ValueError("split has to be one of (train, val, test).") @@ -73,6 +75,10 @@ class R2N2(ShapeNetBase): warnings.warn(msg) synset_set = set() + # Store lists of views of each model in a list. + self.views_per_model_list = [] + # Store tuples of synset label and total number of views in each category in a list. + synset_num_instances = [] for synset in split_dict.keys(): # Examine if the given synset is present in the ShapeNetCore dataset # and is also part of the standard R2N2 dataset. @@ -88,9 +94,10 @@ class R2N2(ShapeNetBase): continue synset_set.add(synset) - self.synset_starts[synset] = len(self.synset_ids) - models = split_dict[synset].keys() - for model in models: + self.synset_start_idxs[synset] = len(self.synset_ids) + # Start counting total number of views in the current category. + synset_view_count = 0 + for model in split_dict[synset]: # Examine if the given model is present in the ShapeNetCore path. shapenet_path = path.join(shapenet_dir, synset, model) if not path.isdir(shapenet_path): @@ -103,13 +110,28 @@ class R2N2(ShapeNetBase): continue self.synset_ids.append(synset) self.model_ids.append(model) - self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset] + + model_views = split_dict[synset][model] + # Randomly select a view index if return_all_views set to False. + if not return_all_views: + rand_idx = torch.randint(len(model_views), (1,)) + model_views = [model_views[rand_idx]] + self.views_per_model_list.append(model_views) + synset_view_count += len(model_views) + synset_num_instances.append((self.synset_dict[synset], synset_view_count)) + model_count = len(self.synset_ids) - self.synset_start_idxs[synset] + self.synset_num_models[synset] = model_count + headers = ["category", "#instances"] + synset_num_instances.append(("total", sum(n for _, n in synset_num_instances))) + print( + tabulate(synset_num_instances, headers, numalign="left", stralign="center") + ) # Examine if all the synsets in the standard R2N2 mapping are present. # Update self.synset_inv so that it only includes the loaded categories. synset_not_present = [ self.synset_inv.pop(self.synset_dict[synset]) - for synset in self.synset_dict.keys() + for synset in self.synset_dict if synset not in synset_set ] if len(synset_not_present) > 0: @@ -126,8 +148,9 @@ class R2N2(ShapeNetBase): Args: model_idx: The idx of the model to be retrieved in the dataset. view_idx: List of indices of the view to be returned. Each index needs to be - between 0 and 23, inclusive. If an invalid index is supplied, view_idx will be - ignored and views will be sampled according to self.return_all_views. + contained in the loaded split (always between 0 and 23, inclusive). If + an invalid index is supplied, view_idx will be ignored and all the loaded + views will be returned. Returns: dictionary with following keys: @@ -139,8 +162,31 @@ class R2N2(ShapeNetBase): - images: FloatTensor of shape (V, H, W, C), where V is number of views returned. Returns a batch of the renderings of the models from the R2N2 dataset. """ - if type(model_idx) is tuple: + if isinstance(model_idx, tuple): model_idx, view_idxs = model_idx + if view_idxs is not None: + if isinstance(view_idxs, int): + view_idxs = [view_idxs] + if not isinstance(view_idxs, list) and not torch.is_tensor(view_idxs): + raise TypeError( + "view_idxs is of type %s but it needs to be a list." + % type(view_idxs) + ) + + model_views = self.views_per_model_list[model_idx] + if view_idxs is not None and any( + idx not in self.views_per_model_list[model_idx] for idx in view_idxs + ): + msg = """At least one of the indices in view_idxs is not available. + Specified view of the model needs to be contained in the + loaded split. If return_all_views is set to False, only one + random view is loaded. Try accessing the specified view(s) + after loading the dataset with self.return_all_views set to True. + Now returning all view(s) in the loaded dataset.""" + warnings.warn(msg) + elif view_idxs is not None: + model_views = view_idxs + model = self._get_item_ids(model_idx) model_path = path.join( self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj" @@ -152,19 +198,6 @@ class R2N2(ShapeNetBase): model["images"] = None # Retrieve R2N2's renderings if required. if self.return_images: - ranges = ( - range(24) if self.return_all_views else torch.randint(24, (1,)).tolist() - ) - if view_idxs is not None and any(idx < 0 or idx > 23 for idx in view_idxs): - msg = ( - "One of the indicies in view_idxs is out of range. " - "Index needs to be between 0 and 23, inclusive. " - "Now sampling according to self.return_all_views." - ) - warnings.warn(msg) - elif view_idxs is not None: - ranges = view_idxs - rendering_path = path.join( self.r2n2_dir, "ShapeNetRendering", @@ -174,7 +207,7 @@ class R2N2(ShapeNetBase): ) images = [] - for i in ranges: + for i in model_views: # Read image. image_path = path.join(rendering_path, "%02d.png" % i) raw_img = Image.open(image_path) diff --git a/pytorch3d/datasets/shapenet/shapenet_core.py b/pytorch3d/datasets/shapenet/shapenet_core.py index 79852975..0d799ca2 100644 --- a/pytorch3d/datasets/shapenet/shapenet_core.py +++ b/pytorch3d/datasets/shapenet/shapenet_core.py @@ -100,7 +100,7 @@ class ShapeNetCore(ShapeNetBase): # Each grandchildren directory of data_dir contains an object, and the name # of the directory is the object's model_id. for synset in synset_set: - self.synset_starts[synset] = len(self.synset_ids) + self.synset_start_idxs[synset] = len(self.synset_ids) for model in os.listdir(path.join(data_dir, synset)): if not path.exists(path.join(data_dir, synset, model, self.model_dir)): msg = ( @@ -111,7 +111,8 @@ class ShapeNetCore(ShapeNetBase): continue self.synset_ids.append(synset) self.model_ids.append(model) - self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset] + model_count = len(self.synset_ids) - self.synset_start_idxs[synset] + self.synset_num_models[synset] = model_count def __getitem__(self, idx: int) -> Dict: """ diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index d305894f..85c4ce68 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -31,8 +31,8 @@ class ShapeNetBase(torch.utils.data.Dataset): self.synset_ids = [] self.model_ids = [] self.synset_inv = {} - self.synset_starts = {} - self.synset_lens = {} + self.synset_start_idxs = {} + self.synset_num_models = {} self.shapenet_dir = "" self.model_dir = "model.obj" @@ -227,9 +227,9 @@ class ShapeNetBase(torch.utils.data.Dataset): category: category synset of the category to be sampled from. If not specified, sample from all models in the loaded dataset. """ - start = self.synset_starts[category] if category is not None else 0 + start = self.synset_start_idxs[category] if category is not None else 0 range_len = ( - self.synset_lens[category] if category is not None else self.__len__() + self.synset_num_models[category] if category is not None else self.__len__() ) replacement = sample_num > range_len sampled_idxs = ( diff --git a/pytorch3d/datasets/utils.py b/pytorch3d/datasets/utils.py index 922ff780..5c2f4bc0 100644 --- a/pytorch3d/datasets/utils.py +++ b/pytorch3d/datasets/utils.py @@ -17,6 +17,7 @@ def collate_batched_meshes(batch: List[Dict]): Args: batch: List of dictionaries containing information about objects in the dataset. + Returns: collated_dict: Dictionary of collated lists. If batch contains both verts and faces, a collated mesh batch is also returned. diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py index 1de4222d..9b0253db 100644 --- a/tests/test_r2n2.py +++ b/tests/test_r2n2.py @@ -59,20 +59,30 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): Test the loaded train split of R2N2 return items of the correct shapes and types. Also check the first image returned is correct. """ - # Load dataset in the train split. - r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + # Load dataset in the test split. + r2n2_dataset = R2N2("test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) # Check total number of objects in the dataset is correct. with open(SPLITS_PATH) as splits: - split_dict = json.load(splits)["train"] - model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()] + split_dict = json.load(splits)["test"] + model_nums = [len(split_dict[synset]) for synset in split_dict] self.assertEqual(len(r2n2_dataset), sum(model_nums)) - # Randomly retrieve an object from the dataset. - rand_idx = torch.randint(len(r2n2_dataset), (1,)) - rand_obj = r2n2_dataset[rand_idx] + # Check the numbers of loaded instances for each category are correct. + for synset in split_dict: + split_synset_nums = sum( + len(split_dict[synset][model]) for model in split_dict[synset] + ) + idx_start = r2n2_dataset.synset_start_idxs[synset] + idx_end = idx_start + r2n2_dataset.synset_num_models[synset] + synset_views_list = r2n2_dataset.views_per_model_list[idx_start:idx_end] + loaded_synset_views = sum(len(views) for views in synset_views_list) + self.assertEqual(loaded_synset_views, split_synset_nums) + + # Retrieve an object from the dataset. + r2n2_obj = r2n2_dataset[39] # Check that verts and faces returned by __getitem__ have the correct shapes and types. - verts, faces = rand_obj["verts"], rand_obj["faces"] + verts, faces = r2n2_obj["verts"], r2n2_obj["faces"] self.assertTrue(verts.dtype == torch.float32) self.assertTrue(faces.dtype == torch.int64) self.assertEqual(verts.ndim, 2) @@ -81,11 +91,17 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(faces.shape[-1], 3) # Check that image batch returned by __getitem__ has the correct shape. - self.assertEqual(rand_obj["images"].shape[0], 24) - self.assertEqual(rand_obj["images"].shape[1], 137) - self.assertEqual(rand_obj["images"].shape[2], 137) - self.assertEqual(rand_obj["images"].shape[-1], 3) - self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1) + self.assertEqual(r2n2_obj["images"].shape[0], 24) + self.assertEqual(r2n2_obj["images"].shape[1], 137) + self.assertEqual(r2n2_obj["images"].shape[2], 137) + self.assertEqual(r2n2_obj["images"].shape[-1], 3) + self.assertEqual(r2n2_dataset[39, [21]]["images"].shape[0], 1) + self.assertEqual(r2n2_dataset[39, torch.tensor([12, 21])]["images"].shape[0], 2) + + # Check models with total view counts less than 24 return image batches + # of the correct shapes. + self.assertEqual(r2n2_dataset[635]["images"].shape[0], 5) + self.assertEqual(r2n2_dataset[8369]["images"].shape[0], 10) def test_collate_models(self): """