mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
Add MeshRasterizerOpenGL
Summary: Adding MeshRasterizerOpenGL, a faster alternative to MeshRasterizer. The new rasterizer follows the ideas from "Differentiable Surface Rendering via non-Differentiable Sampling". The new rasterizer 20x faster on a 2M face mesh (try pose optimization on Nefertiti from https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/!). The larger the mesh, the larger the speedup. There are two main disadvantages: * The new rasterizer works with an OpenGL backend, so requires pycuda.gl and pyopengl installed (though we avoided writing any C++ code, everything is in Python!) * The new rasterizer is non-differentiable. However, you can still differentiate the rendering function if you use if with the new SplatterPhongShader which we recently added to PyTorch3D (see the original paper cited above). Reviewed By: patricklabatut, jcjohnson Differential Revision: D37698816 fbshipit-source-id: 54d120639d3cb001f096237807e54aced0acda25
This commit is contained in:
committed by
Facebook GitHub Bot
parent
36edf2b302
commit
cb49550486
@@ -14,6 +14,7 @@ from pytorch3d.renderer import (
|
||||
HardGouraudShader,
|
||||
Materials,
|
||||
MeshRasterizer,
|
||||
MeshRasterizerOpenGL,
|
||||
MeshRenderer,
|
||||
PointLights,
|
||||
PointsRasterizationSettings,
|
||||
@@ -21,18 +22,19 @@ from pytorch3d.renderer import (
|
||||
PointsRenderer,
|
||||
RasterizationSettings,
|
||||
SoftPhongShader,
|
||||
SplatterPhongShader,
|
||||
TexturesVertex,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
from .common_testing import get_random_cuda_device, TestCaseMixin
|
||||
from .common_testing import TestCaseMixin
|
||||
|
||||
|
||||
# Set the number of GPUS you want to test with
|
||||
NUM_GPUS = 3
|
||||
GPU_LIST = list({get_random_cuda_device() for _ in range(NUM_GPUS)})
|
||||
NUM_GPUS = 2
|
||||
GPU_LIST = [f"cuda:{idx}" for idx in range(NUM_GPUS)]
|
||||
print("GPUs: %s" % ", ".join(GPU_LIST))
|
||||
|
||||
|
||||
@@ -56,12 +58,12 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(renderer.shader.materials.device, device)
|
||||
self.assertEqual(renderer.shader.materials.ambient_color.device, device)
|
||||
|
||||
def test_mesh_renderer_to(self):
|
||||
def _mesh_renderer_to(self, rasterizer_class, shader_class):
|
||||
"""
|
||||
Test moving all the tensors in the mesh renderer to a new device.
|
||||
"""
|
||||
|
||||
device1 = torch.device("cpu")
|
||||
device1 = torch.device("cuda:0")
|
||||
|
||||
R, T = look_at_view_transform(1500, 0.0, 0.0)
|
||||
|
||||
@@ -71,12 +73,12 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None]
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=256, blur_radius=0.0, faces_per_pixel=1
|
||||
image_size=128, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
cameras = FoVPerspectiveCameras(
|
||||
device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
rasterizer = rasterizer_class(cameras=cameras, raster_settings=raster_settings)
|
||||
|
||||
blend_params = BlendParams(
|
||||
1e-4,
|
||||
@@ -84,7 +86,7 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
background_color=torch.zeros(3, dtype=torch.float32, device=device1),
|
||||
)
|
||||
|
||||
shader = SoftPhongShader(
|
||||
shader = shader_class(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
@@ -107,26 +109,32 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
# Move renderer and mesh to another device and re render
|
||||
# This also tests that background_color is correctly moved to
|
||||
# the new device
|
||||
device2 = torch.device("cuda:0")
|
||||
device2 = torch.device("cuda:1")
|
||||
renderer = renderer.to(device2)
|
||||
mesh = mesh.to(device2)
|
||||
self._check_mesh_renderer_props_on_device(renderer, device2)
|
||||
output_images = renderer(mesh)
|
||||
self.assertEqual(output_images.device, device2)
|
||||
|
||||
def test_render_meshes(self):
|
||||
def test_mesh_renderer_to(self):
|
||||
self._mesh_renderer_to(MeshRasterizer, SoftPhongShader)
|
||||
|
||||
def test_mesh_renderer_opengl_to(self):
|
||||
self._mesh_renderer_to(MeshRasterizerOpenGL, SplatterPhongShader)
|
||||
|
||||
def _render_meshes(self, rasterizer_class, shader_class):
|
||||
test = self
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, device):
|
||||
super(Model, self).__init__()
|
||||
mesh = ico_sphere(3)
|
||||
mesh = ico_sphere(3).to(device)
|
||||
self.register_buffer("faces", mesh.faces_padded())
|
||||
self.renderer = self.init_render()
|
||||
self.renderer = self.init_render(device)
|
||||
|
||||
def init_render(self):
|
||||
def init_render(self, device):
|
||||
|
||||
cameras = FoVPerspectiveCameras()
|
||||
cameras = FoVPerspectiveCameras().to(device)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=128, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
@@ -135,12 +143,12 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
diffuse_color=((0, 0.0, 0),),
|
||||
specular_color=((0.0, 0, 0),),
|
||||
location=((0.0, 0.0, 1e5),),
|
||||
)
|
||||
).to(device)
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(
|
||||
rasterizer=rasterizer_class(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
),
|
||||
shader=HardGouraudShader(cameras=cameras, lights=lights),
|
||||
shader=shader_class(cameras=cameras, lights=lights),
|
||||
)
|
||||
return renderer
|
||||
|
||||
@@ -155,20 +163,25 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
img_render = self.renderer(mesh)
|
||||
return img_render[:, :, :, :3]
|
||||
|
||||
# DataParallel requires every input tensor be provided
|
||||
# on the first device in its device_ids list.
|
||||
verts = ico_sphere(3).verts_padded()
|
||||
# Make sure we use all GPUs in GPU_LIST by making the batch size 4 x GPU count.
|
||||
verts = ico_sphere(3).verts_padded().expand(len(GPU_LIST) * 4, 642, 3)
|
||||
texs = verts.new_ones(verts.shape)
|
||||
model = Model()
|
||||
model.to(GPU_LIST[0])
|
||||
model = Model(device=GPU_LIST[0])
|
||||
model = nn.DataParallel(model, device_ids=GPU_LIST)
|
||||
|
||||
# Test a few iterations
|
||||
for _ in range(100):
|
||||
model(verts, texs)
|
||||
|
||||
def test_render_meshes(self):
|
||||
self._render_meshes(MeshRasterizer, HardGouraudShader)
|
||||
|
||||
class TestRenderPointssMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
# @unittest.skip("Multi-GPU OpenGL training is currently not supported.")
|
||||
def test_render_meshes_opengl(self):
|
||||
self._render_meshes(MeshRasterizerOpenGL, SplatterPhongShader)
|
||||
|
||||
|
||||
class TestRenderPointsMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
def _check_points_renderer_props_on_device(self, renderer, device):
|
||||
"""
|
||||
Helper function to check that all the properties have
|
||||
|
||||
Reference in New Issue
Block a user