diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index e38cf7c0..59927155 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -549,6 +549,33 @@ class TexturesAtlas(TexturesBase): return texels + def submeshes( + self, + vertex_ids_list: List[List[torch.LongTensor]], + faces_ids_list: List[List[torch.LongTensor]], + ) -> "TexturesAtlas": + """ + Extract a sub-texture for use in a submesh. + + If the meshes batch corresponding to this TextureAtlas contains + `n = len(faces_ids_list)` meshes, then self.atlas_list() + will be of length n. After submeshing, we obtain a batch of + `k = sum(len(v) for v in atlas_list` submeshes (see Meshes.submeshes). This + function creates a corresponding TexturesAtlas object with `atlas_list` + of length `k`. + """ + if len(faces_ids_list) != len(self.atlas_list()): + raise IndexError( + "faces_ids_list must be of " "the same length as atlas_list." + ) + + sub_features = [] + for atlas, faces_ids in zip(self.atlas_list(), faces_ids_list): + for faces_ids_submesh in faces_ids: + sub_features.append(atlas[faces_ids_submesh]) + + return self.__class__(sub_features) + def faces_verts_textures_packed(self) -> torch.Tensor: """ Samples texture from each vertex for each face in the mesh. diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 19b8fa5d..fce929be 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -1576,8 +1576,6 @@ class Meshes: Returns: Meshes object of length `sum(len(ids) for ids in face_indices)`. - Submeshing only works with no textures, TexturesVertex, or TexturesUV. - Example 1: If `meshes` has batch size 1, and `face_indices` is a 1D LongTensor, diff --git a/tests/test_texturing.py b/tests/test_texturing.py index 5234f2ae..71ffa3e2 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -576,6 +576,39 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "do not match the dimensions"): meshes.sample_textures(None) + def test_submeshes(self): + N = 2 + V = 5 + F = 5 + tex = TexturesAtlas( + atlas=torch.arange(N * F * 4 * 4 * 3, dtype=torch.float32).reshape( + N, F, 4, 4, 3 + ) + ) + + verts = torch.rand(size=(N, V, 3)) + faces = torch.randint(size=(N, F, 3), high=V) + mesh = Meshes(verts=verts, faces=faces, textures=tex) + + sub_faces = [ + [torch.tensor([0, 2]), torch.tensor([1, 2])], + [], + ] + subtex = mesh.submeshes(sub_faces).textures + subtex_faces = subtex.atlas_list() + + self.assertEqual(len(subtex_faces), 2) + self.assertClose( + subtex_faces[0].flatten().msort(), + torch.cat( + ( + torch.arange(4 * 4 * 3, dtype=torch.float32), + torch.arange(96, 96 + 4 * 4 * 3, dtype=torch.float32), + ), + 0, + ), + ) + class TestTexturesUV(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: