CUDA/C++ Rasterizer updates to handle clipped faces

Summary:
- Updated the C++/CUDA mesh rasterization kernels to handle the clipped faces. In particular this required careful handling of the distance calculation for faces which are cut into a quadrilateral by the image plane and then split into two sub triangles i.e. both sub triangles can't be part of the top K faces.
- Updated `rasterize_meshes.py` to use the utils functions to clip the meshes and convert the fragments back to in terms of the unclipped mesh
- Added end to end tests

Reviewed By: jcjohnson

Differential Revision: D26169685

fbshipit-source-id: d64cd0d656109b965f44a35c301b7c81f451cfa0
This commit is contained in:
Nikhila Ravi
2021-02-08 14:30:55 -08:00
committed by Facebook GitHub Bot
parent 838b73d3b6
commit 340662e98e
12 changed files with 733 additions and 46 deletions

View File

@@ -66,7 +66,7 @@ def bm_rasterize_meshes() -> None:
# Square and non square cases
image_size = [64, 128, 512, (512, 256), (256, 512)]
blur = [1e-6]
faces_per_pixel = [50]
faces_per_pixel = [40]
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
for case in test_cases:
@@ -87,6 +87,35 @@ def bm_rasterize_meshes() -> None:
warmup_iters=1,
)
# Test a subset of the cases with the
# image plane intersecting the mesh.
kwargs_list = []
num_meshes = [8, 16]
# Square and non square cases
image_size = [64, 128, 512, (512, 256), (256, 512)]
dist = [3, 0.8, 0.5]
test_cases = product(num_meshes, dist, image_size)
for case in test_cases:
n, d, im = case
kwargs_list.append(
{
"num_meshes": n,
"ico_level": 4,
"image_size": im,
"blur_radius": 1e-6,
"faces_per_pixel": 40,
"dist": d,
}
)
benchmark(
TestRasterizeMeshes.bm_rasterize_meshes_with_clipping,
"RASTERIZE_MESHES_CUDA_CLIPPING",
kwargs_list,
warmup_iters=1,
)
if __name__ == "__main__":
bm_rasterize_meshes()

BIN
tests/data/room.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.2 KiB

View File

@@ -6,6 +6,8 @@ import unittest
import torch
from common_testing import TestCaseMixin, get_random_cuda_device
from pytorch3d import _C
from pytorch3d.renderer import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.rasterize_meshes import (
rasterize_meshes,
rasterize_meshes_python,
@@ -1204,3 +1206,50 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
torch.cuda.synchronize(device)
return rasterize
@staticmethod
def bm_rasterize_meshes_with_clipping(
num_meshes: int,
ico_level: int,
image_size: int,
blur_radius: float,
faces_per_pixel: int,
dist: float,
):
device = get_random_cuda_device()
meshes = ico_sphere(ico_level, device)
meshes_batch = meshes.extend(num_meshes)
settings = RasterizationSettings(
image_size=image_size,
blur_radius=blur_radius,
faces_per_pixel=faces_per_pixel,
z_clip_value=1e-2,
perspective_correct=True,
cull_to_frustum=True,
)
# The camera is positioned so that the image plane intersects
# the mesh and some faces are partially behind the image plane.
R, T = look_at_view_transform(dist, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
# Transform the meshes to projec them onto the image plane
meshes_screen = rasterizer.transform(meshes_batch)
torch.cuda.synchronize(device)
def rasterize():
# Only measure rasterization speed (including clipping)
rasterize_meshes(
meshes_screen,
image_size,
blur_radius,
faces_per_pixel,
z_clip_value=1e-2,
perspective_correct=True,
cull_to_frustum=True,
)
torch.cuda.synchronize(device)
return rasterize

View File

@@ -245,6 +245,7 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase):
image_size=image_size,
blur_radius=1e-3,
faces_per_pixel=10,
z_clip_value=None,
perspective_correct=False,
),
),

View File

@@ -994,6 +994,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
blur_radius=0.0,
faces_per_pixel=1,
cull_backfaces=True,
perspective_correct=False,
)
# Init shader settings

View File

