add align modes for cubify

Summary: Add alignment modes for cubify operation.

Reviewed By: nikhilaravi

Differential Revision: D21393199

fbshipit-source-id: 7022044e591229a6ed5efc361fd3215e65f43f86
This commit is contained in:
Georgia Gkioxari 2020-05-05 11:07:23 -07:00 committed by Facebook GitHub Bot
parent 8fc28baa27
commit a61c9376d5
2 changed files with 290 additions and 232 deletions

View File

@ -45,7 +45,7 @@ def ravel_index(idx, dims) -> torch.Tensor:
@torch.no_grad() @torch.no_grad()
def cubify(voxels, thresh, device=None) -> Meshes: def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
r""" r"""
Converts a voxel to a mesh by replacing each occupied voxel with a cube Converts a voxel to a mesh by replacing each occupied voxel with a cube
consisting of 12 faces and 8 vertices. Shared vertices are merged, and consisting of 12 faces and 8 vertices. Shared vertices are merged, and
@ -54,13 +54,38 @@ def cubify(voxels, thresh, device=None) -> Meshes:
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities. voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
thresh: A scalar threshold. If a voxel occupancy is larger than thresh: A scalar threshold. If a voxel occupancy is larger than
thresh, the voxel is considered occupied. thresh, the voxel is considered occupied.
device: The device of the output meshes
align: Defines the alignment of the mesh vertices and the grid locations.
Has to be one of {"topleft", "corner", "center"}. See below for explanation.
Default is "topleft".
Returns: Returns:
meshes: A Meshes object of the corresponding meshes. meshes: A Meshes object of the corresponding meshes.
The alignment between the vertices of the cubified mesh and the voxel locations (or pixels)
is defined by the choice of `align`. We support three modes, as shown below for a 2x2 grid:
X---X---- X-------X ---------
| | | | | | | X | X |
X---X---- --------- ---------
| | | | | | | X | X |
--------- X-------X ---------
topleft corner center
In the figure, X denote the grid locations and the squares represent the added cuboids.
When `align="topleft"`, then the top left corner of each cuboid corresponds to the
pixel coordinate of the input grid.
When `align="corner"`, then the corners of the output mesh span the whole grid.
When `align="center"`, then the grid locations form the center of the cuboids.
""" """
if device is None: if device is None:
device = voxels.device device = voxels.device
if align not in ["topleft", "corner", "center"]:
raise ValueError("Align mode must be one of (topleft, corner, center).")
if len(voxels) == 0: if len(voxels) == 0:
return Meshes(verts=[], faces=[]) return Meshes(verts=[], faces=[])
@ -146,7 +171,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
# boolean to linear index # boolean to linear index
# NF x 2 # NF x 2
linind = torch.nonzero(faces_idx) linind = torch.nonzero(faces_idx, as_tuple=False)
# NF x 4 # NF x 4
nyxz = unravel_index(linind[:, 0], (N, H, W, D)) nyxz = unravel_index(linind[:, 0], (N, H, W, D))
@ -170,11 +195,19 @@ def cubify(voxels, thresh, device=None) -> Meshes:
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1) torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
) )
y = y.to(device=device, dtype=torch.float32) y = y.to(device=device, dtype=torch.float32)
y = y * 2.0 / (H - 1.0) - 1.0
x = x.to(device=device, dtype=torch.float32) x = x.to(device=device, dtype=torch.float32)
x = x * 2.0 / (W - 1.0) - 1.0
z = z.to(device=device, dtype=torch.float32) z = z.to(device=device, dtype=torch.float32)
z = z * 2.0 / (D - 1.0) - 1.0
if align == "center":
x = x - 0.5
y = y - 0.5
z = z - 0.5
margin = 0.0 if align == "corner" else 1.0
y = y * 2.0 / (H - margin) - 1.0
x = x * 2.0 / (W - margin) - 1.0
z = z * 2.0 / (D - margin) - 1.0
# ((H+1)(W+1)(D+1)) x 3 # ((H+1)(W+1)(D+1)) x 3
grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3) grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)
@ -196,7 +229,7 @@ def cubify(voxels, thresh, device=None) -> Meshes:
idlenum = idleverts.cumsum(1) idlenum = idleverts.cumsum(1)
verts_list = [ verts_list = [
grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0]) grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0])
for n in range(N) for n in range(N)
] ]
faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)] faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]

View File

