R2N2 skeleton

Summary: Skeleton of R2N2 that for now only returns verts and faces extracted from ShapeNetCore v1.

Reviewed By: nikhilaravi

Differential Revision: D22203656

fbshipit-source-id: 00db6ac76bfdb76fdbc77a2087c34a3f0ff01e6a
This commit is contained in:
Luya Gao 2020-07-14 14:52:21 -07:00 committed by Facebook GitHub Bot
parent 22d8c3337a
commit 49b4ce1acc
5 changed files with 251 additions and 0 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .r2n2 import R2N2
from .shapenet import ShapeNetCore
from .utils import collate_batched_meshes

View File

@ -0,0 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .r2n2 import R2N2
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,118 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import json
import warnings
from os import path
from pathlib import Path
from typing import Dict
from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.io import load_obj
SYNSET_DICT_DIR = Path(__file__).resolve().parent
class R2N2(ShapeNetBase):
"""
This class loads the R2N2 dataset from a given directory into a Dataset object.
The R2N2 dataset contains 13 categories that are a subset of the ShapeNetCore v.1
dataset. The R2N2 dataset also contains its own 24 renderings of each object and
voxelized models.
"""
def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
"""
Store each object's synset id and models id the given directories.
Args:
split (str): One of (train, val, test).
shapenet_dir (path): Path to ShapeNet core v1.
r2n2_dir (path): Path to the R2N2 dataset.
splits_file (path): File containing the train/val/test splits.
"""
super().__init__()
self.shapenet_dir = shapenet_dir
self.r2n2_dir = r2n2_dir
# Examine if split is valid.
if split not in ["train", "val", "test"]:
raise ValueError("split has to be one of (train, val, test).")
# Synset dictionary mapping synset offsets in R2N2 to corresponding labels.
with open(
path.join(SYNSET_DICT_DIR, "r2n2_synset_dict.json"), "r"
) as read_dict:
self.synset_dict = json.load(read_dict)
# Inverse dicitonary mapping synset labels to corresponding offsets.
self.synset_inv = {label: offset for offset, label in self.synset_dict.items()}
# Store synset and model ids of objects mentioned in the splits_file.
with open(splits_file) as splits:
split_dict = json.load(splits)[split]
synset_set = set()
for synset in split_dict.keys():
# Examine if the given synset is present in the ShapeNetCore dataset
# and is also part of the standard R2N2 dataset.
if not (
path.isdir(path.join(shapenet_dir, synset))
and synset in self.synset_dict
):
msg = (
"Synset category %s from the splits file is either not "
"present in %s or not part of the standard R2N2 dataset."
) % (synset, shapenet_dir)
warnings.warn(msg)
continue
synset_set.add(synset)
models = split_dict[synset].keys()
for model in models:
# Examine if the given model is present in the ShapeNetCore path.
shapenet_path = path.join(shapenet_dir, synset, model)
if not path.isdir(shapenet_path):
msg = "Model %s from category %s is not present in %s." % (
model,
synset,
shapenet_dir,
)
warnings.warn(msg)
continue
self.synset_ids.append(synset)
self.model_ids.append(model)
# Examine if all the synsets in the standard R2N2 mapping are present.
# Update self.synset_inv so that it only includes the loaded categories.
synset_not_present = [
self.synset_inv.pop(self.synset_dict[synset])
for synset in self.synset_dict.keys()
if synset not in synset_set
]
if len(synset_not_present) > 0:
msg = (
"The following categories are included in R2N2's"
"official mapping but not found in the dataset location %s: %s"
) % (shapenet_dir, ", ".join(synset_not_present))
warnings.warn(msg)
def __getitem__(self, idx: int) -> Dict:
"""
Read a model by the given index.
Args:
idx: The idx of the model to be retrieved in the dataset.
Returns:
dictionary with following keys:
- verts: FloatTensor of shape (V, 3).
- faces: faces.verts_idx, LongTensor of shape (F, 3).
- synset_id (str): synset id.
- model_id (str): model id.
- label (str): synset label.
"""
model = self._get_item_ids(idx)
model_path = path.join(
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
)
model["verts"], faces, _ = load_obj(model_path)
model["faces"] = faces.verts_idx
model["label"] = self.synset_dict[model["synset_id"]]
return model

