add align modes for cubify

Summary: Add alignment modes for cubify operation.

Reviewed By: nikhilaravi

Differential Revision: D21393199

fbshipit-source-id: 7022044e591229a6ed5efc361fd3215e65f43f86
This commit is contained in:
Georgia Gkioxari 2020-05-05 11:07:23 -07:00 committed by Facebook GitHub Bot
parent 8fc28baa27
commit a61c9376d5
2 changed files with 290 additions and 232 deletions

View File

@ -45,7 +45,7 @@ def ravel_index(idx, dims) -> torch.Tensor:
@torch.no_grad() @torch.no_grad()
def cubify(voxels, thresh, device=None) -> Meshes: def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
r""" r"""
Converts a voxel to a mesh by replacing each occupied voxel with a cube Converts a voxel to a mesh by replacing each occupied voxel with a cube
consisting of 12 faces and 8 vertices. Shared vertices are merged, and consisting of 12 faces and 8 vertices. Shared vertices are merged, and
@ -54,13 +54,38 @@ def cubify(voxels, thresh, device=None) -> Meshes:
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities. voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
thresh: A scalar threshold. If a voxel occupancy is larger than thresh: A scalar threshold. If a voxel occupancy is larger than
thresh, the voxel is considered occupied. thresh, the voxel is considered occupied.
device: The device of the output meshes
align: Defines the alignment of the mesh vertices and the grid locations.
Has to be one of {"topleft", "corner", "center"}. See below for explanation.
Default is "topleft".
Returns: Returns:
meshes: A Meshes object of the corresponding meshes. meshes: A Meshes object of the corresponding meshes.
The alignment between the vertices of the cubified mesh and the voxel locations (or pixels)
is defined by the choice of `align`. We support three modes, as shown below for a 2x2 grid:
X---X---- X-------X ---------
| | | | | | | X | X |
X---X---- --------- ---------
| | | | | | | X | X |
--------- X-------X ---------
topleft corner center
In the figure, X denote the grid locations and the squares represent the added cuboids.
When `align="topleft"`, then the top left corner of each cuboid corresponds to the
pixel coordinate of the input grid.
When `align="corner"`, then the corners of the output mesh span the whole grid.
When `align="center"`, then the grid locations form the center of the cuboids.
""" """
if device is None: if device is None:
device = voxels.device device = voxels.device
if align not in ["topleft", "corner", "center"]:
raise ValueError("Align mode must be one of (topleft, corner, center).")
if len(voxels) == 0: if len(voxels) == 0:
return Meshes(verts=[], faces=[]) return Meshes(verts=[], faces=[])
@ -146,7 +171,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
# boolean to linear index # boolean to linear index
# NF x 2 # NF x 2
linind = torch.nonzero(faces_idx) linind = torch.nonzero(faces_idx, as_tuple=False)
# NF x 4 # NF x 4
nyxz = unravel_index(linind[:, 0], (N, H, W, D)) nyxz = unravel_index(linind[:, 0], (N, H, W, D))
@ -170,11 +195,19 @@ def cubify(voxels, thresh, device=None) -> Meshes:
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1) torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
) )
y = y.to(device=device, dtype=torch.float32) y = y.to(device=device, dtype=torch.float32)
y = y * 2.0 / (H - 1.0) - 1.0
x = x.to(device=device, dtype=torch.float32) x = x.to(device=device, dtype=torch.float32)
x = x * 2.0 / (W - 1.0) - 1.0
z = z.to(device=device, dtype=torch.float32) z = z.to(device=device, dtype=torch.float32)
z = z * 2.0 / (D - 1.0) - 1.0
if align == "center":
x = x - 0.5
y = y - 0.5
z = z - 0.5
margin = 0.0 if align == "corner" else 1.0
y = y * 2.0 / (H - margin) - 1.0
x = x * 2.0 / (W - margin) - 1.0
z = z * 2.0 / (D - margin) - 1.0
# ((H+1)(W+1)(D+1)) x 3 # ((H+1)(W+1)(D+1)) x 3
grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3) grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)
@ -196,7 +229,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
idlenum = idleverts.cumsum(1) idlenum = idleverts.cumsum(1)
verts_list = [ verts_list = [
grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0]) grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0])
for n in range(N) for n in range(N)
] ]
faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)] faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]

View File

