diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 6588c8e0..6f855ddb 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -574,6 +574,15 @@ class TexturesAtlas(TexturesBase): """ return self.__class__(atlas=[torch.cat(self.atlas_list())]) + def check_shapes( + self, batch_size: int, max_num_verts: int, max_num_faces: int + ) -> bool: + """ + Check if the dimensions of the atlas match that of the mesh faces + """ + # (N, F) should be the same + return self.atlas_padded().shape[0:2] == (batch_size, max_num_faces) + class TexturesUV(TexturesBase): def __init__( @@ -1213,6 +1222,18 @@ class TexturesUV(TexturesBase): centers = centers[0, :, 0].T return centers + def check_shapes( + self, batch_size: int, max_num_verts: int, max_num_faces: int + ) -> bool: + """ + Check if the dimensions of the verts/faces uvs match that of the mesh + """ + # (N, F) should be the same + # (N, V) is not guaranteed to be the same + return (self.faces_uvs_padded().shape[0:2] == (batch_size, max_num_faces)) and ( + self.verts_uvs_padded().shape[0] == batch_size + ) + class TexturesVertex(TexturesBase): def __init__( @@ -1292,6 +1313,13 @@ class TexturesVertex(TexturesBase): new_props = self._getitem(index, props) verts_features = new_props["verts_features_list"] if isinstance(verts_features, list): + # Handle the case of an empty list + if len(verts_features) == 0: + verts_features = torch.empty( + size=(0, 0, 3), + dtype=torch.float32, + device=self.verts_features_padded().device, + ) new_tex = self.__class__(verts_features=verts_features) elif torch.is_tensor(verts_features): new_tex = self.__class__(verts_features=[verts_features]) @@ -1410,3 +1438,12 @@ class TexturesVertex(TexturesBase): Return a new TexturesVertex amalgamating the batch. """ return self.__class__(verts_features=[torch.cat(self.verts_features_list())]) + + def check_shapes( + self, batch_size: int, max_num_verts: int, max_num_faces: int + ) -> bool: + """ + Check if the dimensions of the verts features match that of the mesh verts + """ + # (N, V) should be the same + return self.verts_features_padded().shape[:-1] == (batch_size, max_num_verts) diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index ac88cd05..92ba4da1 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -255,6 +255,7 @@ class Meshes: if textures is not None and not hasattr(textures, "sample_textures"): msg = "Expected textures to be an instance of type TexturesBase; got %r" raise ValueError(msg % type(textures)) + self.textures = textures # Indicates whether the meshes in the list/batch have the same number @@ -424,10 +425,14 @@ class Meshes: ) # Set the num verts/faces on the textures if present. - if self.textures is not None: + if textures is not None: + shape_ok = self.textures.check_shapes(self._N, self._V, self._F) + if not shape_ok: + msg = "Textures do not match the dimensions of Meshes." + raise ValueError(msg) + self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist() self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist() - self.textures._N = self._N self.textures.valid = self.valid if verts_normals is not None: @@ -1560,6 +1565,13 @@ class Meshes: def sample_textures(self, fragments): if self.textures is not None: + + # Check dimensions of textures match that of meshes + shape_ok = self.textures.check_shapes(self._N, self._V, self._F) + if not shape_ok: + msg = "Textures do not match the dimensions of Meshes." + raise ValueError(msg) + # Pass in faces packed. If the textures are defined per # vertex, the face indices are needed in order to interpolate # the vertex attributes across the face. diff --git a/tests/test_texturing.py b/tests/test_texturing.py index c82377c4..ceb0dc0d 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -251,7 +251,7 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase): def test_getitem(self): N = 5 V = 20 - source = {"verts_features": torch.randn(size=(N, 10, 128))} + source = {"verts_features": torch.randn(size=(N, V, 128))} tex = TexturesVertex(verts_features=source["verts_features"]) verts = torch.rand(size=(N, V, 3)) @@ -268,6 +268,30 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase): tryindex(self, index, tex, meshes, source) tryindex(self, [2, 4], tex, meshes, source) + def test_sample_textures_error(self): + N = 5 + V = 20 + verts = torch.rand(size=(N, V, 3)) + faces = torch.randint(size=(N, 10, 3), high=V) + tex = TexturesVertex(verts_features=torch.randn(size=(N, 10, 128))) + + # Verts features have the wrong number of verts + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + # Verts features have the wrong batch dim + tex = TexturesVertex(verts_features=torch.randn(size=(1, V, 128))) + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + meshes = Meshes(verts=verts, faces=faces) + meshes.textures = tex + + # Cannot use the texture attribute set on meshes for sampling + # textures if the dimensions don't match + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + meshes.sample_textures(None) + class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): def test_sample_texture_atlas(self): @@ -456,11 +480,12 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): def test_getitem(self): N = 5 V = 20 - source = {"atlas": torch.randn(size=(N, 10, 4, 4, 3))} + F = 10 + source = {"atlas": torch.randn(size=(N, F, 4, 4, 3))} tex = TexturesAtlas(atlas=source["atlas"]) verts = torch.rand(size=(N, V, 3)) - faces = torch.randint(size=(N, 10, 3), high=V) + faces = torch.randint(size=(N, F, 3), high=V) meshes = Meshes(verts=verts, faces=faces, textures=tex) tryindex(self, 2, tex, meshes, source) @@ -473,6 +498,32 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): tryindex(self, index, tex, meshes, source) tryindex(self, [2, 4], tex, meshes, source) + def test_sample_textures_error(self): + N = 1 + V = 20 + F = 10 + verts = torch.rand(size=(5, V, 3)) + faces = torch.randint(size=(5, F, 3), high=V) + meshes = Meshes(verts=verts, faces=faces) + + # TexturesAtlas have the wrong batch dim + tex = TexturesAtlas(atlas=torch.randn(size=(1, F, 4, 4, 3))) + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + # TexturesAtlas have the wrong number of faces + tex = TexturesAtlas(atlas=torch.randn(size=(N, 15, 4, 4, 3))) + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + meshes = Meshes(verts=verts, faces=faces) + meshes.textures = tex + + # Cannot use the texture attribute set on meshes for sampling + # textures if the dimensions don't match + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + meshes.sample_textures(None) + class TestTexturesUV(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: @@ -824,9 +875,10 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): def test_getitem(self): N = 5 V = 20 + F = 10 source = { "maps": torch.rand(size=(N, 1, 1, 3)), - "faces_uvs": torch.randint(size=(N, 10, 3), high=V), + "faces_uvs": torch.randint(size=(N, F, 3), high=V), "verts_uvs": torch.randn(size=(N, V, 2)), } tex = TexturesUV( @@ -836,7 +888,7 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): ) verts = torch.rand(size=(N, V, 3)) - faces = torch.randint(size=(N, 10, 3), high=V) + faces = torch.randint(size=(N, F, 3), high=V) meshes = Meshes(verts=verts, faces=faces, textures=tex) tryindex(self, 2, tex, meshes, source) @@ -858,6 +910,46 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): expected = torch.FloatTensor([[32, 224], [64, 96], [64, 128]]) self.assertClose(tex.centers_for_image(0), expected) + def test_sample_textures_error(self): + N = 1 + V = 20 + F = 10 + maps = torch.rand(size=(N, 1, 1, 3)) + verts_uvs = torch.randn(size=(N, V, 2)) + tex = TexturesUV( + maps=maps, + faces_uvs=torch.randint(size=(N, 15, 3), high=V), + verts_uvs=verts_uvs, + ) + verts = torch.rand(size=(5, V, 3)) + faces = torch.randint(size=(5, 10, 3), high=V) + meshes = Meshes(verts=verts, faces=faces) + + # Wrong number of faces + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + # Wrong batch dim for faces + tex = TexturesUV( + maps=maps, + faces_uvs=torch.randint(size=(1, F, 3), high=V), + verts_uvs=verts_uvs, + ) + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + # Wrong batch dim for verts_uvs is not necessary to check as + # there is already a check inside TexturesUV for a batch dim + # mismatch with faces_uvs + + meshes = Meshes(verts=verts, faces=faces) + meshes.textures = tex + + # Cannot use the texture attribute set on meshes for sampling + # textures if the dimensions don't match + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + meshes.sample_textures(None) + class TestRectanglePacking(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: