From 49b4ce1accac7d180920aeeb95ba7091f02c4c3c Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Tue, 14 Jul 2020 14:52:21 -0700 Subject: [PATCH] R2N2 skeleton Summary: Skeleton of R2N2 that for now only returns verts and faces extracted from ShapeNetCore v1. Reviewed By: nikhilaravi Differential Revision: D22203656 fbshipit-source-id: 00db6ac76bfdb76fdbc77a2087c34a3f0ff01e6a --- pytorch3d/datasets/__init__.py | 1 + pytorch3d/datasets/r2n2/__init__.py | 6 + pytorch3d/datasets/r2n2/r2n2.py | 118 ++++++++++++++++++ pytorch3d/datasets/r2n2/r2n2_synset_dict.json | 15 +++ tests/test_r2n2.py | 111 ++++++++++++++++ 5 files changed, 251 insertions(+) create mode 100644 pytorch3d/datasets/r2n2/__init__.py create mode 100644 pytorch3d/datasets/r2n2/r2n2.py create mode 100644 pytorch3d/datasets/r2n2/r2n2_synset_dict.json create mode 100644 tests/test_r2n2.py diff --git a/pytorch3d/datasets/__init__.py b/pytorch3d/datasets/__init__.py index 04da090a..78679b6f 100644 --- a/pytorch3d/datasets/__init__.py +++ b/pytorch3d/datasets/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from .r2n2 import R2N2 from .shapenet import ShapeNetCore from .utils import collate_batched_meshes diff --git a/pytorch3d/datasets/r2n2/__init__.py b/pytorch3d/datasets/r2n2/__init__.py new file mode 100644 index 00000000..a98d7b00 --- /dev/null +++ b/pytorch3d/datasets/r2n2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from .r2n2 import R2N2 + + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py new file mode 100644 index 00000000..1305e478 --- /dev/null +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import json +import warnings +from os import path +from pathlib import Path +from typing import Dict + +from pytorch3d.datasets.shapenet_base import ShapeNetBase +from pytorch3d.io import load_obj + + +SYNSET_DICT_DIR = Path(__file__).resolve().parent + + +class R2N2(ShapeNetBase): + """ + This class loads the R2N2 dataset from a given directory into a Dataset object. + The R2N2 dataset contains 13 categories that are a subset of the ShapeNetCore v.1 + dataset. The R2N2 dataset also contains its own 24 renderings of each object and + voxelized models. + """ + + def __init__(self, split, shapenet_dir, r2n2_dir, splits_file): + """ + Store each object's synset id and models id the given directories. + Args: + split (str): One of (train, val, test). + shapenet_dir (path): Path to ShapeNet core v1. + r2n2_dir (path): Path to the R2N2 dataset. + splits_file (path): File containing the train/val/test splits. + """ + super().__init__() + self.shapenet_dir = shapenet_dir + self.r2n2_dir = r2n2_dir + # Examine if split is valid. + if split not in ["train", "val", "test"]: + raise ValueError("split has to be one of (train, val, test).") + # Synset dictionary mapping synset offsets in R2N2 to corresponding labels. + with open( + path.join(SYNSET_DICT_DIR, "r2n2_synset_dict.json"), "r" + ) as read_dict: + self.synset_dict = json.load(read_dict) + # Inverse dicitonary mapping synset labels to corresponding offsets. + self.synset_inv = {label: offset for offset, label in self.synset_dict.items()} + + # Store synset and model ids of objects mentioned in the splits_file. + with open(splits_file) as splits: + split_dict = json.load(splits)[split] + + synset_set = set() + for synset in split_dict.keys(): + # Examine if the given synset is present in the ShapeNetCore dataset + # and is also part of the standard R2N2 dataset. + if not ( + path.isdir(path.join(shapenet_dir, synset)) + and synset in self.synset_dict + ): + msg = ( + "Synset category %s from the splits file is either not " + "present in %s or not part of the standard R2N2 dataset." + ) % (synset, shapenet_dir) + warnings.warn(msg) + continue + + synset_set.add(synset) + models = split_dict[synset].keys() + for model in models: + # Examine if the given model is present in the ShapeNetCore path. + shapenet_path = path.join(shapenet_dir, synset, model) + if not path.isdir(shapenet_path): + msg = "Model %s from category %s is not present in %s." % ( + model, + synset, + shapenet_dir, + ) + warnings.warn(msg) + continue + self.synset_ids.append(synset) + self.model_ids.append(model) + + # Examine if all the synsets in the standard R2N2 mapping are present. + # Update self.synset_inv so that it only includes the loaded categories. + synset_not_present = [ + self.synset_inv.pop(self.synset_dict[synset]) + for synset in self.synset_dict.keys() + if synset not in synset_set + ] + if len(synset_not_present) > 0: + msg = ( + "The following categories are included in R2N2's" + "official mapping but not found in the dataset location %s: %s" + ) % (shapenet_dir, ", ".join(synset_not_present)) + warnings.warn(msg) + + def __getitem__(self, idx: int) -> Dict: + """ + Read a model by the given index. + + Args: + idx: The idx of the model to be retrieved in the dataset. + + Returns: + dictionary with following keys: + - verts: FloatTensor of shape (V, 3). + - faces: faces.verts_idx, LongTensor of shape (F, 3). + - synset_id (str): synset id. + - model_id (str): model id. + - label (str): synset label. + """ + model = self._get_item_ids(idx) + model_path = path.join( + self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj" + ) + model["verts"], faces, _ = load_obj(model_path) + model["faces"] = faces.verts_idx + model["label"] = self.synset_dict[model["synset_id"]] + return model diff --git a/pytorch3d/datasets/r2n2/r2n2_synset_dict.json b/pytorch3d/datasets/r2n2/r2n2_synset_dict.json new file mode 100644 index 00000000..b8cbae58 --- /dev/null +++ b/pytorch3d/datasets/r2n2/r2n2_synset_dict.json @@ -0,0 +1,15 @@ +{ + "04256520": "sofa", + "02933112": "cabinet", + "02828884": "bench", + "03001627": "chair", + "03211117": "display", + "04090263": "rifle", + "03691459": "loudspeaker", + "03636649": "lamp", + "04401088": "telephone", + "02691156": "airplane", + "04379243": "table", + "02958343": "car", + "04530566": "watercraft" +} diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py new file mode 100644 index 00000000..8cd2ed8c --- /dev/null +++ b/tests/test_r2n2.py @@ -0,0 +1,111 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +Sanity checks for loading R2N2. +""" +import json +import os +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.datasets import R2N2, collate_batched_meshes +from torch.utils.data import DataLoader + + +# Set these paths in order to run the tests. +R2N2_PATH = None +SHAPENET_PATH = None +SPLITS_PATH = None + + +class TestR2N2(TestCaseMixin, unittest.TestCase): + def setUp(self): + """ + Check if the data paths are given otherwise skip tests. + """ + if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH): + url = "https://www.shapenet.org/" + msg = ( + "ShapeNet data not found, download from %s, update " + "SHAPENET_PATH at the top of the file, and rerun." + ) + self.skipTest(msg % url) + if R2N2_PATH is None or not os.path.exists(R2N2_PATH): + url = "http://3d-r2n2.stanford.edu/" + msg = ( + "R2N2 data not found, download from %s, update " + "R2N2_PATH at the top of the file, and rerun." + ) + self.skipTest(msg % url) + if SPLITS_PATH is None or not os.path.exists(SPLITS_PATH): + msg = """Splits file not found, update SPLITS_PATH at the top + of the file, and rerun.""" + self.skipTest(msg) + + def test_load_R2N2(self): + """ + Test loading the train split of R2N2. Check the loaded dataset return items + of the correct shapes and types. + """ + # Load dataset in the train split. + split = "train" + r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + + # Check total number of objects in the dataset is correct. + with open(SPLITS_PATH) as splits: + split_dict = json.load(splits)[split] + model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()] + self.assertEqual(len(r2n2_dataset), sum(model_nums)) + + # Randomly retrieve an object from the dataset. + rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))] + # Check that data type and shape of the item returned by __getitem__ are correct. + verts, faces = rand_obj["verts"], rand_obj["faces"] + self.assertTrue(verts.dtype == torch.float32) + self.assertTrue(faces.dtype == torch.int64) + self.assertEqual(verts.ndim, 2) + self.assertEqual(verts.shape[-1], 3) + self.assertEqual(faces.ndim, 2) + self.assertEqual(faces.shape[-1], 3) + + def test_collate_models(self): + """ + Test collate_batched_meshes returns items of the correct shapes and types. + Check that when collate_batched_meshes is passed to Dataloader, batches of + the correct shapes and types are returned. + """ + # Load dataset in the train split. + split = "train" + r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + + # Randomly retrieve several objects from the dataset and collate them. + collated_meshes = collate_batched_meshes( + [r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))] + ) + # Check the collated verts and faces have the correct shapes. + verts, faces = collated_meshes["verts"], collated_meshes["faces"] + self.assertEqual(len(verts), 6) + self.assertEqual(len(faces), 6) + self.assertEqual(verts[0].shape[-1], 3) + self.assertEqual(faces[0].shape[-1], 3) + + # Check the collated mesh has the correct shape. + mesh = collated_meshes["mesh"] + self.assertEqual(mesh.verts_padded().shape[0], 6) + self.assertEqual(mesh.verts_padded().shape[-1], 3) + self.assertEqual(mesh.faces_padded().shape[0], 6) + self.assertEqual(mesh.faces_padded().shape[-1], 3) + + # Pass the custom collate_fn function to DataLoader and check elements + # in batch have the correct shape. + batch_size = 12 + r2n2_loader = DataLoader( + r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes + ) + it = iter(r2n2_loader) + object_batch = next(it) + self.assertEqual(len(object_batch["synset_id"]), batch_size) + self.assertEqual(len(object_batch["model_id"]), batch_size) + self.assertEqual(len(object_batch["label"]), batch_size) + self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size) + self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)