Test rendering models for R2N2

Summary: Adding a render function for R2N2.

Reviewed By: nikhilaravi

Differential Revision: D22230228

fbshipit-source-id: a9f588ddcba15bb5d8be1401f68d730e810b4251
This commit is contained in:
Luya Gao 2020-07-14 14:52:21 -07:00 committed by Facebook GitHub Bot
parent 49b4ce1acc
commit 5636eb6152
9 changed files with 134 additions and 9 deletions

View File

@ -64,6 +64,7 @@ class R2N2(ShapeNetBase):
continue continue
synset_set.add(synset) synset_set.add(synset)
self.synset_starts[synset] = len(self.synset_ids)
models = split_dict[synset].keys() models = split_dict[synset].keys()
for model in models: for model in models:
# Examine if the given model is present in the ShapeNetCore path. # Examine if the given model is present in the ShapeNetCore path.
@ -78,6 +79,7 @@ class R2N2(ShapeNetBase):
continue continue
self.synset_ids.append(synset) self.synset_ids.append(synset)
self.model_ids.append(model) self.model_ids.append(model)
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
# Examine if all the synsets in the standard R2N2 mapping are present. # Examine if all the synsets in the standard R2N2 mapping are present.
# Update self.synset_inv so that it only includes the loaded categories. # Update self.synset_inv so that it only includes the loaded categories.

View File

@ -34,7 +34,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
self.synset_starts = {} self.synset_starts = {}
self.synset_lens = {} self.synset_lens = {}
self.shapenet_dir = "" self.shapenet_dir = ""
self.model_dir = "" self.model_dir = "model.obj"
def __len__(self): def __len__(self):
""" """

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

View File

@ -5,10 +5,19 @@ Sanity checks for loading R2N2.
import json import json
import os import os
import unittest import unittest
from pathlib import Path
import numpy as np
import torch import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.datasets import R2N2, collate_batched_meshes from pytorch3d.datasets import R2N2, collate_batched_meshes
from pytorch3d.renderer import (
OpenGLPerspectiveCameras,
PointLights,
RasterizationSettings,
look_at_view_transform,
)
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -17,6 +26,9 @@ R2N2_PATH = None
SHAPENET_PATH = None SHAPENET_PATH = None
SPLITS_PATH = None SPLITS_PATH = None
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
class TestR2N2(TestCaseMixin, unittest.TestCase): class TestR2N2(TestCaseMixin, unittest.TestCase):
def setUp(self): def setUp(self):
@ -44,16 +56,14 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
def test_load_R2N2(self): def test_load_R2N2(self):
""" """
Test loading the train split of R2N2. Check the loaded dataset return items Test the loaded train split of R2N2 return items of the correct shapes and types.
of the correct shapes and types.
""" """
# Load dataset in the train split. # Load dataset in the train split.
split = "train" r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Check total number of objects in the dataset is correct. # Check total number of objects in the dataset is correct.
with open(SPLITS_PATH) as splits: with open(SPLITS_PATH) as splits:
split_dict = json.load(splits)[split] split_dict = json.load(splits)["train"]
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()] model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
self.assertEqual(len(r2n2_dataset), sum(model_nums)) self.assertEqual(len(r2n2_dataset), sum(model_nums))
@ -75,8 +85,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
the correct shapes and types are returned. the correct shapes and types are returned.
""" """
# Load dataset in the train split. # Load dataset in the train split.
split = "train" r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Randomly retrieve several objects from the dataset and collate them. # Randomly retrieve several objects from the dataset and collate them.
collated_meshes = collate_batched_meshes( collated_meshes = collate_batched_meshes(
@ -109,3 +118,117 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(len(object_batch["label"]), batch_size) self.assertEqual(len(object_batch["label"]), batch_size)
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size) self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size) self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
def test_catch_render_arg_errors(self):
"""
Test rendering R2N2 with an invalid model_id, category or index, and
catch corresponding errors.
"""
# Load dataset in the train split.
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Try loading with an invalid model_id and catch error.
with self.assertRaises(ValueError) as err:
r2n2_dataset.render(model_ids=["lamp0"])
self.assertTrue("not found in the loaded dataset" in str(err.exception))
# Try loading with an index out of bounds and catch error.
with self.assertRaises(IndexError) as err:
r2n2_dataset.render(idxs=[1000000])
self.assertTrue("are out of bounds" in str(err.exception))
def test_render_r2n2(self):
"""
Test rendering objects from R2N2 selected both by indices and model_ids.
"""
# Set up device and seed for random selections.
device = torch.device("cuda:0")
torch.manual_seed(39)
# Load dataset in the train split.
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Render first three models in the dataset.
R, T = look_at_view_transform(1.0, 1.0, 90)
cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
raster_settings = RasterizationSettings(image_size=512)
lights = PointLights(
location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
# TODO: debug the source of the discrepancy in two images when rendering on GPU.
diffuse_color=((0, 0, 0),),
specular_color=((0, 0, 0),),
device=device,
)
r2n2_by_idxs = r2n2_dataset.render(
idxs=list(range(3)),
device=device,
cameras=cameras,
raster_settings=raster_settings,
lights=lights,
)
# Check that there are three images in the batch.
self.assertEqual(r2n2_by_idxs.shape[0], 3)
# Compare the rendered models to the reference images.
for idx in range(3):
r2n2_by_idxs_rgb = r2n2_by_idxs[idx, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((r2n2_by_idxs_rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / ("DEBUG_r2n2_render_by_idxs_%s.png" % idx)
)
image_ref = load_rgb_image(
"test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
)
self.assertClose(r2n2_by_idxs_rgb, image_ref, atol=0.05)
# Render the same models but by model_ids this time.
r2n2_by_model_ids = r2n2_dataset.render(
model_ids=[
"1a4a8592046253ab5ff61a3a2a0e2484",
"1a04dcce7027357ab540cc4083acfa57",
"1a9d0480b74d782698f5bccb3529a48d",
],
device=device,
cameras=cameras,
raster_settings=raster_settings,
lights=lights,
)
# Compare the rendered models to the reference images.
for idx in range(3):
r2n2_by_model_ids_rgb = r2n2_by_model_ids[idx, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray(
(r2n2_by_model_ids_rgb.numpy() * 255).astype(np.uint8)
).save(DATA_DIR / ("DEBUG_r2n2_render_by_model_ids_%s.png" % idx))
image_ref = load_rgb_image(
"test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
)
self.assertClose(r2n2_by_model_ids_rgb, image_ref, atol=0.05)
###############################
# Test rendering by categories
###############################
# Render a mixture of categories.
categories = ["chair", "lamp"]
mixed_objs = r2n2_dataset.render(
categories=categories,
sample_nums=[1, 2],
device=device,
cameras=cameras,
raster_settings=raster_settings,
lights=lights,
)
# Compare the rendered models to the reference images.
for idx in range(3):
mixed_rgb = mixed_objs[idx, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((mixed_rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / ("DEBUG_r2n2_render_by_categories_%s.png" % idx)
)
image_ref = load_rgb_image(
"test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR
)
self.assertClose(mixed_rgb, image_ref, atol=0.05)