@@ -9,14 +9,161 @@ See pytorch3d/renderer/mesh/clip.py for more details about the
clipping process.
"""
import unittest
from pathlib import Path
import imageio
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.mesh import ClipFrustum, clip_faces
from common_testing import TestCaseMixin, load_rgb_image
from pytorch3d.io import save_obj
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.mesh import (
ClipFrustum,
TexturesUV,
clip_faces,
convert_clipped_rasterization_to_original_faces,
)
from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import SoftPhongShader
from pytorch3d.structures.meshes import Meshes
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
def load_cube_mesh_with_texture(self, device="cpu", with_grad: bool = False):
verts = torch.tensor(
[
[-1, 1, 1],
[1, 1, 1],
[1, -1, 1],
[-1, -1, 1],
[-1, 1, -1],
[1, 1, -1],
[1, -1, -1],
[-1, -1, -1],
],
device=device,
dtype=torch.float32,
requires_grad=with_grad,
)
# all faces correctly wound
faces = torch.tensor(
[
[0, 1, 4],
[4, 1, 5],
[1, 2, 5],
[5, 2, 6],
[2, 7, 6],
[2, 3, 7],
[3, 4, 7],
[0, 4, 3],
[4, 5, 6],
[4, 6, 7],
],
device=device,
dtype=torch.int64,
)
verts_uvs = torch.tensor(
[
[
[0, 1],
[1, 1],
[1, 0],
[0, 0],
[0.204, 0.743],
[0.781, 0.743],
[0.781, 0.154],
[0.204, 0.154],
]
],
device=device,
dtype=torch.float,
)
texture_map = load_rgb_image("room.jpg", DATA_DIR).to(device)
textures = TexturesUV(
maps=[texture_map], faces_uvs=faces.unsqueeze(0), verts_uvs=verts_uvs
)
mesh = Meshes([verts], [faces], textures=textures)
if with_grad:
return mesh, verts
return mesh
def test_cube_mesh_render(self):
"""
End-End test of rendering a cube mesh with texture
from decreasing camera distances. The camera starts
outside the cube and enters the inside of the cube.
"""
device = torch.device("cuda:0")
mesh = self.load_cube_mesh_with_texture(device)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=1e-8,
faces_per_pixel=5,
z_clip_value=1e-2,
perspective_correct=True,
bin_size=0,
)
# Only ambient, no diffuse or specular
lights = PointLights(
device=device,
ambient_color=((1.0, 1.0, 1.0),),
diffuse_color=((0.0, 0.0, 0.0),),
specular_color=((0.0, 0.0, 0.0),),
location=[[0.0, 0.0, -3.0]],
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(raster_settings=raster_settings),
shader=SoftPhongShader(device=device, lights=lights),
)
# Render the cube by decreasing the distance from the camera until
# the camera enters the cube. Check the output looks correct.
images_list = []
dists = np.linspace(0.1, 2.5, 20)[::-1]
for d in dists:
R, T = look_at_view_transform(d, 0, 0)
T[0, 1] -= 0.1 # move down in the y axis
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
images = renderer(mesh, cameras=cameras)
rgb = images[0, ..., :3].cpu().detach()
filename = "DEBUG_cube_dist=%.1f.jpg" % d
im = (rgb.numpy() * 255).astype(np.uint8)
images_list.append(im)
# Check one of the images where the camera is inside the mesh
if d == 0.5:
filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
image_ref = load_rgb_image(filename, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)
# Save a gif of the output - this should show
# the camera moving inside the cube.
if DEBUG:
gif_filename = (
"room_original.gif"
if raster_settings.z_clip_value is None
else "room_clipped.gif"
)
imageio.mimsave(DATA_DIR / gif_filename, images_list, fps=2)
save_obj(
f=DATA_DIR / "cube.obj",
verts=mesh.verts_packed().cpu(),
faces=mesh.faces_packed().cpu(),
)
@staticmethod
def clip_faces(meshes):
verts_packed = meshes.verts_packed()
@@ -42,6 +189,34 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
)
return clipped_faces
def test_grad(self):
"""
Check that gradient flow is unaffected when the camera is inside the mesh
"""
device = torch.device("cuda:0")
mesh, verts = self.load_cube_mesh_with_texture(device=device, with_grad=True)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=1e-5,
faces_per_pixel=5,
z_clip_value=1e-2,
perspective_correct=True,
bin_size=0,
)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(raster_settings=raster_settings),
shader=SoftPhongShader(device=device),
)
dist = 0.4 # Camera is inside the cube
R, T = look_at_view_transform(dist, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
images = renderer(mesh, cameras=cameras)
images.sum().backward()
# Check gradients exist
self.assertIsNotNone(verts.grad)
def test_case_1(self):
"""
Case 1: Single triangle fully in front of the image plane (z=0)
@@ -350,3 +525,134 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
# barycentric conversion matrix.
bary_idx = idx.new_tensor([1, 4, 2, 5, 0, -1, 3, 6])
self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, bary_idx)
def test_convert_clipped_to_unclipped_case_4(self):
"""
Test with a single case 4 triangle which is clipped into
a quadrilateral and subdivided.
"""
device = "cuda:0"
# fmt: off
verts = torch.tensor(
[
[-1.0, 0.0, -1.0], # noqa: E241, E201
[ 0.0, 1.0, -1.0], # noqa: E241, E201
[ 1.0, 0.0, -1.0], # noqa: E241, E201
[ 0.0, -1.0, -1.0], # noqa: E241, E201
[-1.0, 0.5, 0.5], # noqa: E241, E201
[ 1.0, 1.0, 1.0], # noqa: E241, E201
[ 0.0, -1.0, 1.0], # noqa: E241, E201
[-1.0, 0.5, -0.5], # noqa: E241, E201
[ 1.0, 1.0, -1.0], # noqa: E241, E201
[-1.0, 0.0, 1.0], # noqa: E241, E201
[ 0.0, 1.0, 1.0], # noqa: E241, E201
[ 1.0, 0.0, 1.0], # noqa: E241, E201
],
dtype=torch.float32,
device=device,
)
faces = torch.tensor(
[
[0, 1, 2], # noqa: E241, E201 Case 2 fully clipped
[3, 4, 5], # noqa: E241, E201 Case 4 clipped and subdivided
[5, 4, 3], # noqa: E241, E201 Repeat of Case 4
[6, 7, 8], # noqa: E241, E201 Case 3 clipped
[9, 10, 11], # noqa: E241, E201 Case 1 untouched
],
dtype=torch.int64,
device=device,
)
# fmt: on
meshes = Meshes(verts=[verts], faces=[faces])
# Clip meshes
clipped_faces = self.clip_faces(meshes)
# 4x faces (from Case 4) + 1 (from Case 3) + 1 (from Case 1)
self.assertEqual(clipped_faces.face_verts.shape[0], 6)
image_size = (10, 10)
blur_radius = 0.05
faces_per_pixel = 2
perspective_correct = True
bin_size = 0
max_faces_per_bin = 20
clip_barycentric_coords = False
cull_backfaces = False
# Rasterize clipped mesh
pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply(
clipped_faces.face_verts,
clipped_faces.mesh_to_face_first_idx,
clipped_faces.num_faces_per_mesh,
clipped_faces.clipped_faces_neighbor_idx,
image_size,
blur_radius,
faces_per_pixel,
bin_size,
max_faces_per_bin,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
)
# Convert outputs so they are in terms of the unclipped mesh.
outputs = convert_clipped_rasterization_to_original_faces(
pix_to_face,
barycentric_coords,
clipped_faces,
)
pix_to_face_unclipped, barycentric_coords_unclipped = outputs
# In the clipped mesh there are more faces than in the unclipped mesh
self.assertTrue(pix_to_face.max() > pix_to_face_unclipped.max())
# Unclipped pix_to_face indices must be in the limit of the number
# of faces in the unclipped mesh.
self.assertTrue(pix_to_face_unclipped.max() < faces.shape[0])
def test_case_4_no_duplicates(self):
"""
In the case of an simple mesh with one face that is cut by the image
plane into a quadrilateral, there shouldn't be duplicates indices of
the face in the pix_to_face output of rasterization.
"""
for (device, bin_size) in [("cpu", 0), ("cuda:0", 0), ("cuda:0", None)]:
verts = torch.tensor(
[[0.0, -10.0, 1.0], [-1.0, 2.0, -2.0], [1.0, 5.0, -10.0]],
dtype=torch.float32,
device=device,
)
faces = torch.tensor(
[
[0, 1, 2],
],
dtype=torch.int64,
device=device,
)
meshes = Meshes(verts=[verts], faces=[faces])
k = 3
settings = RasterizationSettings(
image_size=10,
blur_radius=0.05,
faces_per_pixel=k,
z_clip_value=1e-2,
perspective_correct=True,
cull_to_frustum=True,
bin_size=bin_size,
)
# The camera is positioned so that the image plane cuts
# the mesh face into a quadrilateral.
R, T = look_at_view_transform(0.2, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
fragments = rasterizer(meshes)
p2f = fragments.pix_to_face.reshape(-1, k)
unique_vals, idx_counts = p2f.unique(dim=0, return_counts=True)
# There is only one face in this mesh so if it hits a pixel
# it can only be at position k = 0
# For any pixel, the values [0, 0, 1] for the top K faces cannot be possible
double_hit = torch.tensor([0, 0, -1], device=device)
check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
self.assertFalse(check_double_hit)