mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Test rendering models for R2N2
Summary: Adding a render function for R2N2. Reviewed By: nikhilaravi Differential Revision: D22230228 fbshipit-source-id: a9f588ddcba15bb5d8be1401f68d730e810b4251
This commit is contained in:
parent
49b4ce1acc
commit
5636eb6152
@ -64,6 +64,7 @@ class R2N2(ShapeNetBase):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
synset_set.add(synset)
|
synset_set.add(synset)
|
||||||
|
self.synset_starts[synset] = len(self.synset_ids)
|
||||||
models = split_dict[synset].keys()
|
models = split_dict[synset].keys()
|
||||||
for model in models:
|
for model in models:
|
||||||
# Examine if the given model is present in the ShapeNetCore path.
|
# Examine if the given model is present in the ShapeNetCore path.
|
||||||
@ -78,6 +79,7 @@ class R2N2(ShapeNetBase):
|
|||||||
continue
|
continue
|
||||||
self.synset_ids.append(synset)
|
self.synset_ids.append(synset)
|
||||||
self.model_ids.append(model)
|
self.model_ids.append(model)
|
||||||
|
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
|
||||||
|
|
||||||
# Examine if all the synsets in the standard R2N2 mapping are present.
|
# Examine if all the synsets in the standard R2N2 mapping are present.
|
||||||
# Update self.synset_inv so that it only includes the loaded categories.
|
# Update self.synset_inv so that it only includes the loaded categories.
|
||||||
|
@ -34,7 +34,7 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
|||||||
self.synset_starts = {}
|
self.synset_starts = {}
|
||||||
self.synset_lens = {}
|
self.synset_lens = {}
|
||||||
self.shapenet_dir = ""
|
self.shapenet_dir = ""
|
||||||
self.model_dir = ""
|
self.model_dir = "model.obj"
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""
|
||||||
|
BIN
tests/data/test_r2n2_render_by_categories_0.png
Normal file
BIN
tests/data/test_r2n2_render_by_categories_0.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.2 KiB |
BIN
tests/data/test_r2n2_render_by_categories_1.png
Normal file
BIN
tests/data/test_r2n2_render_by_categories_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.7 KiB |
BIN
tests/data/test_r2n2_render_by_categories_2.png
Normal file
BIN
tests/data/test_r2n2_render_by_categories_2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.3 KiB |
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_0.png
Normal file
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_0.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.1 KiB |
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_1.png
Normal file
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_1.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.0 KiB |
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_2.png
Normal file
BIN
tests/data/test_r2n2_render_by_idxs_and_ids_2.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.0 KiB |
@ -5,10 +5,19 @@ Sanity checks for loading R2N2.
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin, load_rgb_image
|
||||||
|
from PIL import Image
|
||||||
from pytorch3d.datasets import R2N2, collate_batched_meshes
|
from pytorch3d.datasets import R2N2, collate_batched_meshes
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
OpenGLPerspectiveCameras,
|
||||||
|
PointLights,
|
||||||
|
RasterizationSettings,
|
||||||
|
look_at_view_transform,
|
||||||
|
)
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
@ -17,6 +26,9 @@ R2N2_PATH = None
|
|||||||
SHAPENET_PATH = None
|
SHAPENET_PATH = None
|
||||||
SPLITS_PATH = None
|
SPLITS_PATH = None
|
||||||
|
|
||||||
|
DEBUG = False
|
||||||
|
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||||
|
|
||||||
|
|
||||||
class TestR2N2(TestCaseMixin, unittest.TestCase):
|
class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -44,16 +56,14 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_load_R2N2(self):
|
def test_load_R2N2(self):
|
||||||
"""
|
"""
|
||||||
Test loading the train split of R2N2. Check the loaded dataset return items
|
Test the loaded train split of R2N2 return items of the correct shapes and types.
|
||||||
of the correct shapes and types.
|
|
||||||
"""
|
"""
|
||||||
# Load dataset in the train split.
|
# Load dataset in the train split.
|
||||||
split = "train"
|
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||||
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
|
||||||
|
|
||||||
# Check total number of objects in the dataset is correct.
|
# Check total number of objects in the dataset is correct.
|
||||||
with open(SPLITS_PATH) as splits:
|
with open(SPLITS_PATH) as splits:
|
||||||
split_dict = json.load(splits)[split]
|
split_dict = json.load(splits)["train"]
|
||||||
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
|
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
|
||||||
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
||||||
|
|
||||||
@ -75,8 +85,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
the correct shapes and types are returned.
|
the correct shapes and types are returned.
|
||||||
"""
|
"""
|
||||||
# Load dataset in the train split.
|
# Load dataset in the train split.
|
||||||
split = "train"
|
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||||
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
|
||||||
|
|
||||||
# Randomly retrieve several objects from the dataset and collate them.
|
# Randomly retrieve several objects from the dataset and collate them.
|
||||||
collated_meshes = collate_batched_meshes(
|
collated_meshes = collate_batched_meshes(
|
||||||
@ -109,3 +118,117 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(len(object_batch["label"]), batch_size)
|
self.assertEqual(len(object_batch["label"]), batch_size)
|
||||||
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
||||||
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
||||||
|
|
||||||
|
def test_catch_render_arg_errors(self):
|
||||||
|
"""
|
||||||
|
Test rendering R2N2 with an invalid model_id, category or index, and
|
||||||
|
catch corresponding errors.
|
||||||
|
"""
|
||||||
|
# Load dataset in the train split.
|
||||||
|
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||||
|
|
||||||
|
# Try loading with an invalid model_id and catch error.
|
||||||
|
with self.assertRaises(ValueError) as err:
|
||||||
|
r2n2_dataset.render(model_ids=["lamp0"])
|
||||||
|
self.assertTrue("not found in the loaded dataset" in str(err.exception))
|
||||||
|
|
||||||
|
# Try loading with an index out of bounds and catch error.
|
||||||
|
with self.assertRaises(IndexError) as err:
|
||||||
|
r2n2_dataset.render(idxs=[1000000])
|
||||||
|
self.assertTrue("are out of bounds" in str(err.exception))
|
||||||
|
|
||||||
|
def test_render_r2n2(self):
|
||||||
|
"""
|
||||||
|
Test rendering objects from R2N2 selected both by indices and model_ids.
|
||||||
|
"""
|
||||||
|
# Set up device and seed for random selections.
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
torch.manual_seed(39)
|
||||||
|
|
||||||
|
# Load dataset in the train split.
|
||||||
|
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||||
|
|
||||||
|
# Render first three models in the dataset.
|
||||||
|
R, T = look_at_view_transform(1.0, 1.0, 90)
|
||||||
|
cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
|
||||||
|
raster_settings = RasterizationSettings(image_size=512)
|
||||||
|
lights = PointLights(
|
||||||
|
location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
|
||||||
|
# TODO: debug the source of the discrepancy in two images when rendering on GPU.
|
||||||
|
diffuse_color=((0, 0, 0),),
|
||||||
|
specular_color=((0, 0, 0),),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
r2n2_by_idxs = r2n2_dataset.render(
|
||||||
|
idxs=list(range(3)),
|
||||||
|
device=device,
|
||||||
|
cameras=cameras,
|
||||||
|
raster_settings=raster_settings,
|
||||||
|
lights=lights,
|
||||||
|
)
|
||||||
|
# Check that there are three images in the batch.
|
||||||
|
self.assertEqual(r2n2_by_idxs.shape[0], 3)
|
||||||
|
|
||||||
|
# Compare the rendered models to the reference images.
|
||||||
|
for idx in range(3):
|
||||||
|
r2n2_by_idxs_rgb = r2n2_by_idxs[idx, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((r2n2_by_idxs_rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / ("DEBUG_r2n2_render_by_idxs_%s.png" % idx)
|
||||||
|
)
|
||||||
|
image_ref = load_rgb_image(
|
||||||
|
"test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
|
||||||
|
)
|
||||||
|
self.assertClose(r2n2_by_idxs_rgb, image_ref, atol=0.05)
|
||||||
|
|
||||||
|
# Render the same models but by model_ids this time.
|
||||||
|
r2n2_by_model_ids = r2n2_dataset.render(
|
||||||
|
model_ids=[
|
||||||
|
"1a4a8592046253ab5ff61a3a2a0e2484",
|
||||||
|
"1a04dcce7027357ab540cc4083acfa57",
|
||||||
|
"1a9d0480b74d782698f5bccb3529a48d",
|
||||||
|
],
|
||||||
|
device=device,
|
||||||
|
cameras=cameras,
|
||||||
|
raster_settings=raster_settings,
|
||||||
|
lights=lights,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare the rendered models to the reference images.
|
||||||
|
for idx in range(3):
|
||||||
|
r2n2_by_model_ids_rgb = r2n2_by_model_ids[idx, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray(
|
||||||
|
(r2n2_by_model_ids_rgb.numpy() * 255).astype(np.uint8)
|
||||||
|
).save(DATA_DIR / ("DEBUG_r2n2_render_by_model_ids_%s.png" % idx))
|
||||||
|
image_ref = load_rgb_image(
|
||||||
|
"test_r2n2_render_by_idxs_and_ids_%s.png" % idx, DATA_DIR
|
||||||
|
)
|
||||||
|
self.assertClose(r2n2_by_model_ids_rgb, image_ref, atol=0.05)
|
||||||
|
|
||||||
|
###############################
|
||||||
|
# Test rendering by categories
|
||||||
|
###############################
|
||||||
|
|
||||||
|
# Render a mixture of categories.
|
||||||
|
categories = ["chair", "lamp"]
|
||||||
|
mixed_objs = r2n2_dataset.render(
|
||||||
|
categories=categories,
|
||||||
|
sample_nums=[1, 2],
|
||||||
|
device=device,
|
||||||
|
cameras=cameras,
|
||||||
|
raster_settings=raster_settings,
|
||||||
|
lights=lights,
|
||||||
|
)
|
||||||
|
# Compare the rendered models to the reference images.
|
||||||
|
for idx in range(3):
|
||||||
|
mixed_rgb = mixed_objs[idx, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((mixed_rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / ("DEBUG_r2n2_render_by_categories_%s.png" % idx)
|
||||||
|
)
|
||||||
|
image_ref = load_rgb_image(
|
||||||
|
"test_r2n2_render_by_categories_%s.png" % idx, DATA_DIR
|
||||||
|
)
|
||||||
|
self.assertClose(mixed_rgb, image_ref, atol=0.05)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user