From a61c9376d578525c218c2e0ba7eeedef3d418076 Mon Sep 17 00:00:00 2001 From: Georgia Gkioxari Date: Tue, 5 May 2020 11:07:23 -0700 Subject: [PATCH] add align modes for cubify Summary: Add alignment modes for cubify operation. Reviewed By: nikhilaravi Differential Revision: D21393199 fbshipit-source-id: 7022044e591229a6ed5efc361fd3215e65f43f86 --- pytorch3d/ops/cubify.py | 45 +++- tests/test_cubify.py | 477 +++++++++++++++++++++------------------- 2 files changed, 290 insertions(+), 232 deletions(-) diff --git a/pytorch3d/ops/cubify.py b/pytorch3d/ops/cubify.py index e0fa3456..2d9810cb 100644 --- a/pytorch3d/ops/cubify.py +++ b/pytorch3d/ops/cubify.py @@ -45,7 +45,7 @@ def ravel_index(idx, dims) -> torch.Tensor: @torch.no_grad() -def cubify(voxels, thresh, device=None) -> Meshes: +def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes: r""" Converts a voxel to a mesh by replacing each occupied voxel with a cube consisting of 12 faces and 8 vertices. Shared vertices are merged, and @@ -54,13 +54,38 @@ def cubify(voxels, thresh, device=None) -> Meshes: voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities. thresh: A scalar threshold. If a voxel occupancy is larger than thresh, the voxel is considered occupied. + device: The device of the output meshes + align: Defines the alignment of the mesh vertices and the grid locations. + Has to be one of {"topleft", "corner", "center"}. See below for explanation. + Default is "topleft". Returns: meshes: A Meshes object of the corresponding meshes. + + + The alignment between the vertices of the cubified mesh and the voxel locations (or pixels) + is defined by the choice of `align`. We support three modes, as shown below for a 2x2 grid: + + X---X---- X-------X --------- + | | | | | | | X | X | + X---X---- --------- --------- + | | | | | | | X | X | + --------- X-------X --------- + + topleft corner center + + In the figure, X denote the grid locations and the squares represent the added cuboids. + When `align="topleft"`, then the top left corner of each cuboid corresponds to the + pixel coordinate of the input grid. + When `align="corner"`, then the corners of the output mesh span the whole grid. + When `align="center"`, then the grid locations form the center of the cuboids. """ if device is None: device = voxels.device + if align not in ["topleft", "corner", "center"]: + raise ValueError("Align mode must be one of (topleft, corner, center).") + if len(voxels) == 0: return Meshes(verts=[], faces=[]) @@ -146,7 +171,7 @@ def cubify(voxels, thresh, device=None) -> Meshes: # boolean to linear index # NF x 2 - linind = torch.nonzero(faces_idx) + linind = torch.nonzero(faces_idx, as_tuple=False) # NF x 4 nyxz = unravel_index(linind[:, 0], (N, H, W, D)) @@ -170,11 +195,19 @@ def cubify(voxels, thresh, device=None) -> Meshes: torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1) ) y = y.to(device=device, dtype=torch.float32) - y = y * 2.0 / (H - 1.0) - 1.0 x = x.to(device=device, dtype=torch.float32) - x = x * 2.0 / (W - 1.0) - 1.0 z = z.to(device=device, dtype=torch.float32) - z = z * 2.0 / (D - 1.0) - 1.0 + + if align == "center": + x = x - 0.5 + y = y - 0.5 + z = z - 0.5 + + margin = 0.0 if align == "corner" else 1.0 + y = y * 2.0 / (H - margin) - 1.0 + x = x * 2.0 / (W - margin) - 1.0 + z = z * 2.0 / (D - margin) - 1.0 + # ((H+1)(W+1)(D+1)) x 3 grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3) @@ -196,7 +229,7 @@ def cubify(voxels, thresh, device=None) -> Meshes: idlenum = idleverts.cumsum(1) verts_list = [ - grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0]) + grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0]) for n in range(N) ] faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)] diff --git a/tests/test_cubify.py b/tests/test_cubify.py index 158b8968..ce3fda66 100644 --- a/tests/test_cubify.py +++ b/tests/test_cubify.py @@ -3,16 +3,17 @@ import unittest import torch +from common_testing import TestCaseMixin from pytorch3d.ops import cubify -class TestCubify(unittest.TestCase): +class TestCubify(TestCaseMixin, unittest.TestCase): def test_allempty(self): N, V = 32, 14 device = torch.device("cuda:0") voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device) - meshes = cubify(voxels, 0.5, 0) - self.assertTrue(meshes.isempty) + meshes = cubify(voxels, 0.5) + self.assertTrue(meshes.isempty()) def test_cubify(self): N, V = 4, 2 @@ -29,155 +30,143 @@ class TestCubify(unittest.TestCase): voxels[3, 1, 1, 0] = 1.0 # compute cubify - meshes = cubify(voxels, 0.5, 0) + meshes = cubify(voxels, 0.5) # 1st-check verts, faces = meshes.get_mesh_verts_faces(0) - self.assertTrue( - torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1])) + self.assertClose(faces.max().cpu(), torch.tensor(verts.size(0) - 1)) + self.assertClose( + verts, + torch.tensor( + [ + [-1.0, -1.0, -1.0], + [-1.0, -1.0, 1.0], + [1.0, -1.0, -1.0], + [1.0, -1.0, 1.0], + [-1.0, 1.0, -1.0], + [-1.0, 1.0, 1.0], + [1.0, 1.0, -1.0], + [1.0, 1.0, 1.0], + ], + dtype=torch.float32, + device=device, + ), ) - self.assertTrue( - torch.allclose( - verts, - torch.tensor( - [ - [-1.0, -1.0, -1.0], - [-1.0, -1.0, 1.0], - [1.0, -1.0, -1.0], - [1.0, -1.0, 1.0], - [-1.0, 1.0, -1.0], - [-1.0, 1.0, 1.0], - [1.0, 1.0, -1.0], - [1.0, 1.0, 1.0], - ], - dtype=torch.float32, - device=device, - ), - ) - ) - self.assertTrue( - torch.allclose( - faces, - torch.tensor( - [ - [0, 1, 4], - [1, 5, 4], - [4, 5, 6], - [5, 7, 6], - [0, 4, 6], - [0, 6, 2], - [0, 3, 1], - [0, 2, 3], - [6, 7, 3], - [6, 3, 2], - [1, 7, 5], - [1, 3, 7], - ], - dtype=torch.int64, - device=device, - ), - ) + self.assertClose( + faces, + torch.tensor( + [ + [0, 1, 4], + [1, 5, 4], + [4, 5, 6], + [5, 7, 6], + [0, 4, 6], + [0, 6, 2], + [0, 3, 1], + [0, 2, 3], + [6, 7, 3], + [6, 3, 2], + [1, 7, 5], + [1, 3, 7], + ], + dtype=torch.int64, + device=device, + ), ) # 2nd-check verts, faces = meshes.get_mesh_verts_faces(1) - self.assertTrue( - torch.allclose(faces.max().cpu(), torch.tensor([verts.size(0) - 1])) + self.assertClose(faces.max().cpu(), torch.tensor(verts.size(0) - 1)) + self.assertClose( + verts, + torch.tensor( + [ + [-1.0, -1.0, -1.0], + [-1.0, -1.0, 1.0], + [-1.0, -1.0, 3.0], + [1.0, -1.0, -1.0], + [1.0, -1.0, 1.0], + [1.0, -1.0, 3.0], + [3.0, -1.0, -1.0], + [3.0, -1.0, 1.0], + [3.0, -1.0, 3.0], + [-1.0, 1.0, -1.0], + [-1.0, 1.0, 1.0], + [-1.0, 1.0, 3.0], + [1.0, 1.0, -1.0], + [1.0, 1.0, 3.0], + [3.0, 1.0, -1.0], + [3.0, 1.0, 1.0], + [3.0, 1.0, 3.0], + [-1.0, 3.0, -1.0], + [-1.0, 3.0, 1.0], + [-1.0, 3.0, 3.0], + [1.0, 3.0, -1.0], + [1.0, 3.0, 1.0], + [1.0, 3.0, 3.0], + [3.0, 3.0, -1.0], + [3.0, 3.0, 1.0], + [3.0, 3.0, 3.0], + ], + dtype=torch.float32, + device=device, + ), ) - self.assertTrue( - torch.allclose( - verts, - torch.tensor( - [ - [-1.0, -1.0, -1.0], - [-1.0, -1.0, 1.0], - [-1.0, -1.0, 3.0], - [1.0, -1.0, -1.0], - [1.0, -1.0, 1.0], - [1.0, -1.0, 3.0], - [3.0, -1.0, -1.0], - [3.0, -1.0, 1.0], - [3.0, -1.0, 3.0], - [-1.0, 1.0, -1.0], - [-1.0, 1.0, 1.0], - [-1.0, 1.0, 3.0], - [1.0, 1.0, -1.0], - [1.0, 1.0, 3.0], - [3.0, 1.0, -1.0], - [3.0, 1.0, 1.0], - [3.0, 1.0, 3.0], - [-1.0, 3.0, -1.0], - [-1.0, 3.0, 1.0], - [-1.0, 3.0, 3.0], - [1.0, 3.0, -1.0], - [1.0, 3.0, 1.0], - [1.0, 3.0, 3.0], - [3.0, 3.0, -1.0], - [3.0, 3.0, 1.0], - [3.0, 3.0, 3.0], - ], - dtype=torch.float32, - device=device, - ), - ) - ) - self.assertTrue( - torch.allclose( - faces, - torch.tensor( - [ - [0, 1, 9], - [1, 10, 9], - [0, 9, 12], - [0, 12, 3], - [0, 4, 1], - [0, 3, 4], - [1, 2, 10], - [2, 11, 10], - [1, 5, 2], - [1, 4, 5], - [2, 13, 11], - [2, 5, 13], - [3, 12, 14], - [3, 14, 6], - [3, 7, 4], - [3, 6, 7], - [14, 15, 7], - [14, 7, 6], - [4, 8, 5], - [4, 7, 8], - [15, 16, 8], - [15, 8, 7], - [5, 16, 13], - [5, 8, 16], - [9, 10, 17], - [10, 18, 17], - [17, 18, 20], - [18, 21, 20], - [9, 17, 20], - [9, 20, 12], - [10, 11, 18], - [11, 19, 18], - [18, 19, 21], - [19, 22, 21], - [11, 22, 19], - [11, 13, 22], - [20, 21, 23], - [21, 24, 23], - [12, 20, 23], - [12, 23, 14], - [23, 24, 15], - [23, 15, 14], - [21, 22, 24], - [22, 25, 24], - [24, 25, 16], - [24, 16, 15], - [13, 25, 22], - [13, 16, 25], - ], - dtype=torch.int64, - device=device, - ), - ) + self.assertClose( + faces, + torch.tensor( + [ + [0, 1, 9], + [1, 10, 9], + [0, 9, 12], + [0, 12, 3], + [0, 4, 1], + [0, 3, 4], + [1, 2, 10], + [2, 11, 10], + [1, 5, 2], + [1, 4, 5], + [2, 13, 11], + [2, 5, 13], + [3, 12, 14], + [3, 14, 6], + [3, 7, 4], + [3, 6, 7], + [14, 15, 7], + [14, 7, 6], + [4, 8, 5], + [4, 7, 8], + [15, 16, 8], + [15, 8, 7], + [5, 16, 13], + [5, 8, 16], + [9, 10, 17], + [10, 18, 17], + [17, 18, 20], + [18, 21, 20], + [9, 17, 20], + [9, 20, 12], + [10, 11, 18], + [11, 19, 18], + [18, 19, 21], + [19, 22, 21], + [11, 22, 19], + [11, 13, 22], + [20, 21, 23], + [21, 24, 23], + [12, 20, 23], + [12, 23, 14], + [23, 24, 15], + [23, 15, 14], + [21, 22, 24], + [22, 25, 24], + [24, 25, 16], + [24, 16, 15], + [13, 25, 22], + [13, 16, 25], + ], + dtype=torch.int64, + device=device, + ), ) # 3rd-check @@ -187,91 +176,127 @@ class TestCubify(unittest.TestCase): # 4th-check verts, faces = meshes.get_mesh_verts_faces(3) - self.assertTrue( - torch.allclose( - verts, - torch.tensor( - [ - [1.0, -1.0, -1.0], - [1.0, -1.0, 1.0], - [1.0, -1.0, 3.0], - [3.0, -1.0, -1.0], - [3.0, -1.0, 1.0], - [3.0, -1.0, 3.0], - [-1.0, 1.0, 1.0], - [-1.0, 1.0, 3.0], - [1.0, 1.0, -1.0], - [1.0, 1.0, 1.0], - [1.0, 1.0, 3.0], - [3.0, 1.0, -1.0], - [3.0, 1.0, 1.0], - [3.0, 1.0, 3.0], - [-1.0, 3.0, 1.0], - [-1.0, 3.0, 3.0], - [1.0, 3.0, -1.0], - [1.0, 3.0, 1.0], - [1.0, 3.0, 3.0], - [3.0, 3.0, -1.0], - [3.0, 3.0, 1.0], - [3.0, 3.0, 3.0], - ], - dtype=torch.float32, - device=device, - ), - ) + self.assertClose( + verts, + torch.tensor( + [ + [1.0, -1.0, -1.0], + [1.0, -1.0, 1.0], + [1.0, -1.0, 3.0], + [3.0, -1.0, -1.0], + [3.0, -1.0, 1.0], + [3.0, -1.0, 3.0], + [-1.0, 1.0, 1.0], + [-1.0, 1.0, 3.0], + [1.0, 1.0, -1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 3.0], + [3.0, 1.0, -1.0], + [3.0, 1.0, 1.0], + [3.0, 1.0, 3.0], + [-1.0, 3.0, 1.0], + [-1.0, 3.0, 3.0], + [1.0, 3.0, -1.0], + [1.0, 3.0, 1.0], + [1.0, 3.0, 3.0], + [3.0, 3.0, -1.0], + [3.0, 3.0, 1.0], + [3.0, 3.0, 3.0], + ], + dtype=torch.float32, + device=device, + ), ) - self.assertTrue( - torch.allclose( - faces, - torch.tensor( - [ - [0, 1, 8], - [1, 9, 8], - [0, 8, 11], - [0, 11, 3], - [0, 4, 1], - [0, 3, 4], - [11, 12, 4], - [11, 4, 3], - [1, 2, 9], - [2, 10, 9], - [1, 5, 2], - [1, 4, 5], - [12, 13, 5], - [12, 5, 4], - [2, 13, 10], - [2, 5, 13], - [6, 7, 14], - [7, 15, 14], - [14, 15, 17], - [15, 18, 17], - [6, 14, 17], - [6, 17, 9], - [6, 10, 7], - [6, 9, 10], - [7, 18, 15], - [7, 10, 18], - [8, 9, 16], - [9, 17, 16], - [16, 17, 19], - [17, 20, 19], - [8, 16, 19], - [8, 19, 11], - [19, 20, 12], - [19, 12, 11], - [17, 18, 20], - [18, 21, 20], - [20, 21, 13], - [20, 13, 12], - [10, 21, 18], - [10, 13, 21], - ], - dtype=torch.int64, - device=device, - ), - ) + self.assertClose( + faces, + torch.tensor( + [ + [0, 1, 8], + [1, 9, 8], + [0, 8, 11], + [0, 11, 3], + [0, 4, 1], + [0, 3, 4], + [11, 12, 4], + [11, 4, 3], + [1, 2, 9], + [2, 10, 9], + [1, 5, 2], + [1, 4, 5], + [12, 13, 5], + [12, 5, 4], + [2, 13, 10], + [2, 5, 13], + [6, 7, 14], + [7, 15, 14], + [14, 15, 17], + [15, 18, 17], + [6, 14, 17], + [6, 17, 9], + [6, 10, 7], + [6, 9, 10], + [7, 18, 15], + [7, 10, 18], + [8, 9, 16], + [9, 17, 16], + [16, 17, 19], + [17, 20, 19], + [8, 16, 19], + [8, 19, 11], + [19, 20, 12], + [19, 12, 11], + [17, 18, 20], + [18, 21, 20], + [20, 21, 13], + [20, 13, 12], + [10, 21, 18], + [10, 13, 21], + ], + dtype=torch.int64, + device=device, + ), ) + def test_align(self): + N, V = 1, 2 + device = torch.device("cuda:0") + voxels = torch.ones((N, V, V, V), dtype=torch.float32, device=device) + + # topleft align + mesh = cubify(voxels, 0.5) + verts, faces = mesh.get_mesh_verts_faces(0) + self.assertClose(verts.min(), torch.tensor(-1.0, device=device)) + self.assertClose(verts.max(), torch.tensor(3.0, device=device)) + + # corner align + mesh = cubify(voxels, 0.5, align="corner") + verts, faces = mesh.get_mesh_verts_faces(0) + self.assertClose(verts.min(), torch.tensor(-1.0, device=device)) + self.assertClose(verts.max(), torch.tensor(1.0, device=device)) + + # center align + mesh = cubify(voxels, 0.5, align="center") + verts, faces = mesh.get_mesh_verts_faces(0) + self.assertClose(verts.min(), torch.tensor(-2.0, device=device)) + self.assertClose(verts.max(), torch.tensor(2.0, device=device)) + + # invalid align + with self.assertRaisesRegex(ValueError, "Align mode must be one of"): + cubify(voxels, 0.5, align="") + + # invalid align + with self.assertRaisesRegex(ValueError, "Align mode must be one of"): + cubify(voxels, 0.5, align="topright") + + # inside occupancy, similar to GH#185 use case + N, V = 1, 4 + voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device) + voxels[0, : V // 2, : V // 2, : V // 2] = 1.0 + mesh = cubify(voxels, 0.5, align="corner") + verts, faces = mesh.get_mesh_verts_faces(0) + self.assertClose(verts.min(), torch.tensor(-1.0, device=device)) + self.assertClose(verts.max(), torch.tensor(0.0, device=device)) + @staticmethod def cubify_with_init(batch_size: int, V: int): device = torch.device("cuda:0")