pytorch3d/tests/test_rendering_meshes.py
Georgia Gkioxari 487d4d6607 point mesh distances
Summary:
Implementation of point to mesh distances. The current diff contains two types:
(a) Point to Edge
(b) Point to Face

```

Benchmark                                       Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
POINT_MESH_EDGE_4_100_300_5000_cuda:0                2745            3138            183
POINT_MESH_EDGE_4_100_300_10000_cuda:0               4408            4499            114
POINT_MESH_EDGE_4_100_3000_5000_cuda:0               4978            5070            101
POINT_MESH_EDGE_4_100_3000_10000_cuda:0              9076            9187             56
POINT_MESH_EDGE_4_1000_300_5000_cuda:0               1411            1487            355
POINT_MESH_EDGE_4_1000_300_10000_cuda:0              4829            5030            104
POINT_MESH_EDGE_4_1000_3000_5000_cuda:0              7539            7620             67
POINT_MESH_EDGE_4_1000_3000_10000_cuda:0            12088           12272             42
POINT_MESH_EDGE_8_100_300_5000_cuda:0                3106            3222            161
POINT_MESH_EDGE_8_100_300_10000_cuda:0               8561            8648             59
POINT_MESH_EDGE_8_100_3000_5000_cuda:0               6932            7021             73
POINT_MESH_EDGE_8_100_3000_10000_cuda:0             24032           24176             21
POINT_MESH_EDGE_8_1000_300_5000_cuda:0               5272            5399             95
POINT_MESH_EDGE_8_1000_300_10000_cuda:0             11348           11430             45
POINT_MESH_EDGE_8_1000_3000_5000_cuda:0             17478           17683             29
POINT_MESH_EDGE_8_1000_3000_10000_cuda:0            25961           26236             20
POINT_MESH_EDGE_16_100_300_5000_cuda:0               8244            8323             61
POINT_MESH_EDGE_16_100_300_10000_cuda:0             18018           18071             28
POINT_MESH_EDGE_16_100_3000_5000_cuda:0             19428           19544             26
POINT_MESH_EDGE_16_100_3000_10000_cuda:0            44967           45135             12
POINT_MESH_EDGE_16_1000_300_5000_cuda:0              7825            7937             64
POINT_MESH_EDGE_16_1000_300_10000_cuda:0            18504           18571             28
POINT_MESH_EDGE_16_1000_3000_5000_cuda:0            65805           66132              8
POINT_MESH_EDGE_16_1000_3000_10000_cuda:0           90885           91089              6
--------------------------------------------------------------------------------

Benchmark                                       Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
POINT_MESH_FACE_4_100_300_5000_cuda:0                1561            1685            321
POINT_MESH_FACE_4_100_300_10000_cuda:0               2818            2954            178
POINT_MESH_FACE_4_100_3000_5000_cuda:0              15893           16018             32
POINT_MESH_FACE_4_100_3000_10000_cuda:0             16350           16439             31
POINT_MESH_FACE_4_1000_300_5000_cuda:0               3179            3278            158
POINT_MESH_FACE_4_1000_300_10000_cuda:0              2353            2436            213
POINT_MESH_FACE_4_1000_3000_5000_cuda:0             16262           16336             31
POINT_MESH_FACE_4_1000_3000_10000_cuda:0             9334            9448             54
POINT_MESH_FACE_8_100_300_5000_cuda:0                4377            4493            115
POINT_MESH_FACE_8_100_300_10000_cuda:0               9728            9822             52
POINT_MESH_FACE_8_100_3000_5000_cuda:0              26428           26544             19
POINT_MESH_FACE_8_100_3000_10000_cuda:0             42238           43031             12
POINT_MESH_FACE_8_1000_300_5000_cuda:0               3891            3982            129
POINT_MESH_FACE_8_1000_300_10000_cuda:0              5363            5429             94
POINT_MESH_FACE_8_1000_3000_5000_cuda:0             20998           21084             24
POINT_MESH_FACE_8_1000_3000_10000_cuda:0            39711           39897             13
POINT_MESH_FACE_16_100_300_5000_cuda:0               5955            6001             84
POINT_MESH_FACE_16_100_300_10000_cuda:0             12082           12144             42
POINT_MESH_FACE_16_100_3000_5000_cuda:0             44996           45176             12
POINT_MESH_FACE_16_100_3000_10000_cuda:0            73042           73197              7
POINT_MESH_FACE_16_1000_300_5000_cuda:0              8292            8374             61
POINT_MESH_FACE_16_1000_300_10000_cuda:0            19442           19506             26
POINT_MESH_FACE_16_1000_3000_5000_cuda:0            36059           36194             14
POINT_MESH_FACE_16_1000_3000_10000_cuda:0           64644           64822              8
--------------------------------------------------------------------------------
```

Reviewed By: jcjohnson

Differential Revision: D20590462

fbshipit-source-id: 42a39837b514a546ac9471bfaff60eefe7fae829
2020-04-11 00:21:24 -07:00

