diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 09fa5c1f..036d26de 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -462,9 +462,9 @@ class Meshes(object): assert ( self._verts_padded is not None ), "verts_padded is required to compute verts_list." - self._verts_list = [ - v[0] for v in self._verts_padded.split([1] * self._N, 0) - ] + self._verts_list = struct_utils.padded_to_list( + self._verts_padded, self.num_verts_per_mesh().tolist() + ) return self._verts_list def faces_list(self): @@ -478,10 +478,9 @@ class Meshes(object): assert ( self._faces_padded is not None ), "faces_padded is required to compute faces_list." - self._faces_list = [] - for i in range(self._N): - valid = self._faces_padded[i].gt(-1).all(1) - self._faces_list.append(self._faces_padded[i, valid, :]) + self._faces_list = struct_utils.padded_to_list( + self._faces_padded, self.num_faces_per_mesh().tolist() + ) return self._faces_list def verts_packed(self): @@ -525,7 +524,6 @@ class Meshes(object): Returns: 1D tensor of sizes. """ - self._compute_packed() return self._num_verts_per_mesh def faces_packed(self): @@ -590,7 +588,6 @@ class Meshes(object): Returns: 1D tensor of sizes. """ - self._compute_packed() return self._num_faces_per_mesh def edges_packed(self): @@ -664,14 +661,13 @@ class Meshes(object): Returns: 1D tensor of indices. """ - self._compute_packed() if self._verts_padded_to_packed_idx is not None: return self._verts_padded_to_packed_idx self._verts_padded_to_packed_idx = torch.cat( [ torch.arange(v, dtype=torch.int64, device=self.device) + i * self._V - for (i, v) in enumerate(self._num_verts_per_mesh) + for (i, v) in enumerate(self.num_verts_per_mesh()) ], dim=0, ) @@ -862,8 +858,8 @@ class Meshes(object): ): return - verts_list = self._verts_list - faces_list = self._faces_list + verts_list = self.verts_list() + faces_list = self.faces_list() assert ( faces_list is not None and verts_list is not None ), "faces_list and verts_list arguments are required" @@ -943,13 +939,15 @@ class Meshes(object): verts_list_to_packed = struct_utils.list_to_packed(verts_list) self._verts_packed = verts_list_to_packed[0] - self._num_verts_per_mesh = verts_list_to_packed[1] + if not torch.allclose(self.num_verts_per_mesh(), verts_list_to_packed[1]): + raise ValueError("The number of verts per mesh should be consistent.") self._mesh_to_verts_packed_first_idx = verts_list_to_packed[2] self._verts_packed_to_mesh_idx = verts_list_to_packed[3] faces_list_to_packed = struct_utils.list_to_packed(faces_list) faces_packed = faces_list_to_packed[0] - self._num_faces_per_mesh = faces_list_to_packed[1] + if not torch.allclose(self.num_faces_per_mesh(), faces_list_to_packed[1]): + raise ValueError("The number of faces per mesh should be consistent.") self._mesh_to_faces_packed_first_idx = faces_list_to_packed[2] self._faces_packed_to_mesh_idx = faces_list_to_packed[3] @@ -1328,6 +1326,100 @@ class Meshes(object): new_mesh = self.clone() return new_mesh.scale_verts_(scale) + def update_padded(self, new_verts_padded): + """ + This function allows for an pdate of verts_padded without having to + explicitly convert it to the list representation for heterogeneous batches. + Returns a Meshes structure with updated padded tensors and copies of the + auxiliary tensors at construction time. + It updates self._verts_padded with new_verts_padded, and does a + shallow copy of (faces_padded, faces_list, num_verts_per_mesh, num_faces_per_mesh). + If packed representations are computed in self, they are updated as well. + + Args: + new_points_padded: FloatTensor of shape (N, V, 3) + + Returns: + Meshes with updated padded representations + """ + + def check_shapes(x, size): + if x.shape[0] != size[0]: + raise ValueError("new values must have the same batch dimension.") + if x.shape[1] != size[1]: + raise ValueError("new values must have the same number of points.") + if x.shape[2] != size[2]: + raise ValueError("new values must have the same dimension.") + + check_shapes(new_verts_padded, [self._N, self._V, 3]) + + new = self.__class__(verts=new_verts_padded, faces=self.faces_padded()) + + if new._N != self._N or new._V != self._V or new._F != self._F: + raise ValueError("Inconsistent sizes after construction.") + + # overwrite the equisized flag + new.equisized = self.equisized + + # overwrite textures if any + new.textures = self.textures + + # copy auxiliary tensors + copy_tensors = ["_num_verts_per_mesh", "_num_faces_per_mesh", "valid"] + + for k in copy_tensors: + v = getattr(self, k) + if torch.is_tensor(v): + setattr(new, k, v) # shallow copy + + # shallow copy of faces_list if any, st new.faces_list() + # does not re-compute from _faces_padded + new._faces_list = self._faces_list + + # update verts/faces packed if they are computed in self + if self._verts_packed is not None: + copy_tensors = [ + "_faces_packed", + "_verts_packed_to_mesh_idx", + "_faces_packed_to_mesh_idx", + "_mesh_to_verts_packed_first_idx", + "_mesh_to_faces_packed_first_idx", + ] + for k in copy_tensors: + v = getattr(self, k) + assert torch.is_tensor(v) + setattr(new, k, v) # shallow copy + # update verts_packed + pad_to_packed = self.verts_padded_to_packed_idx() + new_verts_packed = new_verts_padded.reshape(-1, 3)[pad_to_packed, :] + new._verts_packed = new_verts_packed + new._verts_padded_to_packed_idx = pad_to_packed + + # update edges packed if they are computed in self + if self._edges_packed is not None: + copy_tensors = [ + "_edges_packed", + "_edges_packed_to_mesh_idx", + "_mesh_to_edges_packed_first_idx", + "_faces_packed_to_edges_packed", + "_num_edges_per_mesh", + ] + for k in copy_tensors: + v = getattr(self, k) + assert torch.is_tensor(v) + setattr(new, k, v) # shallow copy + + # update laplacian if it is compute in self + if self._laplacian_packed is not None: + new._laplacian_packed = self._laplacian_packed + + assert new._verts_list is None + assert new._verts_normals_packed is None + assert new._faces_normals_packed is None + assert new._faces_areas_packed is None + + return new + # TODO(nikhilar) Move function to utils file. def get_bounding_boxes(self): """ diff --git a/tests/test_meshes.py b/tests/test_meshes.py index 84175368..cfc152e7 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -342,16 +342,11 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): # Modify tensors in both meshes. new_mesh._verts_list[0] = new_mesh._verts_list[0] * 5 - mesh._num_verts_per_mesh = torch.randint_like( - mesh.num_verts_per_mesh(), high=10 - ) + # Check cloned and original Meshes objects do not share tensors. self.assertFalse( torch.allclose(new_mesh._verts_list[0], mesh._verts_list[0]) ) - self.assertFalse( - torch.allclose(mesh.num_verts_per_mesh(), new_mesh.num_verts_per_mesh()) - ) self.assertSeparate(new_mesh.verts_packed(), mesh.verts_packed()) self.assertSeparate(new_mesh.verts_padded(), mesh.verts_padded()) self.assertSeparate(new_mesh.faces_packed(), mesh.faces_packed()) @@ -690,6 +685,99 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError): mesh.split(split_sizes) + def test_update_padded(self): + # Define the test mesh object either as a list or tensor of faces/verts. + N = 10 + for lists_to_tensors in (False, True): + for force in (True, False): + mesh = TestMeshes.init_mesh( + N, 100, 300, lists_to_tensors=lists_to_tensors + ) + num_verts_per_mesh = mesh.num_verts_per_mesh() + if force: + # force mesh to have computed attributes + mesh.verts_packed() + mesh.edges_packed() + mesh.laplacian_packed() + mesh.faces_areas_packed() + + new_verts = torch.rand((mesh._N, mesh._V, 3), device=mesh.device) + new_verts_list = [ + new_verts[i, : num_verts_per_mesh[i]] for i in range(N) + ] + new_mesh = mesh.update_padded(new_verts) + + # check the attributes assigned at construction time + self.assertEqual(new_mesh._N, mesh._N) + self.assertEqual(new_mesh._F, mesh._F) + self.assertEqual(new_mesh._V, mesh._V) + self.assertEqual(new_mesh.equisized, mesh.equisized) + self.assertTrue(all(new_mesh.valid == mesh.valid)) + self.assertNotSeparate( + new_mesh.num_verts_per_mesh(), mesh.num_verts_per_mesh() + ) + self.assertClose( + new_mesh.num_verts_per_mesh(), mesh.num_verts_per_mesh() + ) + self.assertNotSeparate( + new_mesh.num_faces_per_mesh(), mesh.num_faces_per_mesh() + ) + self.assertClose( + new_mesh.num_faces_per_mesh(), mesh.num_faces_per_mesh() + ) + + # check that the following attributes are not assigned + self.assertIsNone(new_mesh._verts_list) + self.assertIsNone(new_mesh._faces_areas_packed) + self.assertIsNone(new_mesh._faces_normals_packed) + self.assertIsNone(new_mesh._verts_normals_packed) + + check_tensors = [ + "_faces_packed", + "_verts_packed_to_mesh_idx", + "_faces_packed_to_mesh_idx", + "_mesh_to_verts_packed_first_idx", + "_mesh_to_faces_packed_first_idx", + "_edges_packed", + "_edges_packed_to_mesh_idx", + "_mesh_to_edges_packed_first_idx", + "_faces_packed_to_edges_packed", + "_num_edges_per_mesh", + ] + for k in check_tensors: + v = getattr(new_mesh, k) + if not force: + self.assertIsNone(v) + else: + v_old = getattr(mesh, k) + self.assertNotSeparate(v, v_old) + self.assertClose(v, v_old) + + # check verts/faces padded + self.assertClose(new_mesh.verts_padded(), new_verts) + self.assertNotSeparate(new_mesh.verts_padded(), new_verts) + self.assertClose(new_mesh.faces_padded(), mesh.faces_padded()) + self.assertNotSeparate(new_mesh.faces_padded(), mesh.faces_padded()) + # check verts/faces list + for i in range(N): + self.assertNotSeparate( + new_mesh.faces_list()[i], mesh.faces_list()[i] + ) + self.assertClose(new_mesh.faces_list()[i], mesh.faces_list()[i]) + self.assertSeparate(new_mesh.verts_list()[i], mesh.verts_list()[i]) + self.assertClose(new_mesh.verts_list()[i], new_verts_list[i]) + # check verts/faces packed + self.assertClose(new_mesh.verts_packed(), torch.cat(new_verts_list)) + self.assertSeparate(new_mesh.verts_packed(), mesh.verts_packed()) + self.assertClose(new_mesh.faces_packed(), mesh.faces_packed()) + # check pad_to_packed + self.assertClose( + new_mesh.verts_padded_to_packed_idx(), + mesh.verts_padded_to_packed_idx(), + ) + # check edges + self.assertClose(new_mesh.edges_packed(), mesh.edges_packed()) + def test_get_mesh_verts_faces(self): device = torch.device("cuda:0") verts_list = []