mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Summary: Implementation of point to mesh distances. The current diff contains two types: (a) Point to Edge (b) Point to Face ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- POINT_MESH_EDGE_4_100_300_5000_cuda:0 2745 3138 183 POINT_MESH_EDGE_4_100_300_10000_cuda:0 4408 4499 114 POINT_MESH_EDGE_4_100_3000_5000_cuda:0 4978 5070 101 POINT_MESH_EDGE_4_100_3000_10000_cuda:0 9076 9187 56 POINT_MESH_EDGE_4_1000_300_5000_cuda:0 1411 1487 355 POINT_MESH_EDGE_4_1000_300_10000_cuda:0 4829 5030 104 POINT_MESH_EDGE_4_1000_3000_5000_cuda:0 7539 7620 67 POINT_MESH_EDGE_4_1000_3000_10000_cuda:0 12088 12272 42 POINT_MESH_EDGE_8_100_300_5000_cuda:0 3106 3222 161 POINT_MESH_EDGE_8_100_300_10000_cuda:0 8561 8648 59 POINT_MESH_EDGE_8_100_3000_5000_cuda:0 6932 7021 73 POINT_MESH_EDGE_8_100_3000_10000_cuda:0 24032 24176 21 POINT_MESH_EDGE_8_1000_300_5000_cuda:0 5272 5399 95 POINT_MESH_EDGE_8_1000_300_10000_cuda:0 11348 11430 45 POINT_MESH_EDGE_8_1000_3000_5000_cuda:0 17478 17683 29 POINT_MESH_EDGE_8_1000_3000_10000_cuda:0 25961 26236 20 POINT_MESH_EDGE_16_100_300_5000_cuda:0 8244 8323 61 POINT_MESH_EDGE_16_100_300_10000_cuda:0 18018 18071 28 POINT_MESH_EDGE_16_100_3000_5000_cuda:0 19428 19544 26 POINT_MESH_EDGE_16_100_3000_10000_cuda:0 44967 45135 12 POINT_MESH_EDGE_16_1000_300_5000_cuda:0 7825 7937 64 POINT_MESH_EDGE_16_1000_300_10000_cuda:0 18504 18571 28 POINT_MESH_EDGE_16_1000_3000_5000_cuda:0 65805 66132 8 POINT_MESH_EDGE_16_1000_3000_10000_cuda:0 90885 91089 6 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- POINT_MESH_FACE_4_100_300_5000_cuda:0 1561 1685 321 POINT_MESH_FACE_4_100_300_10000_cuda:0 2818 2954 178 POINT_MESH_FACE_4_100_3000_5000_cuda:0 15893 16018 32 POINT_MESH_FACE_4_100_3000_10000_cuda:0 16350 16439 31 POINT_MESH_FACE_4_1000_300_5000_cuda:0 3179 3278 158 POINT_MESH_FACE_4_1000_300_10000_cuda:0 2353 2436 213 POINT_MESH_FACE_4_1000_3000_5000_cuda:0 16262 16336 31 POINT_MESH_FACE_4_1000_3000_10000_cuda:0 9334 9448 54 POINT_MESH_FACE_8_100_300_5000_cuda:0 4377 4493 115 POINT_MESH_FACE_8_100_300_10000_cuda:0 9728 9822 52 POINT_MESH_FACE_8_100_3000_5000_cuda:0 26428 26544 19 POINT_MESH_FACE_8_100_3000_10000_cuda:0 42238 43031 12 POINT_MESH_FACE_8_1000_300_5000_cuda:0 3891 3982 129 POINT_MESH_FACE_8_1000_300_10000_cuda:0 5363 5429 94 POINT_MESH_FACE_8_1000_3000_5000_cuda:0 20998 21084 24 POINT_MESH_FACE_8_1000_3000_10000_cuda:0 39711 39897 13 POINT_MESH_FACE_16_100_300_5000_cuda:0 5955 6001 84 POINT_MESH_FACE_16_100_300_10000_cuda:0 12082 12144 42 POINT_MESH_FACE_16_100_3000_5000_cuda:0 44996 45176 12 POINT_MESH_FACE_16_100_3000_10000_cuda:0 73042 73197 7 POINT_MESH_FACE_16_1000_300_5000_cuda:0 8292 8374 61 POINT_MESH_FACE_16_1000_300_10000_cuda:0 19442 19506 26 POINT_MESH_FACE_16_1000_3000_5000_cuda:0 36059 36194 14 POINT_MESH_FACE_16_1000_3000_10000_cuda:0 64644 64822 8 -------------------------------------------------------------------------------- ``` Reviewed By: jcjohnson Differential Revision: D20590462 fbshipit-source-id: 42a39837b514a546ac9471bfaff60eefe7fae829
343 lines
13 KiB
Python
343 lines
13 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
|
|
"""
|
|
Sanity checks for output images from the renderer.
|
|
"""
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from pytorch3d.io import load_objs_as_meshes
|
|
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
|
|
from pytorch3d.renderer.lighting import PointLights
|
|
from pytorch3d.renderer.materials import Materials
|
|
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
|
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
|
from pytorch3d.renderer.mesh.shader import (
|
|
BlendParams,
|
|
HardFlatShader,
|
|
HardGouraudShader,
|
|
HardPhongShader,
|
|
SoftSilhouetteShader,
|
|
TexturedSoftPhongShader,
|
|
)
|
|
from pytorch3d.renderer.mesh.texturing import Textures
|
|
from pytorch3d.structures.meshes import Meshes
|
|
from pytorch3d.utils.ico_sphere import ico_sphere
|
|
|
|
|
|
# If DEBUG=True, save out images generated in the tests for debugging.
|
|
# All saved images have prefix DEBUG_
|
|
DEBUG = False
|
|
DATA_DIR = Path(__file__).resolve().parent / "data"
|
|
|
|
|
|
def load_rgb_image(filename, data_dir=DATA_DIR):
|
|
filepath = data_dir / filename
|
|
with Image.open(filepath) as raw_image:
|
|
image = torch.from_numpy(np.array(raw_image) / 255.0)
|
|
image = image.to(dtype=torch.float32)
|
|
return image[..., :3]
|
|
|
|
|
|
class TestRenderingMeshes(unittest.TestCase):
|
|
def test_simple_sphere(self, elevated_camera=False):
|
|
"""
|
|
Test output of phong and gouraud shading matches a reference image using
|
|
the default values for the light sources.
|
|
|
|
Args:
|
|
elevated_camera: Defines whether the camera observing the scene should
|
|
have an elevation of 45 degrees.
|
|
"""
|
|
device = torch.device("cuda:0")
|
|
|
|
# Init mesh
|
|
sphere_mesh = ico_sphere(5, device)
|
|
verts_padded = sphere_mesh.verts_padded()
|
|
faces_padded = sphere_mesh.faces_padded()
|
|
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
|
|
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
|
|
|
# Init rasterizer settings
|
|
if elevated_camera:
|
|
# Elevated and rotated camera
|
|
R, T = look_at_view_transform(dist=2.7, elev=45.0, azim=45.0)
|
|
postfix = "_elevated_camera"
|
|
# If y axis is up, the spot of light should
|
|
# be on the bottom left of the sphere.
|
|
else:
|
|
# No elevation or azimuth rotation
|
|
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
|
postfix = ""
|
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
|
|
|
# Init shader settings
|
|
materials = Materials(device=device)
|
|
lights = PointLights(device=device)
|
|
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
|
|
|
raster_settings = RasterizationSettings(
|
|
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
|
)
|
|
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
|
|
|
# Test several shaders
|
|
shaders = {
|
|
"phong": HardPhongShader,
|
|
"gouraud": HardGouraudShader,
|
|
"flat": HardFlatShader,
|
|
}
|
|
for (name, shader_init) in shaders.items():
|
|
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
|
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
|
images = renderer(sphere_mesh)
|
|
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
|
image_ref = load_rgb_image("test_%s" % filename)
|
|
rgb = images[0, ..., :3].squeeze().cpu()
|
|
if DEBUG:
|
|
filename = "DEBUG_" % filename
|
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
DATA_DIR / filename
|
|
)
|
|
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
|
|
|
########################################################
|
|
# Move the light to the +z axis in world space so it is
|
|
# behind the sphere. Note that +Z is in, +Y up,
|
|
# +X left for both world and camera space.
|
|
########################################################
|
|
lights.location[..., 2] = -2.0
|
|
phong_shader = HardPhongShader(
|
|
lights=lights, cameras=cameras, materials=materials
|
|
)
|
|
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
|
|
images = phong_renderer(sphere_mesh, lights=lights)
|
|
rgb = images[0, ..., :3].squeeze().cpu()
|
|
if DEBUG:
|
|
filename = "DEBUG_simple_sphere_dark%s.png" % postfix
|
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
DATA_DIR / filename
|
|
)
|
|
|
|
# Load reference image
|
|
image_ref_phong_dark = load_rgb_image("test_simple_sphere_dark%s.png" % postfix)
|
|
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
|
|
|
|
def test_simple_sphere_elevated_camera(self):
|
|
"""
|
|
Test output of phong and gouraud shading matches a reference image using
|
|
the default values for the light sources.
|
|
|
|
The rendering is performed with a camera that has non-zero elevation.
|
|
"""
|
|
self.test_simple_sphere(elevated_camera=True)
|
|
|
|
def test_simple_sphere_batched(self):
|
|
"""
|
|
Test a mesh with vertex textures can be extended to form a batch, and
|
|
is rendered correctly with Phong, Gouraud and Flat Shaders.
|
|
"""
|
|
batch_size = 20
|
|
device = torch.device("cuda:0")
|
|
|
|
# Init mesh with vertex textures.
|
|
sphere_meshes = ico_sphere(5, device).extend(batch_size)
|
|
verts_padded = sphere_meshes.verts_padded()
|
|
faces_padded = sphere_meshes.faces_padded()
|
|
textures = Textures(verts_rgb=torch.ones_like(verts_padded))
|
|
sphere_meshes = Meshes(
|
|
verts=verts_padded, faces=faces_padded, textures=textures
|
|
)
|
|
|
|
# Init rasterizer settings
|
|
dist = torch.tensor([2.7]).repeat(batch_size).to(device)
|
|
elev = torch.zeros_like(dist)
|
|
azim = torch.zeros_like(dist)
|
|
R, T = look_at_view_transform(dist, elev, azim)
|
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
|
raster_settings = RasterizationSettings(
|
|
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
|
)
|
|
|
|
# Init shader settings
|
|
materials = Materials(device=device)
|
|
lights = PointLights(device=device)
|
|
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
|
|
|
# Init renderer
|
|
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
|
shaders = {
|
|
"phong": HardGouraudShader,
|
|
"gouraud": HardGouraudShader,
|
|
"flat": HardFlatShader,
|
|
}
|
|
for (name, shader_init) in shaders.items():
|
|
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
|
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
|
images = renderer(sphere_meshes)
|
|
image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
|
|
for i in range(batch_size):
|
|
rgb = images[i, ..., :3].squeeze().cpu()
|
|
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
|
|
|
def test_silhouette_with_grad(self):
|
|
"""
|
|
Test silhouette blending. Also check that gradient calculation works.
|
|
"""
|
|
device = torch.device("cuda:0")
|
|
ref_filename = "test_silhouette.png"
|
|
image_ref_filename = DATA_DIR / ref_filename
|
|
sphere_mesh = ico_sphere(5, device)
|
|
verts, faces = sphere_mesh.get_mesh_verts_faces(0)
|
|
sphere_mesh = Meshes(verts=[verts], faces=[faces])
|
|
|
|
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
|
|
raster_settings = RasterizationSettings(
|
|
image_size=512,
|
|
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
|
faces_per_pixel=80,
|
|
bin_size=0,
|
|
)
|
|
|
|
# Init rasterizer settings
|
|
R, T = look_at_view_transform(2.7, 0, 0)
|
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
|
|
|
# Init renderer
|
|
renderer = MeshRenderer(
|
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
|
shader=SoftSilhouetteShader(blend_params=blend_params),
|
|
)
|
|
images = renderer(sphere_mesh)
|
|
alpha = images[0, ..., 3].squeeze().cpu()
|
|
if DEBUG:
|
|
Image.fromarray((alpha.numpy() * 255).astype(np.uint8)).save(
|
|
DATA_DIR / "DEBUG_silhouette.png"
|
|
)
|
|
|
|
with Image.open(image_ref_filename) as raw_image_ref:
|
|
image_ref = torch.from_numpy(np.array(raw_image_ref))
|
|
image_ref = image_ref.to(dtype=torch.float32) / 255.0
|
|
self.assertTrue(torch.allclose(alpha, image_ref, atol=0.055))
|
|
|
|
# Check grad exist
|
|
verts.requires_grad = True
|
|
sphere_mesh = Meshes(verts=[verts], faces=[faces])
|
|
images = renderer(sphere_mesh)
|
|
images[0, ...].sum().backward()
|
|
self.assertIsNotNone(verts.grad)
|
|
|
|
def test_texture_map(self):
|
|
"""
|
|
Test a mesh with a texture map is loaded and rendered correctly.
|
|
The pupils in the eyes of the cow should always be looking to the left.
|
|
"""
|
|
device = torch.device("cuda:0")
|
|
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
|
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
|
|
|
|
# Load mesh + texture
|
|
mesh = load_objs_as_meshes([obj_filename], device=device)
|
|
|
|
# Init rasterizer settings
|
|
R, T = look_at_view_transform(2.7, 0, 0)
|
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
|
raster_settings = RasterizationSettings(
|
|
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
|
)
|
|
|
|
# Init shader settings
|
|
materials = Materials(device=device)
|
|
lights = PointLights(device=device)
|
|
|
|
# Place light behind the cow in world space. The front of
|
|
# the cow is facing the -z direction.
|
|
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
|
|
|
|
# Init renderer
|
|
renderer = MeshRenderer(
|
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
|
shader=TexturedSoftPhongShader(
|
|
lights=lights, cameras=cameras, materials=materials
|
|
),
|
|
)
|
|
images = renderer(mesh)
|
|
rgb = images[0, ..., :3].squeeze().cpu()
|
|
|
|
# Load reference image
|
|
image_ref = load_rgb_image("test_texture_map_back.png")
|
|
|
|
if DEBUG:
|
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
DATA_DIR / "DEBUG_texture_map_back.png"
|
|
)
|
|
|
|
# NOTE some pixels can be flaky and will not lead to
|
|
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
|
|
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
|
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
|
|
self.assertTrue(cond1 or cond2)
|
|
|
|
# Check grad exists
|
|
[verts] = mesh.verts_list()
|
|
verts.requires_grad = True
|
|
mesh2 = Meshes(verts=[verts], faces=mesh.faces_list(), textures=mesh.textures)
|
|
images = renderer(mesh2)
|
|
images[0, ...].sum().backward()
|
|
self.assertIsNotNone(verts.grad)
|
|
|
|
##########################################
|
|
# Check rendering of the front of the cow
|
|
##########################################
|
|
|
|
R, T = look_at_view_transform(2.7, 0, 180)
|
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
|
|
|
# Move light to the front of the cow in world space
|
|
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
|
images = renderer(mesh, cameras=cameras, lights=lights)
|
|
rgb = images[0, ..., :3].squeeze().cpu()
|
|
|
|
# Load reference image
|
|
image_ref = load_rgb_image("test_texture_map_front.png")
|
|
|
|
if DEBUG:
|
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
DATA_DIR / "DEBUG_texture_map_front.png"
|
|
)
|
|
|
|
#################################
|
|
# Add blurring to rasterization
|
|
#################################
|
|
R, T = look_at_view_transform(2.7, 0, 180)
|
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
|
blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
|
|
raster_settings = RasterizationSettings(
|
|
image_size=512,
|
|
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
|
faces_per_pixel=100,
|
|
bin_size=0,
|
|
)
|
|
|
|
images = renderer(
|
|
mesh.clone(),
|
|
cameras=cameras,
|
|
raster_settings=raster_settings,
|
|
blend_params=blend_params,
|
|
)
|
|
rgb = images[0, ..., :3].squeeze().cpu()
|
|
|
|
# Load reference image
|
|
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
|
|
|
|
if DEBUG:
|
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|