mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-23 07:40:34 +08:00
Add MeshRasterizerOpenGL
Summary: Adding MeshRasterizerOpenGL, a faster alternative to MeshRasterizer. The new rasterizer follows the ideas from "Differentiable Surface Rendering via non-Differentiable Sampling". The new rasterizer 20x faster on a 2M face mesh (try pose optimization on Nefertiti from https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/!). The larger the mesh, the larger the speedup. There are two main disadvantages: * The new rasterizer works with an OpenGL backend, so requires pycuda.gl and pyopengl installed (though we avoided writing any C++ code, everything is in Python!) * The new rasterizer is non-differentiable. However, you can still differentiate the rendering function if you use if with the new SplatterPhongShader which we recently added to PyTorch3D (see the original paper cited above). Reviewed By: patricklabatut, jcjohnson Differential Revision: D37698816 fbshipit-source-id: 54d120639d3cb001f096237807e54aced0acda25
This commit is contained in:
committed by
Facebook GitHub Bot
parent
36edf2b302
commit
cb49550486
@@ -10,16 +10,29 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from pytorch3d.renderer.points.rasterizer import (
|
||||
from pytorch3d.renderer import (
|
||||
FoVOrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
look_at_view_transform,
|
||||
MeshRasterizer,
|
||||
MeshRasterizerOpenGL,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
PointsRasterizationSettings,
|
||||
PointsRasterizer,
|
||||
RasterizationSettings,
|
||||
)
|
||||
from pytorch3d.renderer.opengl.rasterizer_opengl import (
|
||||
_check_cameras,
|
||||
_check_raster_settings,
|
||||
_convert_meshes_to_gl_ndc,
|
||||
_parse_and_verify_image_size,
|
||||
)
|
||||
from pytorch3d.structures import Pointclouds
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
from .common_testing import get_tests_dir
|
||||
from .common_testing import get_tests_dir, TestCaseMixin
|
||||
|
||||
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
@@ -36,8 +49,14 @@ def convert_image_to_binary_mask(filename):
|
||||
|
||||
class TestMeshRasterizer(unittest.TestCase):
|
||||
def test_simple_sphere(self):
|
||||
self._simple_sphere(MeshRasterizer)
|
||||
|
||||
def test_simple_sphere_opengl(self):
|
||||
self._simple_sphere(MeshRasterizerOpenGL)
|
||||
|
||||
def _simple_sphere(self, rasterizer_type):
|
||||
device = torch.device("cuda:0")
|
||||
ref_filename = "test_rasterized_sphere.png"
|
||||
ref_filename = f"test_rasterized_sphere_{rasterizer_type.__name__}.png"
|
||||
image_ref_filename = DATA_DIR / ref_filename
|
||||
|
||||
# Rescale image_ref to the 0 - 1 range and convert to a binary mask.
|
||||
@@ -54,7 +73,7 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Init rasterizer
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
rasterizer = rasterizer_type(cameras=cameras, raster_settings=raster_settings)
|
||||
|
||||
####################################
|
||||
# 1. Test rasterizing a single mesh
|
||||
@@ -68,7 +87,8 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_test_rasterized_sphere.png"
|
||||
DATA_DIR
|
||||
/ f"DEBUG_test_rasterized_sphere_{rasterizer_type.__name__}.png"
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(image, image_ref))
|
||||
@@ -90,20 +110,21 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
# 3. Test that passing kwargs to rasterizer works.
|
||||
####################################################
|
||||
|
||||
# Change the view transform to zoom in.
|
||||
R, T = look_at_view_transform(2.0, 0, 0, device=device)
|
||||
# Change the view transform to zoom out.
|
||||
R, T = look_at_view_transform(20.0, 0, 0, device=device)
|
||||
fragments = rasterizer(sphere_mesh, R=R, T=T)
|
||||
image = fragments.pix_to_face[0, ..., 0].squeeze().cpu()
|
||||
image[image >= 0] = 1.0
|
||||
image[image < 0] = 0.0
|
||||
|
||||
ref_filename = "test_rasterized_sphere_zoom.png"
|
||||
ref_filename = f"test_rasterized_sphere_zoom_{rasterizer_type.__name__}.png"
|
||||
image_ref_filename = DATA_DIR / ref_filename
|
||||
image_ref = convert_image_to_binary_mask(image_ref_filename)
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_test_rasterized_sphere_zoom.png"
|
||||
DATA_DIR
|
||||
/ f"DEBUG_test_rasterized_sphere_zoom_{rasterizer_type.__name__}.png"
|
||||
)
|
||||
self.assertTrue(torch.allclose(image, image_ref))
|
||||
|
||||
@@ -112,7 +133,7 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
##################################
|
||||
|
||||
# Create a new empty rasterizer:
|
||||
rasterizer = MeshRasterizer()
|
||||
rasterizer = rasterizer_type(raster_settings=raster_settings)
|
||||
|
||||
# Check that omitting the cameras in both initialization
|
||||
# and the forward pass throws an error:
|
||||
@@ -120,9 +141,7 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
rasterizer(sphere_mesh)
|
||||
|
||||
# Now pass in the cameras as a kwarg
|
||||
fragments = rasterizer(
|
||||
sphere_mesh, cameras=cameras, raster_settings=raster_settings
|
||||
)
|
||||
fragments = rasterizer(sphere_mesh, cameras=cameras)
|
||||
image = fragments.pix_to_face[0, ..., 0].squeeze().cpu()
|
||||
# Convert pix_to_face to a binary mask
|
||||
image[image >= 0] = 1.0
|
||||
@@ -130,7 +149,8 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_test_rasterized_sphere.png"
|
||||
DATA_DIR
|
||||
/ f"DEBUG_test_rasterized_sphere_{rasterizer_type.__name__}.png"
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(image, image_ref))
|
||||
@@ -141,6 +161,187 @@ class TestMeshRasterizer(unittest.TestCase):
|
||||
rasterizer = MeshRasterizer()
|
||||
rasterizer.to(device)
|
||||
|
||||
rasterizer = MeshRasterizerOpenGL()
|
||||
rasterizer.to(device)
|
||||
|
||||
def test_compare_rasterizers(self):
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Init rasterizer settings
|
||||
R, T = look_at_view_transform(2.7, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=0.0,
|
||||
faces_per_pixel=1,
|
||||
bin_size=0,
|
||||
perspective_correct=True,
|
||||
)
|
||||
from pytorch3d.io import load_obj
|
||||
from pytorch3d.renderer import TexturesAtlas
|
||||
|
||||
from .common_testing import get_pytorch3d_dir
|
||||
|
||||
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
||||
obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj"
|
||||
|
||||
# Load mesh and texture as a per face texture atlas.
|
||||
verts, faces, aux = load_obj(
|
||||
obj_filename,
|
||||
device=device,
|
||||
load_textures=True,
|
||||
create_texture_atlas=True,
|
||||
texture_atlas_size=8,
|
||||
texture_wrap=None,
|
||||
)
|
||||
atlas = aux.texture_atlas
|
||||
mesh = Meshes(
|
||||
verts=[verts],
|
||||
faces=[faces.verts_idx],
|
||||
textures=TexturesAtlas(atlas=[atlas]),
|
||||
)
|
||||
|
||||
# Rasterize using both rasterizers and compare results.
|
||||
rasterizer = MeshRasterizerOpenGL(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
)
|
||||
fragments_opengl = rasterizer(mesh)
|
||||
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
fragments = rasterizer(mesh)
|
||||
|
||||
# Ensure that 99.9% of bary_coords is at most 0.001 different.
|
||||
self.assertLess(
|
||||
torch.quantile(
|
||||
(fragments.bary_coords - fragments_opengl.bary_coords).abs(), 0.999
|
||||
),
|
||||
0.001,
|
||||
)
|
||||
|
||||
# Ensure that 99.9% of zbuf vals is at most 0.001 different.
|
||||
self.assertLess(
|
||||
torch.quantile((fragments.zbuf - fragments_opengl.zbuf).abs(), 0.999), 0.001
|
||||
)
|
||||
|
||||
# Ensure that 99.99% of pix_to_face is identical.
|
||||
self.assertEqual(
|
||||
torch.quantile(
|
||||
(fragments.pix_to_face != fragments_opengl.pix_to_face).float(), 0.9999
|
||||
),
|
||||
0,
|
||||
)
|
||||
|
||||
|
||||
class TestMeshRasterizerOpenGLUtils(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
verts = torch.tensor(
|
||||
[[-1, 1, 0], [1, 1, 0], [1, -1, 0]], dtype=torch.float32
|
||||
).cuda()
|
||||
faces = torch.tensor([[0, 1, 2]]).cuda()
|
||||
self.meshes_world = Meshes(verts=[verts], faces=[faces])
|
||||
|
||||
# Test various utils specific to the OpenGL rasterizer. Full "integration tests"
|
||||
# live in test_render_meshes and test_render_multigpu.
|
||||
def test_check_cameras(self):
|
||||
_check_cameras(FoVPerspectiveCameras())
|
||||
_check_cameras(FoVPerspectiveCameras())
|
||||
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
|
||||
_check_cameras(None)
|
||||
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
|
||||
_check_cameras(PerspectiveCameras())
|
||||
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
|
||||
_check_cameras(OrthographicCameras())
|
||||
|
||||
MeshRasterizerOpenGL(FoVPerspectiveCameras().cuda())(self.meshes_world)
|
||||
MeshRasterizerOpenGL(FoVOrthographicCameras().cuda())(self.meshes_world)
|
||||
MeshRasterizerOpenGL()(
|
||||
self.meshes_world, cameras=FoVPerspectiveCameras().cuda()
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
|
||||
MeshRasterizerOpenGL(PerspectiveCameras().cuda())(self.meshes_world)
|
||||
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
|
||||
MeshRasterizerOpenGL(OrthographicCameras().cuda())(self.meshes_world)
|
||||
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
|
||||
MeshRasterizerOpenGL()(self.meshes_world)
|
||||
|
||||
def test_check_raster_settings(self):
|
||||
raster_settings = RasterizationSettings()
|
||||
raster_settings.faces_per_pixel = 100
|
||||
with self.assertWarnsRegex(UserWarning, ".* one face per pixel"):
|
||||
_check_raster_settings(raster_settings)
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, ".* one face per pixel"):
|
||||
MeshRasterizerOpenGL(raster_settings=raster_settings)(
|
||||
self.meshes_world, cameras=FoVPerspectiveCameras().cuda()
|
||||
)
|
||||
|
||||
def test_convert_meshes_to_gl_ndc_square_img(self):
|
||||
R, T = look_at_view_transform(1, 90, 180)
|
||||
cameras = FoVOrthographicCameras(R=R, T=T).cuda()
|
||||
|
||||
meshes_gl_ndc = _convert_meshes_to_gl_ndc(
|
||||
self.meshes_world, (100, 100), cameras
|
||||
)
|
||||
|
||||
# After look_at_view_transform rotating 180 deg around z-axis, we recover
|
||||
# the original coordinates. After additionally elevating the view by 90
|
||||
# deg, we "zero out" the y-coordinate. Finally, we negate the x and y axes
|
||||
# to adhere to OpenGL conventions (which go against the PyTorch3D convention).
|
||||
self.assertClose(
|
||||
meshes_gl_ndc.verts_list()[0],
|
||||
torch.tensor(
|
||||
[[-1, 0, 0], [1, 0, 0], [1, 0, 2]], dtype=torch.float32
|
||||
).cuda(),
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
def test_parse_and_verify_image_size(self):
|
||||
img_size = _parse_and_verify_image_size(512)
|
||||
self.assertEqual(img_size, (512, 512))
|
||||
|
||||
img_size = _parse_and_verify_image_size((2047, 10))
|
||||
self.assertEqual(img_size, (2047, 10))
|
||||
|
||||
img_size = _parse_and_verify_image_size((10, 2047))
|
||||
self.assertEqual(img_size, (10, 2047))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
|
||||
_parse_and_verify_image_size((2049, 512))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
|
||||
_parse_and_verify_image_size((512, 2049))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
|
||||
_parse_and_verify_image_size((2049, 2049))
|
||||
|
||||
rasterizer = MeshRasterizerOpenGL(FoVPerspectiveCameras().cuda())
|
||||
raster_settings = RasterizationSettings()
|
||||
|
||||
raster_settings.image_size = 512
|
||||
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
|
||||
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 512, 512, 1]))
|
||||
|
||||
raster_settings.image_size = (2047, 10)
|
||||
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
|
||||
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 2047, 10, 1]))
|
||||
|
||||
raster_settings.image_size = (10, 2047)
|
||||
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
|
||||
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 10, 2047, 1]))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
|
||||
raster_settings.image_size = (2049, 512)
|
||||
rasterizer(self.meshes_world, raster_settings=raster_settings)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
|
||||
raster_settings.image_size = (512, 2049)
|
||||
rasterizer(self.meshes_world, raster_settings=raster_settings)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
|
||||
raster_settings.image_size = (2049, 2049)
|
||||
rasterizer(self.meshes_world, raster_settings=raster_settings)
|
||||
|
||||
|
||||
class TestPointRasterizer(unittest.TestCase):
|
||||
def test_simple_sphere(self):
|
||||
|
||||
Reference in New Issue
Block a user