343 lines
13 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Sanity checks for output images from the renderer.
"""
import unittest
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import (
BlendParams,
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftSilhouetteShader,
TexturedSoftPhongShader,
)
from pytorch3d.renderer.mesh.texturing import Textures
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
def load_rgb_image(filename, data_dir=DATA_DIR):
filepath = data_dir / filename
with Image.open(filepath) as raw_image:
image = torch.from_numpy(np.array(raw_image) / 255.0)
image = image.to(dtype=torch.float32)
return image[..., :3]
class TestRenderingMeshes(unittest.TestCase):
def test_simple_sphere(self, elevated_camera=False):
"""
Test output of phong and gouraud shading matches a reference image using
the default values for the light sources.
Args:
elevated_camera: Defines whether the camera observing the scene should
have an elevation of 45 degrees.
"""
device = torch.device("cuda:0")
# Init mesh
sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded()
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
# Init rasterizer settings
if elevated_camera:
# Elevated and rotated camera
R, T = look_at_view_transform(dist=2.7, elev=45.0, azim=45.0)
postfix = "_elevated_camera"
# If y axis is up, the spot of light should
# be on the bottom left of the sphere.
else:
# No elevation or azimuth rotation
R, T = look_at_view_transform(2.7, 0.0, 0.0)
postfix = ""
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
# Init shader settings
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
# Test several shaders
shaders = {
"phong": HardPhongShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
for (name, shader_init) in shaders.items():
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
image_ref = load_rgb_image("test_%s" % filename)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
########################################################
# Move the light to the +z axis in world space so it is
# behind the sphere. Note that +Z is in, +Y up,
# +X left for both world and camera space.
########################################################
lights.location[..., 2] = -2.0
phong_shader = HardPhongShader(
lights=lights, cameras=cameras, materials=materials
)
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
images = phong_renderer(sphere_mesh, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_simple_sphere_dark%s.png" % postfix
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
# Load reference image
image_ref_phong_dark = load_rgb_image("test_simple_sphere_dark%s.png" % postfix)
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
def test_simple_sphere_elevated_camera(self):
"""
Test output of phong and gouraud shading matches a reference image using
the default values for the light sources.
The rendering is performed with a camera that has non-zero elevation.
"""
self.test_simple_sphere(elevated_camera=True)
def test_simple_sphere_batched(self):
"""
Test a mesh with vertex textures can be extended to form a batch, and
is rendered correctly with Phong, Gouraud and Flat Shaders.
"""
batch_size = 20
device = torch.device("cuda:0")
# Init mesh with vertex textures.
sphere_meshes = ico_sphere(5, device).extend(batch_size)
verts_padded = sphere_meshes.verts_padded()
faces_padded = sphere_meshes.faces_padded()
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
sphere_meshes = Meshes(
verts=verts_padded, faces=faces_padded, textures=textures
)
# Init rasterizer settings
dist = torch.tensor([2.7]).repeat(batch_size).to(device)
elev = torch.zeros_like(dist)
azim = torch.zeros_like(dist)
R, T = look_at_view_transform(dist, elev, azim)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
)
# Init shader settings
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
# Init renderer
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
shaders = {
"phong": HardGouraudShader,
"gouraud": HardGouraudShader,
"flat": HardFlatShader,
}
for (name, shader_init) in shaders.items():
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_meshes)
image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
def test_silhouette_with_grad(self):
"""
Test silhouette blending. Also check that gradient calculation works.
"""
device = torch.device("cuda:0")
ref_filename = "test_silhouette.png"
image_ref_filename = DATA_DIR / ref_filename
sphere_mesh = ico_sphere(5, device)
verts, faces = sphere_mesh.get_mesh_verts_faces(0)
sphere_mesh = Meshes(verts=[verts], faces=[faces])
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=80,
bin_size=0,
)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
# Init renderer
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=SoftSilhouetteShader(blend_params=blend_params),
)
images = renderer(sphere_mesh)
alpha = images[0, ..., 3].squeeze().cpu()
if DEBUG:
Image.fromarray((alpha.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_silhouette.png"
)
with Image.open(image_ref_filename) as raw_image_ref:
image_ref = torch.from_numpy(np.array(raw_image_ref))
image_ref = image_ref.to(dtype=torch.float32) / 255.0
self.assertTrue(torch.allclose(alpha, image_ref, atol=0.055))
# Check grad exist
verts.requires_grad = True
sphere_mesh = Meshes(verts=[verts], faces=[faces])
images = renderer(sphere_mesh)
images[0, ...].sum().backward()
self.assertIsNotNone(verts.grad)
def test_texture_map(self):
"""
Test a mesh with a texture map is loaded and rendered correctly.
The pupils in the eyes of the cow should always be looking to the left.
"""
device = torch.device("cuda:0")
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
# Load mesh + texture
mesh = load_objs_as_meshes([obj_filename], device=device)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
)
# Init shader settings
materials = Materials(device=device)
lights = PointLights(device=device)
# Place light behind the cow in world space. The front of
# the cow is facing the -z direction.
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
# Init renderer
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=TexturedSoftPhongShader(
lights=lights, cameras=cameras, materials=materials
),
)
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_texture_map_back.png")
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_back.png"
)
# NOTE some pixels can be flaky and will not lead to
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
self.assertTrue(cond1 or cond2)
# Check grad exists
[verts] = mesh.verts_list()
verts.requires_grad = True
mesh2 = Meshes(verts=[verts], faces=mesh.faces_list(), textures=mesh.textures)
images = renderer(mesh2)
images[0, ...].sum().backward()
self.assertIsNotNone(verts.grad)
##########################################
# Check rendering of the front of the cow
##########################################
R, T = look_at_view_transform(2.7, 0, 180)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
# Move light to the front of the cow in world space
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
images = renderer(mesh, cameras=cameras, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_texture_map_front.png")
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_front.png"
)
#################################
# Add blurring to rasterization
#################################
R, T = look_at_view_transform(2.7, 0, 180)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=100,
bin_size=0,
)
images = renderer(
mesh.clone(),
cameras=cameras,
raster_settings=raster_settings,
blend_params=blend_params,
)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
)
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))