mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Test rendering models for R2N2
Summary: Adding a render function for R2N2. Reviewed By: nikhilaravi Differential Revision: D22230228 fbshipit-source-id: a9f588ddcba15bb5d8be1401f68d730e810b4251
This commit is contained in:
parent
49b4ce1acc
commit
5636eb6152
@ -64,6 +64,7 @@ class R2N2(ShapeNetBase):
|
||||
continue
|
||||
|
||||
synset_set.add(synset)
|
||||
self.synset_starts[synset] = len(self.synset_ids)
|
||||
models = split_dict[synset].keys()
|
||||
for model in models:
|
||||
# Examine if the given model is present in the ShapeNetCore path.
|
||||
@ -78,6 +79,7 @@ class R2N2(ShapeNetBase):
|
||||
continue
|
||||
self.synset_ids.append(synset)
|
||||
self.model_ids.append(model)
|
||||
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
|
||||
|
||||
# Examine if all the synsets in the standard R2N2 mapping are present.
|
||||
# Update self.synset_inv so that it only includes the loaded categories.
|
||||
|
@ -34,7 +34,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
self.synset_starts = {}
|
||||
self.synset_lens = {}
|
||||
self.shapenet_dir = ""
|
||||
self.model_dir = ""
|
||||
self.model_dir = "model.obj"
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
|
BIN
tests/data/test_r2n2_render_by_categories_0.png
Normal file
BIN
tests/data/test_r2n2_render_by_categories_0.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
BIN
tests/data/test_r2n2_render_by_categories_1.png
Normal file
BIN
tests/data/test_r2n2_render_by_categories_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.7 KiB |
BIN
tests/data/test_r2n2_render_by_categories_2.png
Normal file
BIN
tests/data/test_r2n2_render_by_categories_2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.3 KiB |
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_0.png
Normal file
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_0.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.1 KiB |
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_1.png
Normal file
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.0 KiB |
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_2.png
Normal file
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.0 KiB |
@ -5,10 +5,19 @@ Sanity checks for loading R2N2.
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, load_rgb_image
|
||||
from PIL import Image
|
||||
from pytorch3d.datasets import R2N2, collate_batched_meshes
|
||||
from pytorch3d.renderer import (
|
||||
OpenGLPerspectiveCameras,
|
||||
PointLights,
|
||||
RasterizationSettings,
|
||||
look_at_view_transform,
|
||||
)
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
@ -17,6 +26,9 @@ R2N2_PATH = None
|
||||
SHAPENET_PATH = None
|
||||
SPLITS_PATH = None
|
||||
|
||||
DEBUG = False
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
|
||||
|
||||
class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
@ -44,16 +56,14 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_load_R2N2(self):
|
||||
"""
|
||||
Test loading the train split of R2N2. Check the loaded dataset return items
|
||||
of the correct shapes and types.
|
||||
Test the loaded train split of R2N2 return items of the correct shapes and types.
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
split = "train"
|
||||
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
|
||||
# Check total number of objects in the dataset is correct.
|
||||
with open(SPLITS_PATH) as splits:
|
||||
split_dict = json.load(splits)[split]
|
||||
split_dict = json.load(splits)["train"]
|
||||
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
|
||||
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
||||
|
||||
@ -75,8 +85,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
the correct shapes and types are returned.
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
split = "train"
|
||||
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
|
||||
# Randomly retrieve several objects from the dataset and collate them.
|
||||
collated_meshes = collate_batched_meshes(
|
||||
@ -109,3 +118,117 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(len(object_batch["label"]), batch_size)
|
||||
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
||||
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
||||
|
||||
def test_catch_render_arg_errors(self):
|
||||
"""
|
||||
Test rendering R2N2 with an invalid model_id, category or index, and
|
||||
catch corresponding errors.
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
|
||||
# Try loading with an invalid model_id and catch error.
|
||||
with self.assertRaises(ValueError) as err:
|
||||
r2n2_dataset.render(model_ids=["lamp0"])
|
||||
self.assertTrue("not found in the loaded dataset" in str(err.exception))
|
||||
|
||||
# Try loading with an index out of bounds and catch error.
|
||||
with self.assertRaises(IndexError) as err:
|
||||
r2n2_dataset.render(idxs=[1000000])
|
||||
self.assertTrue("are out of bounds" in str(err.exception))
|
||||
|
||||
def test_render_r2n2(self):
|
||||
"""
|
||||
Test rendering objects from R2N2 selected both by indices and model_ids.
|
||||
"""
|
||||
# Set up device and seed for random selections.
|
||||
device = torch.device("cuda:0")
|
||||
torch.manual_seed(39)
|
||||
|
||||
# Load dataset in the train split.
|
||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
|
||||
# Render first three models in the dataset.
|
||||
R, T = look_at_view_transform(1.0, 1.0, 90)
|
||||
cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
|
||||
raster_settings = RasterizationSettings(image_size=512)
|
||||
lights = PointLights(
|
||||
location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
|
||||
# TODO: debug the source of the discrepancy in two images when rendering on GPU.
|
||||
diffuse_color=((0, 0, 0),),
|
||||
specular_color=((0, 0, 0),),
|
||||
device=device,
|
||||
)
|
||||
|
||||
r2n2_by_idxs = r2n2_dataset.render(
|
||||
idxs=list(range(3)),
|
||||
device=device,
|
||||
cameras=cameras,
|
||||
raster_settings=raster_settings,
|
||||
lights=lights,
|
||||
)
|
||||
# Check that there are three images in the batch.
|
||||
self.assertEqual(r2n2_by_idxs.shape[0], 3)
|
||||
|
||||
# Compare the rendered models to the reference images.
|
||||
for idx in range(3):
|
||||
r2n2_by_idxs_rgb = r2n2_by_idxs[idx, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
Image.fromarray((r2n2_by_idxs_rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / ("DEBUG_r2n2_render_by_idxs_%s.png" % idx)
|
||||
)
|
||||
image_ref = load_rgb_image(
|
||||
"test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
|
||||
)
|
||||
self.assertClose(r2n2_by_idxs_rgb, image_ref, atol=0.05)
|
||||
|
||||
# Render the same models but by model_ids this time.
|
||||
r2n2_by_model_ids = r2n2_dataset.render(
|
||||
model_ids=[
|
||||
"1a4a8592046253ab5ff61a3a2a0e2484",
|
||||
"1a04dcce7027357ab540cc4083acfa57",
|
||||
"1a9d0480b74d782698f5bccb3529a48d",
|
||||
],
|
||||
device=device,
|
||||
cameras=cameras,
|
||||
raster_settings=raster_settings,
|
||||
lights=lights,
|
||||
)
|
||||
|
||||
# Compare the rendered models to the reference images.
|
||||
for idx in range(3):
|
||||
r2n2_by_model_ids_rgb = r2n2_by_model_ids[idx, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
Image.fromarray(
|
||||
(r2n2_by_model_ids_rgb.numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / ("DEBUG_r2n2_render_by_model_ids_%s.png" % idx))
|
||||
image_ref = load_rgb_image(
|
||||
"test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
|
||||
)
|
||||
self.assertClose(r2n2_by_model_ids_rgb, image_ref, atol=0.05)
|
||||
|
||||
###############################
|
||||
# Test rendering by categories
|
||||
###############################
|
||||
|
||||
# Render a mixture of categories.
|
||||
categories = ["chair", "lamp"]
|
||||
mixed_objs = r2n2_dataset.render(
|
||||
categories=categories,
|
||||
sample_nums=[1, 2],
|
||||
device=device,
|
||||
cameras=cameras,
|
||||
raster_settings=raster_settings,
|
||||
lights=lights,
|
||||
)
|
||||
# Compare the rendered models to the reference images.
|
||||
for idx in range(3):
|
||||
mixed_rgb = mixed_objs[idx, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
Image.fromarray((mixed_rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / ("DEBUG_r2n2_render_by_categories_%s.png" % idx)
|
||||
)
|
||||
image_ref = load_rgb_image(
|
||||
"test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR
|
||||
)
|
||||
self.assertClose(mixed_rgb, image_ref, atol=0.05)
|
||||
|
Loading…
x
Reference in New Issue
Block a user