Test rendering models for R2N2

Summary: Adding a render function for R2N2.

Reviewed By: nikhilaravi

Differential Revision: D22230228

fbshipit-source-id: a9f588ddcba15bb5d8be1401f68d730e810b4251
This commit is contained in:
Luya Gao 2020-07-14 14:52:21 -07:00 committed by Facebook GitHub Bot
parent 49b4ce1acc
commit 5636eb6152
9 changed files with 134 additions and 9 deletions

View File

@ -64,6 +64,7 @@ class R2N2(ShapeNetBase):
continue
synset_set.add(synset)
self.synset_starts[synset] = len(self.synset_ids)
models = split_dict[synset].keys()
for model in models:
# Examine if the given model is present in the ShapeNetCore path.
@ -78,6 +79,7 @@ class R2N2(ShapeNetBase):
continue
self.synset_ids.append(synset)
self.model_ids.append(model)
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
# Examine if all the synsets in the standard R2N2 mapping are present.
# Update self.synset_inv so that it only includes the loaded categories.

View File

@ -34,7 +34,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
self.synset_starts = {}
self.synset_lens = {}
self.shapenet_dir = ""
self.model_dir = ""
self.model_dir = "model.obj"
def __len__(self):
"""

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 KiB

View File

@ -5,10 +5,19 @@ Sanity checks for loading R2N2.
import json
import os
import unittest
from pathlib import Path
import numpy as np
import torch
from common_testing import TestCaseMixin
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.datasets import R2N2, collate_batched_meshes
from pytorch3d.renderer import (
OpenGLPerspectiveCameras,
PointLights,
RasterizationSettings,
look_at_view_transform,
)
from torch.utils.data import DataLoader
@ -17,6 +26,9 @@ R2N2_PATH = None
SHAPENET_PATH = None
SPLITS_PATH = None
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
class TestR2N2(TestCaseMixin, unittest.TestCase):
def setUp(self):
@ -44,16 +56,14 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
def test_load_R2N2(self):
"""
Test loading the train split of R2N2. Check the loaded dataset return items
of the correct shapes and types.
Test the loaded train split of R2N2 return items of the correct shapes and types.
"""
# Load dataset in the train split.
split = "train"
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Check total number of objects in the dataset is correct.
with open(SPLITS_PATH) as splits:
split_dict = json.load(splits)[split]
split_dict = json.load(splits)["train"]
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
self.assertEqual(len(r2n2_dataset), sum(model_nums))
@ -75,8 +85,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
the correct shapes and types are returned.
"""
# Load dataset in the train split.
split = "train"
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Randomly retrieve several objects from the dataset and collate them.
collated_meshes = collate_batched_meshes(
@ -109,3 +118,117 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
self.assertEqual(len(object_batch["label"]), batch_size)
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
def test_catch_render_arg_errors(self):
"""
Test rendering R2N2 with an invalid model_id, category or index, and
catch corresponding errors.
"""
# Load dataset in the train split.
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Try loading with an invalid model_id and catch error.
with self.assertRaises(ValueError) as err:
r2n2_dataset.render(model_ids=["lamp0"])
self.assertTrue("not found in the loaded dataset" in str(err.exception))
# Try loading with an index out of bounds and catch error.
with self.assertRaises(IndexError) as err:
r2n2_dataset.render(idxs=[1000000])
self.assertTrue("are out of bounds" in str(err.exception))
def test_render_r2n2(self):
"""
Test rendering objects from R2N2 selected both by indices and model_ids.
"""
# Set up device and seed for random selections.
device = torch.device("cuda:0")
torch.manual_seed(39)
# Load dataset in the train split.
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
# Render first three models in the dataset.
R, T = look_at_view_transform(1.0, 1.0, 90)
cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
raster_settings = RasterizationSettings(image_size=512)
lights = PointLights(
location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
# TODO: debug the source of the discrepancy in two images when rendering on GPU.
diffuse_color=((0, 0, 0),),
specular_color=((0, 0, 0),),
device=device,
)
r2n2_by_idxs = r2n2_dataset.render(
idxs=list(range(3)),
device=device,
cameras=cameras,
raster_settings=raster_settings,
lights=lights,
)
# Check that there are three images in the batch.
self.assertEqual(r2n2_by_idxs.shape[0], 3)
# Compare the rendered models to the reference images.
for idx in range(3):
r2n2_by_idxs_rgb = r2n2_by_idxs[idx, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((r2n2_by_idxs_rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / ("DEBUG_r2n2_render_by_idxs_%s.png" % idx)
)
image_ref = load_rgb_image(
"test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
)
self.assertClose(r2n2_by_idxs_rgb, image_ref, atol=0.05)
# Render the same models but by model_ids this time.
r2n2_by_model_ids = r2n2_dataset.render(
model_ids=[
"1a4a8592046253ab5ff61a3a2a0e2484",
"1a04dcce7027357ab540cc4083acfa57",
"1a9d0480b74d782698f5bccb3529a48d",
],
device=device,
cameras=cameras,
raster_settings=raster_settings,
lights=lights,
)
# Compare the rendered models to the reference images.
for idx in range(3):
r2n2_by_model_ids_rgb = r2n2_by_model_ids[idx, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray(
(r2n2_by_model_ids_rgb.numpy() * 255).astype(np.uint8)
).save(DATA_DIR / ("DEBUG_r2n2_render_by_model_ids_%s.png" % idx))
image_ref = load_rgb_image(
"test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
)
self.assertClose(r2n2_by_model_ids_rgb, image_ref, atol=0.05)
###############################
# Test rendering by categories
###############################
# Render a mixture of categories.
categories = ["chair", "lamp"]
mixed_objs = r2n2_dataset.render(
categories=categories,
sample_nums=[1, 2],
device=device,
cameras=cameras,
raster_settings=raster_settings,
lights=lights,
)
# Compare the rendered models to the reference images.
for idx in range(3):
mixed_rgb = mixed_objs[idx, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((mixed_rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / ("DEBUG_r2n2_render_by_categories_%s.png" % idx)
)
image_ref = load_rgb_image(
"test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR
)
self.assertClose(mixed_rgb, image_ref, atol=0.05)