Add MeshRasterizerOpenGL

Summary:
Adding MeshRasterizerOpenGL, a faster alternative to MeshRasterizer. The new rasterizer follows the ideas from "Differentiable Surface Rendering via non-Differentiable Sampling".

The new rasterizer 20x faster on a 2M face mesh (try pose optimization on Nefertiti from https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/!). The larger the mesh, the larger the speedup.

There are two main disadvantages:
* The new rasterizer works with an OpenGL backend, so requires pycuda.gl and pyopengl installed (though we avoided writing any C++ code, everything is in Python!)
* The new rasterizer is non-differentiable. However, you can still differentiate the rendering function if you use if with the new SplatterPhongShader which we recently added to PyTorch3D (see the original paper cited above).

Reviewed By: patricklabatut, jcjohnson

Differential Revision: D37698816

fbshipit-source-id: 54d120639d3cb001f096237807e54aced0acda25
This commit is contained in:
Krzysztof Chalupka
2022-07-22 15:52:50 -07:00
committed by Facebook GitHub Bot
parent 36edf2b302
commit cb49550486
66 changed files with 1556 additions and 337 deletions

View File

@@ -10,16 +10,29 @@ import unittest
import numpy as np
import torch
from PIL import Image
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.points.rasterizer import (
from pytorch3d.renderer import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
look_at_view_transform,
MeshRasterizer,
MeshRasterizerOpenGL,
OrthographicCameras,
PerspectiveCameras,
PointsRasterizationSettings,
PointsRasterizer,
RasterizationSettings,
)
from pytorch3d.renderer.opengl.rasterizer_opengl import (
_check_cameras,
_check_raster_settings,
_convert_meshes_to_gl_ndc,
_parse_and_verify_image_size,
)
from pytorch3d.structures import Pointclouds
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere
from .common_testing import get_tests_dir
from .common_testing import get_tests_dir, TestCaseMixin
DATA_DIR = get_tests_dir() / "data"
@@ -36,8 +49,14 @@ def convert_image_to_binary_mask(filename):
class TestMeshRasterizer(unittest.TestCase):
def test_simple_sphere(self):
self._simple_sphere(MeshRasterizer)
def test_simple_sphere_opengl(self):
self._simple_sphere(MeshRasterizerOpenGL)
def _simple_sphere(self, rasterizer_type):
device = torch.device("cuda:0")
ref_filename = "test_rasterized_sphere.png"
ref_filename = f"test_rasterized_sphere_{rasterizer_type.__name__}.png"
image_ref_filename = DATA_DIR / ref_filename
# Rescale image_ref to the 0 - 1 range and convert to a binary mask.
@@ -54,7 +73,7 @@ class TestMeshRasterizer(unittest.TestCase):
)
# Init rasterizer
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
rasterizer = rasterizer_type(cameras=cameras, raster_settings=raster_settings)
####################################
# 1. Test rasterizing a single mesh
@@ -68,7 +87,8 @@ class TestMeshRasterizer(unittest.TestCase):
if DEBUG:
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_sphere.png"
DATA_DIR
/ f"DEBUG_test_rasterized_sphere_{rasterizer_type.__name__}.png"
)
self.assertTrue(torch.allclose(image, image_ref))
@@ -90,20 +110,21 @@ class TestMeshRasterizer(unittest.TestCase):
# 3. Test that passing kwargs to rasterizer works.
####################################################
# Change the view transform to zoom in.
R, T = look_at_view_transform(2.0, 0, 0, device=device)
# Change the view transform to zoom out.
R, T = look_at_view_transform(20.0, 0, 0, device=device)
fragments = rasterizer(sphere_mesh, R=R, T=T)
image = fragments.pix_to_face[0, ..., 0].squeeze().cpu()
image[image >= 0] = 1.0
image[image < 0] = 0.0
ref_filename = "test_rasterized_sphere_zoom.png"
ref_filename = f"test_rasterized_sphere_zoom_{rasterizer_type.__name__}.png"
image_ref_filename = DATA_DIR / ref_filename
image_ref = convert_image_to_binary_mask(image_ref_filename)
if DEBUG:
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_sphere_zoom.png"
DATA_DIR
/ f"DEBUG_test_rasterized_sphere_zoom_{rasterizer_type.__name__}.png"
)
self.assertTrue(torch.allclose(image, image_ref))
@@ -112,7 +133,7 @@ class TestMeshRasterizer(unittest.TestCase):
##################################
# Create a new empty rasterizer:
rasterizer = MeshRasterizer()
rasterizer = rasterizer_type(raster_settings=raster_settings)
# Check that omitting the cameras in both initialization
# and the forward pass throws an error:
@@ -120,9 +141,7 @@ class TestMeshRasterizer(unittest.TestCase):
rasterizer(sphere_mesh)
# Now pass in the cameras as a kwarg
fragments = rasterizer(
sphere_mesh, cameras=cameras, raster_settings=raster_settings
)
fragments = rasterizer(sphere_mesh, cameras=cameras)
image = fragments.pix_to_face[0, ..., 0].squeeze().cpu()
# Convert pix_to_face to a binary mask
image[image >= 0] = 1.0
@@ -130,7 +149,8 @@ class TestMeshRasterizer(unittest.TestCase):
if DEBUG:
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_sphere.png"
DATA_DIR
/ f"DEBUG_test_rasterized_sphere_{rasterizer_type.__name__}.png"
)
self.assertTrue(torch.allclose(image, image_ref))
@@ -141,6 +161,187 @@ class TestMeshRasterizer(unittest.TestCase):
rasterizer = MeshRasterizer()
rasterizer.to(device)
rasterizer = MeshRasterizerOpenGL()
rasterizer.to(device)
def test_compare_rasterizers(self):
device = torch.device("cuda:0")
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=0.0,
faces_per_pixel=1,
bin_size=0,
perspective_correct=True,
)
from pytorch3d.io import load_obj
from pytorch3d.renderer import TexturesAtlas
from .common_testing import get_pytorch3d_dir
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj"
# Load mesh and texture as a per face texture atlas.
verts, faces, aux = load_obj(
obj_filename,
device=device,
load_textures=True,
create_texture_atlas=True,
texture_atlas_size=8,
texture_wrap=None,
)
atlas = aux.texture_atlas
mesh = Meshes(
verts=[verts],
faces=[faces.verts_idx],
textures=TexturesAtlas(atlas=[atlas]),
)
# Rasterize using both rasterizers and compare results.
rasterizer = MeshRasterizerOpenGL(
cameras=cameras, raster_settings=raster_settings
)
fragments_opengl = rasterizer(mesh)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
fragments = rasterizer(mesh)
# Ensure that 99.9% of bary_coords is at most 0.001 different.
self.assertLess(
torch.quantile(
(fragments.bary_coords - fragments_opengl.bary_coords).abs(), 0.999
),
0.001,
)
# Ensure that 99.9% of zbuf vals is at most 0.001 different.
self.assertLess(
torch.quantile((fragments.zbuf - fragments_opengl.zbuf).abs(), 0.999), 0.001
)
# Ensure that 99.99% of pix_to_face is identical.
self.assertEqual(
torch.quantile(
(fragments.pix_to_face != fragments_opengl.pix_to_face).float(), 0.9999
),
0,
)
class TestMeshRasterizerOpenGLUtils(TestCaseMixin, unittest.TestCase):
def setUp(self):
verts = torch.tensor(
[[-1, 1, 0], [1, 1, 0], [1, -1, 0]], dtype=torch.float32
).cuda()
faces = torch.tensor([[0, 1, 2]]).cuda()
self.meshes_world = Meshes(verts=[verts], faces=[faces])
# Test various utils specific to the OpenGL rasterizer. Full "integration tests"
# live in test_render_meshes and test_render_multigpu.
def test_check_cameras(self):
_check_cameras(FoVPerspectiveCameras())
_check_cameras(FoVPerspectiveCameras())
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
_check_cameras(None)
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
_check_cameras(PerspectiveCameras())
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
_check_cameras(OrthographicCameras())
MeshRasterizerOpenGL(FoVPerspectiveCameras().cuda())(self.meshes_world)
MeshRasterizerOpenGL(FoVOrthographicCameras().cuda())(self.meshes_world)
MeshRasterizerOpenGL()(
self.meshes_world, cameras=FoVPerspectiveCameras().cuda()
)
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
MeshRasterizerOpenGL(PerspectiveCameras().cuda())(self.meshes_world)
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
MeshRasterizerOpenGL(OrthographicCameras().cuda())(self.meshes_world)
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
MeshRasterizerOpenGL()(self.meshes_world)
def test_check_raster_settings(self):
raster_settings = RasterizationSettings()
raster_settings.faces_per_pixel = 100
with self.assertWarnsRegex(UserWarning, ".* one face per pixel"):
_check_raster_settings(raster_settings)
with self.assertWarnsRegex(UserWarning, ".* one face per pixel"):
MeshRasterizerOpenGL(raster_settings=raster_settings)(
self.meshes_world, cameras=FoVPerspectiveCameras().cuda()
)
def test_convert_meshes_to_gl_ndc_square_img(self):
R, T = look_at_view_transform(1, 90, 180)
cameras = FoVOrthographicCameras(R=R, T=T).cuda()
meshes_gl_ndc = _convert_meshes_to_gl_ndc(
self.meshes_world, (100, 100), cameras
)
# After look_at_view_transform rotating 180 deg around z-axis, we recover
# the original coordinates. After additionally elevating the view by 90
# deg, we "zero out" the y-coordinate. Finally, we negate the x and y axes
# to adhere to OpenGL conventions (which go against the PyTorch3D convention).
self.assertClose(
meshes_gl_ndc.verts_list()[0],
torch.tensor(
[[-1, 0, 0], [1, 0, 0], [1, 0, 2]], dtype=torch.float32
).cuda(),
atol=1e-5,
)
def test_parse_and_verify_image_size(self):
img_size = _parse_and_verify_image_size(512)
self.assertEqual(img_size, (512, 512))
img_size = _parse_and_verify_image_size((2047, 10))
self.assertEqual(img_size, (2047, 10))
img_size = _parse_and_verify_image_size((10, 2047))
self.assertEqual(img_size, (10, 2047))
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
_parse_and_verify_image_size((2049, 512))
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
_parse_and_verify_image_size((512, 2049))
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
_parse_and_verify_image_size((2049, 2049))
rasterizer = MeshRasterizerOpenGL(FoVPerspectiveCameras().cuda())
raster_settings = RasterizationSettings()
raster_settings.image_size = 512
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 512, 512, 1]))
raster_settings.image_size = (2047, 10)
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 2047, 10, 1]))
raster_settings.image_size = (10, 2047)
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 10, 2047, 1]))
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
raster_settings.image_size = (2049, 512)
rasterizer(self.meshes_world, raster_settings=raster_settings)
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
raster_settings.image_size = (512, 2049)
rasterizer(self.meshes_world, raster_settings=raster_settings)
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
raster_settings.image_size = (2049, 2049)
rasterizer(self.meshes_world, raster_settings=raster_settings)
class TestPointRasterizer(unittest.TestCase):
def test_simple_sphere(self):