mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
CUDA/C++ Rasterizer updates to handle clipped faces
Summary: - Updated the C++/CUDA mesh rasterization kernels to handle the clipped faces. In particular this required careful handling of the distance calculation for faces which are cut into a quadrilateral by the image plane and then split into two sub triangles i.e. both sub triangles can't be part of the top K faces. - Updated `rasterize_meshes.py` to use the utils functions to clip the meshes and convert the fragments back to in terms of the unclipped mesh - Added end to end tests Reviewed By: jcjohnson Differential Revision: D26169685 fbshipit-source-id: d64cd0d656109b965f44a35c301b7c81f451cfa0
This commit is contained in:
committed by
Facebook GitHub Bot
parent
838b73d3b6
commit
340662e98e
@@ -66,7 +66,7 @@ def bm_rasterize_meshes() -> None:
|
||||
# Square and non square cases
|
||||
image_size = [64, 128, 512, (512, 256), (256, 512)]
|
||||
blur = [1e-6]
|
||||
faces_per_pixel = [50]
|
||||
faces_per_pixel = [40]
|
||||
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
|
||||
|
||||
for case in test_cases:
|
||||
@@ -87,6 +87,35 @@ def bm_rasterize_meshes() -> None:
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
# Test a subset of the cases with the
|
||||
# image plane intersecting the mesh.
|
||||
kwargs_list = []
|
||||
num_meshes = [8, 16]
|
||||
# Square and non square cases
|
||||
image_size = [64, 128, 512, (512, 256), (256, 512)]
|
||||
dist = [3, 0.8, 0.5]
|
||||
test_cases = product(num_meshes, dist, image_size)
|
||||
|
||||
for case in test_cases:
|
||||
n, d, im = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"num_meshes": n,
|
||||
"ico_level": 4,
|
||||
"image_size": im,
|
||||
"blur_radius": 1e-6,
|
||||
"faces_per_pixel": 40,
|
||||
"dist": d,
|
||||
}
|
||||
)
|
||||
|
||||
benchmark(
|
||||
TestRasterizeMeshes.bm_rasterize_meshes_with_clipping,
|
||||
"RASTERIZE_MESHES_CUDA_CLIPPING",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_rasterize_meshes()
|
||||
|
||||
BIN
tests/data/room.jpg
Normal file
BIN
tests/data/room.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.8 KiB |
BIN
tests/data/test_render_mesh_clipped_cam_dist=0.5.jpg
Normal file
BIN
tests/data/test_render_mesh_clipped_cam_dist=0.5.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.2 KiB |
@@ -6,6 +6,8 @@ import unittest
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer import FoVPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.renderer.mesh import MeshRasterizer, RasterizationSettings
|
||||
from pytorch3d.renderer.mesh.rasterize_meshes import (
|
||||
rasterize_meshes,
|
||||
rasterize_meshes_python,
|
||||
@@ -1204,3 +1206,50 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
return rasterize
|
||||
|
||||
@staticmethod
|
||||
def bm_rasterize_meshes_with_clipping(
|
||||
num_meshes: int,
|
||||
ico_level: int,
|
||||
image_size: int,
|
||||
blur_radius: float,
|
||||
faces_per_pixel: int,
|
||||
dist: float,
|
||||
):
|
||||
device = get_random_cuda_device()
|
||||
meshes = ico_sphere(ico_level, device)
|
||||
meshes_batch = meshes.extend(num_meshes)
|
||||
|
||||
settings = RasterizationSettings(
|
||||
image_size=image_size,
|
||||
blur_radius=blur_radius,
|
||||
faces_per_pixel=faces_per_pixel,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
cull_to_frustum=True,
|
||||
)
|
||||
|
||||
# The camera is positioned so that the image plane intersects
|
||||
# the mesh and some faces are partially behind the image plane.
|
||||
R, T = look_at_view_transform(dist, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||
rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
|
||||
|
||||
# Transform the meshes to projec them onto the image plane
|
||||
meshes_screen = rasterizer.transform(meshes_batch)
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
def rasterize():
|
||||
# Only measure rasterization speed (including clipping)
|
||||
rasterize_meshes(
|
||||
meshes_screen,
|
||||
image_size,
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
cull_to_frustum=True,
|
||||
)
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
return rasterize
|
||||
|
||||
@@ -245,6 +245,7 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase):
|
||||
image_size=image_size,
|
||||
blur_radius=1e-3,
|
||||
faces_per_pixel=10,
|
||||
z_clip_value=None,
|
||||
perspective_correct=False,
|
||||
),
|
||||
),
|
||||
|
||||
@@ -994,6 +994,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
blur_radius=0.0,
|
||||
faces_per_pixel=1,
|
||||
cull_backfaces=True,
|
||||
perspective_correct=False,
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
|
||||
@@ -9,14 +9,161 @@ See pytorch3d/renderer/mesh/clip.py for more details about the
|
||||
clipping process.
|
||||
"""
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.mesh import ClipFrustum, clip_faces
|
||||
from common_testing import TestCaseMixin, load_rgb_image
|
||||
from pytorch3d.io import save_obj
|
||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.renderer.lighting import PointLights
|
||||
from pytorch3d.renderer.mesh import (
|
||||
ClipFrustum,
|
||||
TexturesUV,
|
||||
clip_faces,
|
||||
convert_clipped_rasterization_to_original_faces,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
|
||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||
from pytorch3d.renderer.mesh.shader import SoftPhongShader
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
|
||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||
# All saved images have prefix DEBUG_
|
||||
DEBUG = False
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
|
||||
|
||||
class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
def load_cube_mesh_with_texture(self, device="cpu", with_grad: bool = False):
|
||||
verts = torch.tensor(
|
||||
[
|
||||
[-1, 1, 1],
|
||||
[1, 1, 1],
|
||||
[1, -1, 1],
|
||||
[-1, -1, 1],
|
||||
[-1, 1, -1],
|
||||
[1, 1, -1],
|
||||
[1, -1, -1],
|
||||
[-1, -1, -1],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
requires_grad=with_grad,
|
||||
)
|
||||
|
||||
# all faces correctly wound
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 4],
|
||||
[4, 1, 5],
|
||||
[1, 2, 5],
|
||||
[5, 2, 6],
|
||||
[2, 7, 6],
|
||||
[2, 3, 7],
|
||||
[3, 4, 7],
|
||||
[0, 4, 3],
|
||||
[4, 5, 6],
|
||||
[4, 6, 7],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
verts_uvs = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0, 1],
|
||||
[1, 1],
|
||||
[1, 0],
|
||||
[0, 0],
|
||||
[0.204, 0.743],
|
||||
[0.781, 0.743],
|
||||
[0.781, 0.154],
|
||||
[0.204, 0.154],
|
||||
]
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float,
|
||||
)
|
||||
texture_map = load_rgb_image("room.jpg", DATA_DIR).to(device)
|
||||
textures = TexturesUV(
|
||||
maps=[texture_map], faces_uvs=faces.unsqueeze(0), verts_uvs=verts_uvs
|
||||
)
|
||||
mesh = Meshes([verts], [faces], textures=textures)
|
||||
if with_grad:
|
||||
return mesh, verts
|
||||
return mesh
|
||||
|
||||
def test_cube_mesh_render(self):
|
||||
"""
|
||||
End-End test of rendering a cube mesh with texture
|
||||
from decreasing camera distances. The camera starts
|
||||
outside the cube and enters the inside of the cube.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
mesh = self.load_cube_mesh_with_texture(device)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=1e-8,
|
||||
faces_per_pixel=5,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
bin_size=0,
|
||||
)
|
||||
|
||||
# Only ambient, no diffuse or specular
|
||||
lights = PointLights(
|
||||
device=device,
|
||||
ambient_color=((1.0, 1.0, 1.0),),
|
||||
diffuse_color=((0.0, 0.0, 0.0),),
|
||||
specular_color=((0.0, 0.0, 0.0),),
|
||||
location=[[0.0, 0.0, -3.0]],
|
||||
)
|
||||
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(raster_settings=raster_settings),
|
||||
shader=SoftPhongShader(device=device, lights=lights),
|
||||
)
|
||||
|
||||
# Render the cube by decreasing the distance from the camera until
|
||||
# the camera enters the cube. Check the output looks correct.
|
||||
images_list = []
|
||||
dists = np.linspace(0.1, 2.5, 20)[::-1]
|
||||
for d in dists:
|
||||
R, T = look_at_view_transform(d, 0, 0)
|
||||
T[0, 1] -= 0.1 # move down in the y axis
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||
images = renderer(mesh, cameras=cameras)
|
||||
rgb = images[0, ..., :3].cpu().detach()
|
||||
filename = "DEBUG_cube_dist=%.1f.jpg" % d
|
||||
im = (rgb.numpy() * 255).astype(np.uint8)
|
||||
images_list.append(im)
|
||||
|
||||
# Check one of the images where the camera is inside the mesh
|
||||
if d == 0.5:
|
||||
filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
|
||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
# Save a gif of the output - this should show
|
||||
# the camera moving inside the cube.
|
||||
if DEBUG:
|
||||
gif_filename = (
|
||||
"room_original.gif"
|
||||
if raster_settings.z_clip_value is None
|
||||
else "room_clipped.gif"
|
||||
)
|
||||
imageio.mimsave(DATA_DIR / gif_filename, images_list, fps=2)
|
||||
save_obj(
|
||||
f=DATA_DIR / "cube.obj",
|
||||
verts=mesh.verts_packed().cpu(),
|
||||
faces=mesh.faces_packed().cpu(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def clip_faces(meshes):
|
||||
verts_packed = meshes.verts_packed()
|
||||
@@ -42,6 +189,34 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
return clipped_faces
|
||||
|
||||
def test_grad(self):
|
||||
"""
|
||||
Check that gradient flow is unaffected when the camera is inside the mesh
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
mesh, verts = self.load_cube_mesh_with_texture(device=device, with_grad=True)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=1e-5,
|
||||
faces_per_pixel=5,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
bin_size=0,
|
||||
)
|
||||
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(raster_settings=raster_settings),
|
||||
shader=SoftPhongShader(device=device),
|
||||
)
|
||||
dist = 0.4 # Camera is inside the cube
|
||||
R, T = look_at_view_transform(dist, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||
images = renderer(mesh, cameras=cameras)
|
||||
images.sum().backward()
|
||||
|
||||
# Check gradients exist
|
||||
self.assertIsNotNone(verts.grad)
|
||||
|
||||
def test_case_1(self):
|
||||
"""
|
||||
Case 1: Single triangle fully in front of the image plane (z=0)
|
||||
@@ -350,3 +525,134 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
# barycentric conversion matrix.
|
||||
bary_idx = idx.new_tensor([1, 4, 2, 5, 0, -1, 3, 6])
|
||||
self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, bary_idx)
|
||||
|
||||
def test_convert_clipped_to_unclipped_case_4(self):
|
||||
"""
|
||||
Test with a single case 4 triangle which is clipped into
|
||||
a quadrilateral and subdivided.
|
||||
"""
|
||||
device = "cuda:0"
|
||||
# fmt: off
|
||||
verts = torch.tensor(
|
||||
[
|
||||
[-1.0, 0.0, -1.0], # noqa: E241, E201
|
||||
[ 0.0, 1.0, -1.0], # noqa: E241, E201
|
||||
[ 1.0, 0.0, -1.0], # noqa: E241, E201
|
||||
[ 0.0, -1.0, -1.0], # noqa: E241, E201
|
||||
[-1.0, 0.5, 0.5], # noqa: E241, E201
|
||||
[ 1.0, 1.0, 1.0], # noqa: E241, E201
|
||||
[ 0.0, -1.0, 1.0], # noqa: E241, E201
|
||||
[-1.0, 0.5, -0.5], # noqa: E241, E201
|
||||
[ 1.0, 1.0, -1.0], # noqa: E241, E201
|
||||
[-1.0, 0.0, 1.0], # noqa: E241, E201
|
||||
[ 0.0, 1.0, 1.0], # noqa: E241, E201
|
||||
[ 1.0, 0.0, 1.0], # noqa: E241, E201
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 2], # noqa: E241, E201 Case 2 fully clipped
|
||||
[3, 4, 5], # noqa: E241, E201 Case 4 clipped and subdivided
|
||||
[5, 4, 3], # noqa: E241, E201 Repeat of Case 4
|
||||
[6, 7, 8], # noqa: E241, E201 Case 3 clipped
|
||||
[9, 10, 11], # noqa: E241, E201 Case 1 untouched
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
# fmt: on
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
|
||||
# Clip meshes
|
||||
clipped_faces = self.clip_faces(meshes)
|
||||
|
||||
# 4x faces (from Case 4) + 1 (from Case 3) + 1 (from Case 1)
|
||||
self.assertEqual(clipped_faces.face_verts.shape[0], 6)
|
||||
|
||||
image_size = (10, 10)
|
||||
blur_radius = 0.05
|
||||
faces_per_pixel = 2
|
||||
perspective_correct = True
|
||||
bin_size = 0
|
||||
max_faces_per_bin = 20
|
||||
clip_barycentric_coords = False
|
||||
cull_backfaces = False
|
||||
|
||||
# Rasterize clipped mesh
|
||||
pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply(
|
||||
clipped_faces.face_verts,
|
||||
clipped_faces.mesh_to_face_first_idx,
|
||||
clipped_faces.num_faces_per_mesh,
|
||||
clipped_faces.clipped_faces_neighbor_idx,
|
||||
image_size,
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
bin_size,
|
||||
max_faces_per_bin,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces,
|
||||
)
|
||||
|
||||
# Convert outputs so they are in terms of the unclipped mesh.
|
||||
outputs = convert_clipped_rasterization_to_original_faces(
|
||||
pix_to_face,
|
||||
barycentric_coords,
|
||||
clipped_faces,
|
||||
)
|
||||
pix_to_face_unclipped, barycentric_coords_unclipped = outputs
|
||||
|
||||
# In the clipped mesh there are more faces than in the unclipped mesh
|
||||
self.assertTrue(pix_to_face.max() > pix_to_face_unclipped.max())
|
||||
# Unclipped pix_to_face indices must be in the limit of the number
|
||||
# of faces in the unclipped mesh.
|
||||
self.assertTrue(pix_to_face_unclipped.max() < faces.shape[0])
|
||||
|
||||
def test_case_4_no_duplicates(self):
|
||||
"""
|
||||
In the case of an simple mesh with one face that is cut by the image
|
||||
plane into a quadrilateral, there shouldn't be duplicates indices of
|
||||
the face in the pix_to_face output of rasterization.
|
||||
"""
|
||||
for (device, bin_size) in [("cpu", 0), ("cuda:0", 0), ("cuda:0", None)]:
|
||||
verts = torch.tensor(
|
||||
[[0.0, -10.0, 1.0], [-1.0, 2.0, -2.0], [1.0, 5.0, -10.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 2],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
k = 3
|
||||
settings = RasterizationSettings(
|
||||
image_size=10,
|
||||
blur_radius=0.05,
|
||||
faces_per_pixel=k,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
cull_to_frustum=True,
|
||||
bin_size=bin_size,
|
||||
)
|
||||
|
||||
# The camera is positioned so that the image plane cuts
|
||||
# the mesh face into a quadrilateral.
|
||||
R, T = look_at_view_transform(0.2, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||
rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
|
||||
fragments = rasterizer(meshes)
|
||||
|
||||
p2f = fragments.pix_to_face.reshape(-1, k)
|
||||
unique_vals, idx_counts = p2f.unique(dim=0, return_counts=True)
|
||||
# There is only one face in this mesh so if it hits a pixel
|
||||
# it can only be at position k = 0
|
||||
# For any pixel, the values [0, 0, 1] for the top K faces cannot be possible
|
||||
double_hit = torch.tensor([0, 0, -1], device=device)
|
||||
check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
|
||||
self.assertFalse(check_double_hit)
|
||||
|
||||
Reference in New Issue
Block a user