diff --git a/tests/test_meshes.py b/tests/test_meshes.py index afe6e941..0b11c054 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -166,6 +166,73 @@ def mesh_structures_equal(mesh1, mesh2) -> bool: return True +def to_sorted(mesh: Meshes) -> "Meshes": + """ + Create a new Meshes object, where each sub-mesh's vertices are sorted + alphabetically. + + Returns: + A Meshes object with the same topology as this mesh, with vertices sorted + alphabetically. + + Example: + + For a mesh with verts [[2.3, .2, .4], [.0, .1, .2], [.0, .0, .1]] and a single + face [[0, 1, 2]], to_sorted will create a new mesh with verts [[.0, .0, .1], + [.0, .1, .2], [2.3, .2, .4]] and a single face [[2, 1, 0]]. This is useful to + create a semi-canonical representation of the mesh that is invariant to vertex + permutations, but not invariant to coordinate frame changes. + """ + if mesh.textures is not None: + raise NotImplementedError( + "to_sorted is not implemented for meshes with " + f"{type(mesh.textures).__name__} textures." + ) + + verts_list = mesh.verts_list() + faces_list = mesh.faces_list() + verts_sorted_list = [] + faces_sorted_list = [] + + for verts, faces in zip(verts_list, faces_list): + # Argsort the vertices alphabetically: sort_ids[k] corresponds to the id of + # the vertex in the non-sorted mesh that should sit at index k in the sorted mesh. + sort_ids = torch.tensor( + [ + idx_and_val[0] + for idx_and_val in sorted( + enumerate(verts.tolist()), + key=lambda idx_and_val: idx_and_val[1], + ) + ], + device=mesh.device, + ) + + # Resort the vertices. index_select allocates new memory. + verts_sorted = verts[sort_ids] + verts_sorted_list.append(verts_sorted) + + # The `faces` tensor contains vertex ids. Substitute old vertex ids for the + # new ones. new_vertex_ids is the inverse of sort_ids: new_vertex_ids[k] + # corresponds to the id of the vertex in the sorted mesh that is the same as + # vertex k in the non-sorted mesh. + new_vertex_ids = torch.argsort(sort_ids) + faces_sorted = ( + torch.gather(new_vertex_ids, 0, faces.flatten()) + .reshape(faces.shape) + .clone() + ) + faces_sorted_list.append(faces_sorted) + + other = mesh.__class__(verts=verts_sorted_list, faces=faces_sorted_list) + for k in mesh._INTERNAL_TENSORS: + v = getattr(mesh, k) + if torch.is_tensor(v): + setattr(other, k, v.clone()) + + return other + + class TestMeshes(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: np.random.seed(42) @@ -1223,6 +1290,57 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): self.assertFalse(mesh_structures_equal(meshes1, meshes2)) self.assertFalse(mesh_structures_equal(meshes1, meshes3)) + def test_to_sorted(self): + mesh = init_simple_mesh() + sorted_mesh = to_sorted(mesh) + + expected_verts = [ + torch.tensor( + [[0.1, 0.3, 0.5], [0.5, 0.2, 0.1], [0.6, 0.8, 0.7]], + dtype=torch.float32, + ), + torch.tensor( + # Vertex permutation: 0->0, 1->3, 2->2, 3->1 + [[0.1, 0.3, 0.3], [0.1, 0.5, 0.3], [0.2, 0.3, 0.4], [0.6, 0.7, 0.8]], + dtype=torch.float32, + ), + torch.tensor( + # Vertex permutation: 0->2, 1->1, 2->4, 3->0, 4->3 + [ + [0.2, 0.3, 0.4], + [0.2, 0.4, 0.8], + [0.7, 0.3, 0.6], + [0.9, 0.3, 0.8], + [0.9, 0.5, 0.2], + ], + dtype=torch.float32, + ), + ] + + expected_faces = [ + torch.tensor([[0, 1, 2]], dtype=torch.int64), + torch.tensor([[0, 3, 2], [3, 2, 1]], dtype=torch.int64), + torch.tensor( + [ + [1, 4, 2], + [2, 1, 0], + [4, 0, 1], + [3, 0, 4], + [3, 2, 1], + [3, 0, 1], + [3, 4, 1], + ], + dtype=torch.int64, + ), + ] + + self.assertFalse(mesh_structures_equal(mesh, sorted_mesh)) + self.assertTrue( + mesh_structures_equal( + Meshes(verts=expected_verts, faces=expected_faces), sorted_mesh + ) + ) + @staticmethod def compute_packed_with_init( num_meshes: int = 10, max_v: int = 100, max_f: int = 300, device: str = "cpu"