Return R2N2 renderings

Summary: R2N2 returns R2N2's own renderings of ShapeNetCore models.

Reviewed By: nikhilaravi

Differential Revision: D22266988

fbshipit-source-id: 36e67bd06c6459773e6e5f654259166b579be36a
This commit is contained in:
Luya Gao 2020-07-14 14:52:21 -07:00 committed by Facebook GitHub Bot
parent 5636eb6152
commit dc08c30583
3 changed files with 91 additions and 8 deletions

View File

@ -4,8 +4,11 @@ import json
import warnings
from os import path
from pathlib import Path
from typing import Dict
from typing import Dict, List, Optional
import numpy as np
import torch
from PIL import Image
from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.io import load_obj
@ -21,18 +24,29 @@ class R2N2(ShapeNetBase):
voxelized models.
"""
def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
def __init__(
self,
split: str,
shapenet_dir,
r2n2_dir,
splits_file,
return_all_views: bool = True,
):
"""
Store each object's synset id and models id the given directories.
Args:
split (str): One of (train, val, test).
shapenet_dir (path): Path to ShapeNet core v1.
r2n2_dir (path): Path to the R2N2 dataset.
splits_file (path): File containing the train/val/test splits.
return_all_views (bool): Indicator of whether or not to return all 24 views. If set
to False, one of the 24 views would be randomly selected and returned.
"""
super().__init__()
self.shapenet_dir = shapenet_dir
self.r2n2_dir = r2n2_dir
self.return_all_views = return_all_views
# Examine if split is valid.
if split not in ["train", "val", "test"]:
raise ValueError("split has to be one of (train, val, test).")
@ -48,6 +62,16 @@ class R2N2(ShapeNetBase):
with open(splits_file) as splits:
split_dict = json.load(splits)[split]
self.return_images = True
# Check if the folder containing R2N2 renderings is included in r2n2_dir.
if not path.isdir(path.join(r2n2_dir, "ShapeNetRendering")):
self.return_images = False
msg = (
"ShapeNetRendering not found in %s. R2N2 renderings will "
"be skipped when returning models."
) % (r2n2_dir)
warnings.warn(msg)
synset_set = set()
for synset in split_dict.keys():
# Examine if the given synset is present in the ShapeNetCore dataset
@ -95,12 +119,15 @@ class R2N2(ShapeNetBase):
) % (shapenet_dir, ", ".join(synset_not_present))
warnings.warn(msg)
def __getitem__(self, idx: int) -> Dict:
def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
"""
Read a model by the given index.
Args:
idx: The idx of the model to be retrieved in the dataset.
model_idx: The idx of the model to be retrieved in the dataset.
view_idx: List of indices of the view to be returned. Each index needs to be
between 0 and 23, inclusive. If an invalid index is supplied, view_idx will be
ignored and views will be sampled according to self.return_all_views.
Returns:
dictionary with following keys:
@ -109,12 +136,51 @@ class R2N2(ShapeNetBase):
- synset_id (str): synset id.
- model_id (str): model id.
- label (str): synset label.
- images: FloatTensor of shape (V, H, W, C), where V is number of views
returned. Returns a batch of the renderings of the models from the R2N2 dataset.
"""
model = self._get_item_ids(idx)
if type(model_idx) is tuple:
model_idx, view_idxs = model_idx
model = self._get_item_ids(model_idx)
model_path = path.join(
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
)
model["verts"], faces, _ = load_obj(model_path)
model["faces"] = faces.verts_idx
model["label"] = self.synset_dict[model["synset_id"]]
model["images"] = None
# Retrieve R2N2's renderings if required.
if self.return_images:
ranges = (
range(24) if self.return_all_views else torch.randint(24, (1,)).tolist()
)
if view_idxs is not None and any(idx < 0 or idx > 23 for idx in view_idxs):
msg = (
"One of the indicies in view_idxs is out of range. "
"Index needs to be between 0 and 23, inclusive. "
"Now sampling according to self.return_all_views."
)
warnings.warn(msg)
elif view_idxs is not None:
ranges = view_idxs
rendering_path = path.join(
self.r2n2_dir,
"ShapeNetRendering",
model["synset_id"],
model["model_id"],
"rendering",
)
images = []
for i in ranges:
# Read image.
image_path = path.join(rendering_path, "%02d.png" % i)
raw_img = Image.open(image_path)
image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3]
images.append(image.to(dtype=torch.float32))
model["images"] = torch.stack(images)
return model

View File

@ -2,6 +2,7 @@
from typing import Dict, List
import torch
from pytorch3d.structures import Meshes
@ -32,4 +33,10 @@ def collate_batched_meshes(batch: List[Dict]):
verts=collated_dict["verts"], faces=collated_dict["faces"]
)
# If collate_batched_meshes receives R2N2 items, stack the batches of
# views of each model into a new batch of shape (N, V, H, W, 3) where
# V is the number of views.
if "images" in collated_dict:
collated_dict["images"] = torch.stack(collated_dict["images"])
return collated_dict

View File

@ -56,7 +56,8 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
def test_load_R2N2(self):
"""
Test the loaded train split of R2N2 return items of the correct shapes and types.
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
check the first image returned is correct.
"""
# Load dataset in the train split.
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
@ -68,8 +69,9 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(len(r2n2_dataset), sum(model_nums))
# Randomly retrieve an object from the dataset.
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
# Check that data type and shape of the item returned by __getitem__ are correct.
rand_idx = torch.randint(len(r2n2_dataset), (1,))
rand_obj = r2n2_dataset[rand_idx]
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
verts, faces = rand_obj["verts"], rand_obj["faces"]
self.assertTrue(verts.dtype == torch.float32)
self.assertTrue(faces.dtype == torch.int64)
@ -78,6 +80,13 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(faces.ndim, 2)
self.assertEqual(faces.shape[-1], 3)
# Check that image batch returned by __getitem__ has the correct shape.
self.assertEqual(rand_obj["images"].shape[0], 24)
self.assertEqual(rand_obj["images"].shape[1], 137)
self.assertEqual(rand_obj["images"].shape[2], 137)
self.assertEqual(rand_obj["images"].shape[-1], 3)
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
def test_collate_models(self):
"""
Test collate_batched_meshes returns items of the correct shapes and types.
@ -118,6 +127,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(len(object_batch["label"]), batch_size)
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
self.assertEqual(object_batch["images"].shape[0], batch_size)
def test_catch_render_arg_errors(self):
"""