diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index a0f34a6f..aaefaa92 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -415,10 +415,12 @@ class Meshes(object): else: raise IndexError(index) + textures = None if self.textures is None else self.textures[index] + if torch.is_tensor(verts) and torch.is_tensor(faces): - return Meshes(verts=[verts], faces=[faces]) + return Meshes(verts=[verts], faces=[faces], textures=textures) elif isinstance(verts, list) and isinstance(faces, list): - return Meshes(verts=verts, faces=faces) + return Meshes(verts=verts, faces=faces, textures=textures) else: raise ValueError("(verts, faces) not defined correctly") diff --git a/pytorch3d/structures/textures.py b/pytorch3d/structures/textures.py index 4daf69bb..0f30f0aa 100644 --- a/pytorch3d/structures/textures.py +++ b/pytorch3d/structures/textures.py @@ -115,8 +115,14 @@ class Textures(object): self._verts_rgb_padded = verts_rgb self._maps_padded = maps self._num_faces_per_mesh = None + self._set_num_faces_per_mesh() + def _set_num_faces_per_mesh(self) -> None: + """ + Determines and sets the number of textured faces for each mesh. + """ if self._faces_uvs_padded is not None: + faces_uvs = self._faces_uvs_padded self._num_faces_per_mesh = faces_uvs.gt(-1).all(-1).sum(-1).tolist() def clone(self): @@ -134,6 +140,18 @@ class Textures(object): setattr(self, k, v.to(device)) return self + def __getitem__(self, index): + other = Textures() + for key in dir(self): + value = getattr(self, key) + if torch.is_tensor(value): + if isinstance(index, int): + setattr(other, key, value[index][None]) + else: + setattr(other, key, value[index]) + other._set_num_faces_per_mesh() + return other + def faces_uvs_padded(self) -> torch.Tensor: return self._faces_uvs_padded diff --git a/tests/test_texturing.py b/tests/test_texturing.py index aea04553..07e53495 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -167,6 +167,55 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded) self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded) + def test_getitem(self): + N = 5 + V = 20 + source = { + "maps": torch.rand(size=(N, 16, 16, 3)), + "faces_uvs": torch.randint(size=(N, 10, 3), low=0, high=V), + "verts_uvs": torch.rand((N, V, 2)), + } + tex = Textures( + maps=source["maps"], + faces_uvs=source["faces_uvs"], + verts_uvs=source["verts_uvs"], + ) + + verts = torch.rand(size=(N, V, 3)) + faces = torch.randint(size=(N, 10, 3), high=V) + meshes = Meshes(verts=verts, faces=faces, textures=tex) + + def tryindex(index): + tex2 = tex[index] + meshes2 = meshes[index] + tex_from_meshes = meshes2.textures + for item in source: + basic = source[item][index] + from_texture = getattr(tex2, item + "_padded")() + from_meshes = getattr(tex_from_meshes, item + "_padded")() + if isinstance(index, int): + basic = basic[None] + self.assertClose(basic, from_texture) + self.assertClose(basic, from_meshes) + self.assertEqual( + from_texture.ndim, getattr(tex, item + "_padded")().ndim + ) + if item == "faces_uvs": + faces_uvs_list = tex_from_meshes.faces_uvs_list() + self.assertEqual(basic.shape[0], len(faces_uvs_list)) + for i, faces_uvs in enumerate(faces_uvs_list): + self.assertClose(faces_uvs, basic[i]) + + tryindex(2) + tryindex(slice(0, 2, 1)) + index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool) + tryindex(index) + index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool) + tryindex(index) + index = torch.tensor([1, 2], dtype=torch.int64) + tryindex(index) + tryindex([2, 4]) + def test_to(self): V = 20 tex = Textures(