Add MeshRasterizerOpenGL

Summary:
Adding MeshRasterizerOpenGL, a faster alternative to MeshRasterizer. The new rasterizer follows the ideas from "Differentiable Surface Rendering via non-Differentiable Sampling".

The new rasterizer 20x faster on a 2M face mesh (try pose optimization on Nefertiti from https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/!). The larger the mesh, the larger the speedup.

There are two main disadvantages:
* The new rasterizer works with an OpenGL backend, so requires pycuda.gl and pyopengl installed (though we avoided writing any C++ code, everything is in Python!)
* The new rasterizer is non-differentiable. However, you can still differentiate the rendering function if you use if with the new SplatterPhongShader which we recently added to PyTorch3D (see the original paper cited above).

Reviewed By: patricklabatut, jcjohnson

Differential Revision: D37698816

fbshipit-source-id: 54d120639d3cb001f096237807e54aced0acda25
This commit is contained in:
Krzysztof Chalupka
2022-07-22 15:52:50 -07:00
committed by Facebook GitHub Bot
parent 36edf2b302
commit cb49550486
66 changed files with 1556 additions and 337 deletions

View File

@@ -14,6 +14,7 @@ from pytorch3d.renderer import (
HardGouraudShader,
Materials,
MeshRasterizer,
MeshRasterizerOpenGL,
MeshRenderer,
PointLights,
PointsRasterizationSettings,
@@ -21,18 +22,19 @@ from pytorch3d.renderer import (
PointsRenderer,
RasterizationSettings,
SoftPhongShader,
SplatterPhongShader,
TexturesVertex,
)
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere
from .common_testing import get_random_cuda_device, TestCaseMixin
from .common_testing import TestCaseMixin
# Set the number of GPUS you want to test with
NUM_GPUS = 3
GPU_LIST = list({get_random_cuda_device() for _ in range(NUM_GPUS)})
NUM_GPUS = 2
GPU_LIST = [f"cuda:{idx}" for idx in range(NUM_GPUS)]
print("GPUs: %s" % ", ".join(GPU_LIST))
@@ -56,12 +58,12 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
self.assertEqual(renderer.shader.materials.device, device)
self.assertEqual(renderer.shader.materials.ambient_color.device, device)
def test_mesh_renderer_to(self):
def _mesh_renderer_to(self, rasterizer_class, shader_class):
"""
Test moving all the tensors in the mesh renderer to a new device.
"""
device1 = torch.device("cpu")
device1 = torch.device("cuda:0")
R, T = look_at_view_transform(1500, 0.0, 0.0)
@@ -71,12 +73,12 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None]
raster_settings = RasterizationSettings(
image_size=256, blur_radius=0.0, faces_per_pixel=1
image_size=128, blur_radius=0.0, faces_per_pixel=1
)
cameras = FoVPerspectiveCameras(
device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
rasterizer = rasterizer_class(cameras=cameras, raster_settings=raster_settings)
blend_params = BlendParams(
1e-4,
@@ -84,7 +86,7 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
background_color=torch.zeros(3, dtype=torch.float32, device=device1),
)
shader = SoftPhongShader(
shader = shader_class(
lights=lights,
cameras=cameras,
materials=materials,
@@ -107,26 +109,32 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
# Move renderer and mesh to another device and re render
# This also tests that background_color is correctly moved to
# the new device
device2 = torch.device("cuda:0")
device2 = torch.device("cuda:1")
renderer = renderer.to(device2)
mesh = mesh.to(device2)
self._check_mesh_renderer_props_on_device(renderer, device2)
output_images = renderer(mesh)
self.assertEqual(output_images.device, device2)
def test_render_meshes(self):
def test_mesh_renderer_to(self):
self._mesh_renderer_to(MeshRasterizer, SoftPhongShader)
def test_mesh_renderer_opengl_to(self):
self._mesh_renderer_to(MeshRasterizerOpenGL, SplatterPhongShader)
def _render_meshes(self, rasterizer_class, shader_class):
test = self
class Model(nn.Module):
def __init__(self):
def __init__(self, device):
super(Model, self).__init__()
mesh = ico_sphere(3)
mesh = ico_sphere(3).to(device)
self.register_buffer("faces", mesh.faces_padded())
self.renderer = self.init_render()
self.renderer = self.init_render(device)
def init_render(self):
def init_render(self, device):
cameras = FoVPerspectiveCameras()
cameras = FoVPerspectiveCameras().to(device)
raster_settings = RasterizationSettings(
image_size=128, blur_radius=0.0, faces_per_pixel=1
)
@@ -135,12 +143,12 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
diffuse_color=((0, 0.0, 0),),
specular_color=((0.0, 0, 0),),
location=((0.0, 0.0, 1e5),),
)
).to(device)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
rasterizer=rasterizer_class(
cameras=cameras, raster_settings=raster_settings
),
shader=HardGouraudShader(cameras=cameras, lights=lights),
shader=shader_class(cameras=cameras, lights=lights),
)
return renderer
@@ -155,20 +163,25 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
img_render = self.renderer(mesh)
return img_render[:, :, :, :3]
# DataParallel requires every input tensor be provided
# on the first device in its device_ids list.
verts = ico_sphere(3).verts_padded()
# Make sure we use all GPUs in GPU_LIST by making the batch size 4 x GPU count.
verts = ico_sphere(3).verts_padded().expand(len(GPU_LIST) * 4, 642, 3)
texs = verts.new_ones(verts.shape)
model = Model()
model.to(GPU_LIST[0])
model = Model(device=GPU_LIST[0])
model = nn.DataParallel(model, device_ids=GPU_LIST)
# Test a few iterations
for _ in range(100):
model(verts, texs)
def test_render_meshes(self):
self._render_meshes(MeshRasterizer, HardGouraudShader)
class TestRenderPointssMultiGPU(TestCaseMixin, unittest.TestCase):
# @unittest.skip("Multi-GPU OpenGL training is currently not supported.")
def test_render_meshes_opengl(self):
self._render_meshes(MeshRasterizerOpenGL, SplatterPhongShader)
class TestRenderPointsMultiGPU(TestCaseMixin, unittest.TestCase):
def _check_points_renderer_props_on_device(self, renderer, device):
"""
Helper function to check that all the properties have