@ -3,16 +3,17 @@
import unittest import unittest
import torch import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import cubify from pytorch3d.ops import cubify
class TestCubify(unittest.TestCase): class TestCubify(TestCaseMixin, unittest.TestCase):
def test_allempty(self): def test_allempty(self):
N, V = 32, 14 N, V = 32, 14
device = torch.device("cuda:0") device = torch.device("cuda:0")
voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device) voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device)
meshes = cubify(voxels, 0.5, 0) meshes = cubify(voxels, 0.5)
self.assertTrue(meshes.isempty) self.assertTrue(meshes.isempty())
def test_cubify(self): def test_cubify(self):
N, V = 4, 2 N, V = 4, 2
@ -29,15 +30,12 @@ class TestCubify(unittest.TestCase):
voxels[3, 1, 1, 0] = 1.0 voxels[3, 1, 1, 0] = 1.0
# compute cubify # compute cubify
meshes = cubify(voxels, 0.5, 0) meshes = cubify(voxels, 0.5)
# 1st-check # 1st-check
verts, faces = meshes.get_mesh_verts_faces(0) verts, faces = meshes.get_mesh_verts_faces(0)
self.assertTrue( self.assertClose(faces.max().cpu(), torch.tensor(verts.size(0) - 1))
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1])) self.assertClose(
)
self.assertTrue(
torch.allclose(
verts, verts,
torch.tensor( torch.tensor(
[ [
@ -54,9 +52,7 @@ class TestCubify(unittest.TestCase):
device=device, device=device,
), ),
) )
) self.assertClose(
self.assertTrue(
torch.allclose(
faces, faces,
torch.tensor( torch.tensor(
[ [
@ -77,14 +73,10 @@ class TestCubify(unittest.TestCase):
device=device, device=device,
), ),
) )
)
# 2nd-check # 2nd-check
verts, faces = meshes.get_mesh_verts_faces(1) verts, faces = meshes.get_mesh_verts_faces(1)
self.assertTrue( self.assertClose(faces.max().cpu(), torch.tensor(verts.size(0) - 1))
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1])) self.assertClose(
)
self.assertTrue(
torch.allclose(
verts, verts,
torch.tensor( torch.tensor(
[ [
@ -119,9 +111,7 @@ class TestCubify(unittest.TestCase):
device=device, device=device,
), ),
) )
) self.assertClose(
self.assertTrue(
torch.allclose(
faces, faces,
torch.tensor( torch.tensor(
[ [
@ -178,7 +168,6 @@ class TestCubify(unittest.TestCase):
device=device, device=device,
), ),
) )
)
# 3rd-check # 3rd-check
verts, faces = meshes.get_mesh_verts_faces(2) verts, faces = meshes.get_mesh_verts_faces(2)
@ -187,8 +176,7 @@ class TestCubify(unittest.TestCase):
# 4th-check # 4th-check
verts, faces = meshes.get_mesh_verts_faces(3) verts, faces = meshes.get_mesh_verts_faces(3)
self.assertTrue( self.assertClose(
torch.allclose(
verts, verts,
torch.tensor( torch.tensor(
[ [
@ -219,9 +207,7 @@ class TestCubify(unittest.TestCase):
device=device, device=device,
), ),
) )
) self.assertClose(
self.assertTrue(
torch.allclose(
faces, faces,
torch.tensor( torch.tensor(
[ [
@ -270,7 +256,46 @@ class TestCubify(unittest.TestCase):
device=device, device=device,
), ),
) )
)
def test_align(self):
N, V = 1, 2
device = torch.device("cuda:0")
voxels = torch.ones((N, V, V, V), dtype=torch.float32, device=device)
# topleft align
mesh = cubify(voxels, 0.5)
verts, faces = mesh.get_mesh_verts_faces(0)
self.assertClose(verts.min(), torch.tensor(-1.0, device=device))
self.assertClose(verts.max(), torch.tensor(3.0, device=device))
# corner align
mesh = cubify(voxels, 0.5, align="corner")
verts, faces = mesh.get_mesh_verts_faces(0)
self.assertClose(verts.min(), torch.tensor(-1.0, device=device))
self.assertClose(verts.max(), torch.tensor(1.0, device=device))
# center align
mesh = cubify(voxels, 0.5, align="center")
verts, faces = mesh.get_mesh_verts_faces(0)
self.assertClose(verts.min(), torch.tensor(-2.0, device=device))
self.assertClose(verts.max(), torch.tensor(2.0, device=device))
# invalid align
with self.assertRaisesRegex(ValueError, "Align mode must be one of"):
cubify(voxels, 0.5, align="")
# invalid align
with self.assertRaisesRegex(ValueError, "Align mode must be one of"):
cubify(voxels, 0.5, align="topright")
# inside occupancy, similar to GH#185 use case
N, V = 1, 4
voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device)
voxels[0, : V // 2, : V // 2, : V // 2] = 1.0
mesh = cubify(voxels, 0.5, align="corner")
verts, faces = mesh.get_mesh_verts_faces(0)
self.assertClose(verts.min(), torch.tensor(-1.0, device=device))
self.assertClose(verts.max(), torch.tensor(0.0, device=device))
@staticmethod @staticmethod
def cubify_with_init(batch_size: int, V: int): def cubify_with_init(batch_size: int, V: int):