@ -3,16 +3,17 @@
import unittest import unittest
import torch import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import cubify from pytorch3d.ops import cubify
class TestCubify(unittest.TestCase): class TestCubify(TestCaseMixin, unittest.TestCase):
def test_allempty(self): def test_allempty(self):
N, V = 32, 14 N, V = 32, 14
device = torch.device("cuda:0") device = torch.device("cuda:0")
voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device) voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device)
meshes = cubify(voxels, 0.5, 0) meshes = cubify(voxels, 0.5)
self.assertTrue(meshes.isempty) self.assertTrue(meshes.isempty())
def test_cubify(self): def test_cubify(self):
N, V = 4, 2 N, V = 4, 2
@ -29,155 +30,143 @@ class TestCubify(unittest.TestCase):
voxels[3, 1, 1, 0] = 1.0 voxels[3, 1, 1, 0] = 1.0
# compute cubify # compute cubify
meshes = cubify(voxels, 0.5, 0) meshes = cubify(voxels, 0.5)
# 1st-check # 1st-check
verts, faces = meshes.get_mesh_verts_faces(0) verts, faces = meshes.get_mesh_verts_faces(0)
self.assertTrue( self.assertClose(faces.max().cpu(), torch.tensor(verts.size(0) - 1))
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1])) self.assertClose(
verts,
torch.tensor(
[
[-1.0, -1.0, -1.0],
[-1.0, -1.0, 1.0],
[1.0, -1.0, -1.0],
[1.0, -1.0, 1.0],
[-1.0, 1.0, -1.0],
[-1.0, 1.0, 1.0],
[1.0, 1.0, -1.0],
[1.0, 1.0, 1.0],
],
dtype=torch.float32,
device=device,
),
) )
self.assertTrue( self.assertClose(
torch.allclose( faces,
verts, torch.tensor(
torch.tensor( [
[ [0, 1, 4],
[-1.0, -1.0, -1.0], [1, 5, 4],
[-1.0, -1.0, 1.0], [4, 5, 6],
[1.0, -1.0, -1.0], [5, 7, 6],
[1.0, -1.0, 1.0], [0, 4, 6],
[-1.0, 1.0, -1.0], [0, 6, 2],
[-1.0, 1.0, 1.0], [0, 3, 1],
[1.0, 1.0, -1.0], [0, 2, 3],
[1.0, 1.0, 1.0], [6, 7, 3],
], [6, 3, 2],
dtype=torch.float32, [1, 7, 5],
device=device, [1, 3, 7],
), ],
) dtype=torch.int64,
) device=device,
self.assertTrue( ),
torch.allclose(
faces,
torch.tensor(
[
[0, 1, 4],
[1, 5, 4],
[4, 5, 6],
[5, 7, 6],
[0, 4, 6],
[0, 6, 2],
[0, 3, 1],
[0, 2, 3],
[6, 7, 3],
[6, 3, 2],
[1, 7, 5],
[1, 3, 7],
],
dtype=torch.int64,
device=device,
),
)
) )
# 2nd-check # 2nd-check
verts, faces = meshes.get_mesh_verts_faces(1) verts, faces = meshes.get_mesh_verts_faces(1)
self.assertTrue( self.assertClose(faces.max().cpu(), torch.tensor(verts.size(0) - 1))
torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1])) self.assertClose(
verts,
torch.tensor(
[
[-1.0, -1.0, -1.0],
[-1.0, -1.0, 1.0],
[-1.0, -1.0, 3.0],
[1.0, -1.0, -1.0],
[1.0, -1.0, 1.0],
[1.0, -1.0, 3.0],
[3.0, -1.0, -1.0],
[3.0, -1.0, 1.0],
[3.0, -1.0, 3.0],
[-1.0, 1.0, -1.0],
[-1.0, 1.0, 1.0],
[-1.0, 1.0, 3.0],
[1.0, 1.0, -1.0],
[1.0, 1.0, 3.0],
[3.0, 1.0, -1.0],
[3.0, 1.0, 1.0],
[3.0, 1.0, 3.0],
[-1.0, 3.0, -1.0],
[-1.0, 3.0, 1.0],
[-1.0, 3.0, 3.0],
[1.0, 3.0, -1.0],
[1.0, 3.0, 1.0],
[1.0, 3.0, 3.0],
[3.0, 3.0, -1.0],
[3.0, 3.0, 1.0],
[3.0, 3.0, 3.0],
],
dtype=torch.float32,
device=device,
),
) )
self.assertTrue( self.assertClose(
torch.allclose( faces,
verts, torch.tensor(
torch.tensor( [
[ [0, 1, 9],
[-1.0, -1.0, -1.0], [1, 10, 9],
[-1.0, -1.0, 1.0], [0, 9, 12],
[-1.0, -1.0, 3.0], [0, 12, 3],
[1.0, -1.0, -1.0], [0, 4, 1],
[1.0, -1.0, 1.0], [0, 3, 4],
[1.0, -1.0, 3.0], [1, 2, 10],
[3.0, -1.0, -1.0], [2, 11, 10],
[3.0, -1.0, 1.0], [1, 5, 2],
[3.0, -1.0, 3.0], [1, 4, 5],
[-1.0, 1.0, -1.0], [2, 13, 11],
[-1.0, 1.0, 1.0], [2, 5, 13],
[-1.0, 1.0, 3.0], [3, 12, 14],
[1.0, 1.0, -1.0], [3, 14, 6],
[1.0, 1.0, 3.0], [3, 7, 4],
[3.0, 1.0, -1.0], [3, 6, 7],
[3.0, 1.0, 1.0], [14, 15, 7],
[3.0, 1.0, 3.0], [14, 7, 6],
[-1.0, 3.0, -1.0], [4, 8, 5],
[-1.0, 3.0, 1.0], [4, 7, 8],
[-1.0, 3.0, 3.0], [15, 16, 8],
[1.0, 3.0, -1.0], [15, 8, 7],
[1.0, 3.0, 1.0], [5, 16, 13],
[1.0, 3.0, 3.0], [5, 8, 16],
[3.0, 3.0, -1.0], [9, 10, 17],
[3.0, 3.0, 1.0], [10, 18, 17],
[3.0, 3.0, 3.0], [17, 18, 20],
], [18, 21, 20],
dtype=torch.float32, [9, 17, 20],
device=device, [9, 20, 12],
), [10, 11, 18],
) [11, 19, 18],
) [18, 19, 21],
self.assertTrue( [19, 22, 21],
torch.allclose( [11, 22, 19],
faces, [11, 13, 22],
torch.tensor( [20, 21, 23],
[ [21, 24, 23],
[0, 1, 9], [12, 20, 23],
[1, 10, 9], [12, 23, 14],
[0, 9, 12], [23, 24, 15],
[0, 12, 3], [23, 15, 14],
[0, 4, 1], [21, 22, 24],
[0, 3, 4], [22, 25, 24],
[1, 2, 10], [24, 25, 16],
[2, 11, 10], [24, 16, 15],
[1, 5, 2], [13, 25, 22],
[1, 4, 5], [13, 16, 25],
[2, 13, 11], ],
[2, 5, 13], dtype=torch.int64,
[3, 12, 14], device=device,
[3, 14, 6], ),
[3, 7, 4],
[3, 6, 7],
[14, 15, 7],
[14, 7, 6],
[4, 8, 5],
[4, 7, 8],
[15, 16, 8],
[15, 8, 7],
[5, 16, 13],
[5, 8, 16],
[9, 10, 17],
[10, 18, 17],
[17, 18, 20],
[18, 21, 20],
[9, 17, 20],
[9, 20, 12],
[10, 11, 18],
[11, 19, 18],
[18, 19, 21],
[19, 22, 21],
[11, 22, 19],
[11, 13, 22],
[20, 21, 23],
[21, 24, 23],
[12, 20, 23],
[12, 23, 14],
[23, 24, 15],
[23, 15, 14],
[21, 22, 24],
[22, 25, 24],
[24, 25, 16],
[24, 16, 15],
[13, 25, 22],
[13, 16, 25],
],
dtype=torch.int64,
device=device,
),
)
) )
# 3rd-check # 3rd-check
@ -187,91 +176,127 @@ class TestCubify(unittest.TestCase):
# 4th-check # 4th-check
verts, faces = meshes.get_mesh_verts_faces(3) verts, faces = meshes.get_mesh_verts_faces(3)
self.assertTrue( self.assertClose(
torch.allclose( verts,
verts, torch.tensor(
torch.tensor( [
[ [1.0, -1.0, -1.0],
[1.0, -1.0, -1.0], [1.0, -1.0, 1.0],
[1.0, -1.0, 1.0], [1.0, -1.0, 3.0],
[1.0, -1.0, 3.0], [3.0, -1.0, -1.0],
[3.0, -1.0, -1.0], [3.0, -1.0, 1.0],
[3.0, -1.0, 1.0], [3.0, -1.0, 3.0],
[3.0, -1.0, 3.0], [-1.0, 1.0, 1.0],
[-1.0, 1.0, 1.0], [-1.0, 1.0, 3.0],
[-1.0, 1.0, 3.0], [1.0, 1.0, -1.0],
[1.0, 1.0, -1.0], [1.0, 1.0, 1.0],
[1.0, 1.0, 1.0], [1.0, 1.0, 3.0],
[1.0, 1.0, 3.0], [3.0, 1.0, -1.0],
[3.0, 1.0, -1.0], [3.0, 1.0, 1.0],
[3.0, 1.0, 1.0], [3.0, 1.0, 3.0],
[3.0, 1.0, 3.0], [-1.0, 3.0, 1.0],
[-1.0, 3.0, 1.0], [-1.0, 3.0, 3.0],
[-1.0, 3.0, 3.0], [1.0, 3.0, -1.0],
[1.0, 3.0, -1.0], [1.0, 3.0, 1.0],
[1.0, 3.0, 1.0], [1.0, 3.0, 3.0],
[1.0, 3.0, 3.0], [3.0, 3.0, -1.0],
[3.0, 3.0, -1.0], [3.0, 3.0, 1.0],
[3.0, 3.0, 1.0], [3.0, 3.0, 3.0],
[3.0, 3.0, 3.0], ],
], dtype=torch.float32,
dtype=torch.float32, device=device,
device=device, ),
),
)
) )
self.assertTrue( self.assertClose(
torch.allclose( faces,
faces, torch.tensor(
torch.tensor( [
[ [0, 1, 8],
[0, 1, 8], [1, 9, 8],
[1, 9, 8], [0, 8, 11],
[0, 8, 11], [0, 11, 3],
[0, 11, 3], [0, 4, 1],
[0, 4, 1], [0, 3, 4],
[0, 3, 4], [11, 12, 4],
[11, 12, 4], [11, 4, 3],
[11, 4, 3], [1, 2, 9],
[1, 2, 9], [2, 10, 9],
[2, 10, 9], [1, 5, 2],
[1, 5, 2], [1, 4, 5],
[1, 4, 5], [12, 13, 5],
[12, 13, 5], [12, 5, 4],
[12, 5, 4], [2, 13, 10],
[2, 13, 10], [2, 5, 13],
[2, 5, 13], [6, 7, 14],
[6, 7, 14], [7, 15, 14],
[7, 15, 14], [14, 15, 17],
[14, 15, 17], [15, 18, 17],
[15, 18, 17], [6, 14, 17],
[6, 14, 17], [6, 17, 9],
[6, 17, 9], [6, 10, 7],
[6, 10, 7], [6, 9, 10],
[6, 9, 10], [7, 18, 15],
[7, 18, 15], [7, 10, 18],
[7, 10, 18], [8, 9, 16],
[8, 9, 16], [9, 17, 16],
[9, 17, 16], [16, 17, 19],
[16, 17, 19], [17, 20, 19],
[17, 20, 19], [8, 16, 19],
[8, 16, 19], [8, 19, 11],
[8, 19, 11], [19, 20, 12],
[19, 20, 12], [19, 12, 11],
[19, 12, 11], [17, 18, 20],
[17, 18, 20], [18, 21, 20],
[18, 21, 20], [20, 21, 13],
[20, 21, 13], [20, 13, 12],
[20, 13, 12], [10, 21, 18],
[10, 21, 18], [10, 13, 21],
[10, 13, 21], ],
], dtype=torch.int64,
dtype=torch.int64, device=device,
device=device, ),
),
)
) )
def test_align(self):
N, V = 1, 2
device = torch.device("cuda:0")
voxels = torch.ones((N, V, V, V), dtype=torch.float32, device=device)
# topleft align
mesh = cubify(voxels, 0.5)
verts, faces = mesh.get_mesh_verts_faces(0)
self.assertClose(verts.min(), torch.tensor(-1.0, device=device))
self.assertClose(verts.max(), torch.tensor(3.0, device=device))
# corner align
mesh = cubify(voxels, 0.5, align="corner")
verts, faces = mesh.get_mesh_verts_faces(0)
self.assertClose(verts.min(), torch.tensor(-1.0, device=device))
self.assertClose(verts.max(), torch.tensor(1.0, device=device))
# center align
mesh = cubify(voxels, 0.5, align="center")
verts, faces = mesh.get_mesh_verts_faces(0)
self.assertClose(verts.min(), torch.tensor(-2.0, device=device))
self.assertClose(verts.max(), torch.tensor(2.0, device=device))
# invalid align
with self.assertRaisesRegex(ValueError, "Align mode must be one of"):
cubify(voxels, 0.5, align="")
# invalid align
with self.assertRaisesRegex(ValueError, "Align mode must be one of"):
cubify(voxels, 0.5, align="topright")
# inside occupancy, similar to GH#185 use case
N, V = 1, 4
voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device)
voxels[0, : V // 2, : V // 2, : V // 2] = 1.0
mesh = cubify(voxels, 0.5, align="corner")
verts, faces = mesh.get_mesh_verts_faces(0)
self.assertClose(verts.min(), torch.tensor(-1.0, device=device))
self.assertClose(verts.max(), torch.tensor(0.0, device=device))
@staticmethod @staticmethod
def cubify_with_init(batch_size: int, V: int): def cubify_with_init(batch_size: int, V: int):
device = torch.device("cuda:0") device = torch.device("cuda:0")