View File

@ -0,0 +1,15 @@
{
"04256520": "sofa",
"02933112": "cabinet",
"02828884": "bench",
"03001627": "chair",
"03211117": "display",
"04090263": "rifle",
"03691459": "loudspeaker",
"03636649": "lamp",
"04401088": "telephone",
"02691156": "airplane",
"04379243": "table",
"02958343": "car",
"04530566": "watercraft"
}

111
tests/test_r2n2.py Normal file
View File

@ -0,0 +1,111 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Sanity checks for loading R2N2.
"""
import json
import os
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.datasets import R2N2, collate_batched_meshes
from torch.utils.data import DataLoader
# Set these paths in order to run the tests.
R2N2_PATH = None
SHAPENET_PATH = None
SPLITS_PATH = None
class TestR2N2(TestCaseMixin, unittest.TestCase):
def setUp(self):
"""
Check if the data paths are given otherwise skip tests.
"""
if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH):
url = "https://www.shapenet.org/"
msg = (
"ShapeNet data not found, download from %s, update "
"SHAPENET_PATH at the top of the file, and rerun."
)
self.skipTest(msg % url)
if R2N2_PATH is None or not os.path.exists(R2N2_PATH):
url = "http://3d-r2n2.stanford.edu/"
msg = (
"R2N2 data not found, download from %s, update "
"R2N2_PATH at the top of the file, and rerun."
)
self.skipTest(msg % url)
if SPLITS_PATH is None or not os.path.exists(SPLITS_PATH):
msg = """Splits file not found, update SPLITS_PATH at the top
of the file, and rerun."""
self.skipTest(msg)
def test_load_R2N2(self):
"""
Test loading the train split of R2N2. Check the loaded dataset return items
of the correct shapes and types.
"""
# Load dataset in the train split.
split = "train"
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Check total number of objects in the dataset is correct.
with open(SPLITS_PATH) as splits:
split_dict = json.load(splits)[split]
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
self.assertEqual(len(r2n2_dataset), sum(model_nums))
# Randomly retrieve an object from the dataset.
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
# Check that data type and shape of the item returned by __getitem__ are correct.
verts, faces = rand_obj["verts"], rand_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
self.assertTrue(faces.dtype == torch.int64)
self.assertEqual(verts.ndim, 2)
self.assertEqual(verts.shape[-1], 3)
self.assertEqual(faces.ndim, 2)
self.assertEqual(faces.shape[-1], 3)
def test_collate_models(self):
"""
Test collate_batched_meshes returns items of the correct shapes and types.
Check that when collate_batched_meshes is passed to Dataloader, batches of
the correct shapes and types are returned.
"""
# Load dataset in the train split.
split = "train"
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Randomly retrieve several objects from the dataset and collate them.
collated_meshes = collate_batched_meshes(
[r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))]
)
# Check the collated verts and faces have the correct shapes.
verts, faces = collated_meshes["verts"], collated_meshes["faces"]
self.assertEqual(len(verts), 6)
self.assertEqual(len(faces), 6)
self.assertEqual(verts[0].shape[-1], 3)
self.assertEqual(faces[0].shape[-1], 3)
# Check the collated mesh has the correct shape.
mesh = collated_meshes["mesh"]
self.assertEqual(mesh.verts_padded().shape[0], 6)
self.assertEqual(mesh.verts_padded().shape[-1], 3)
self.assertEqual(mesh.faces_padded().shape[0], 6)
self.assertEqual(mesh.faces_padded().shape[-1], 3)
# Pass the custom collate_fn function to DataLoader and check elements
# in batch have the correct shape.
batch_size = 12
r2n2_loader = DataLoader(
r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes
)
it = iter(r2n2_loader)
object_batch = next(it)
self.assertEqual(len(object_batch["synset_id"]), batch_size)
self.assertEqual(len(object_batch["model_id"]), batch_size)
self.assertEqual(len(object_batch["label"]), batch_size)
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)