From ef16253953b7035da69f25107d83cc402562ce08 Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Fri, 11 Jun 2021 13:38:01 -0700 Subject: [PATCH] textures dimension check Summary: When textures are set on `Meshes` we need to check if the dimensions actually match that of the verts/faces in the mesh. There was a github issue where someone tried to set the attribute after construction of the `Meshes` object and ran into an error when trying to sample textures. The desired usage is to initialize the class with the textures (not set an attribute afterwards) but in any case we need to check the dimensions match before sampling textures. Reviewed By: bottler Differential Revision: D29020249 fbshipit-source-id: 9fb8a5368b83c9ec53652df92b96fc8b2613f591 --- pytorch3d/renderer/mesh/textures.py | 37 ++++++++++ pytorch3d/structures/meshes.py | 16 ++++- tests/test_texturing.py | 102 ++++++++++++++++++++++++++-- 3 files changed, 148 insertions(+), 7 deletions(-) diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 6588c8e0..6f855ddb 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -574,6 +574,15 @@ class TexturesAtlas(TexturesBase): """ return self.__class__(atlas=[torch.cat(self.atlas_list())]) + def check_shapes( + self, batch_size: int, max_num_verts: int, max_num_faces: int + ) -> bool: + """ + Check if the dimensions of the atlas match that of the mesh faces + """ + # (N, F) should be the same + return self.atlas_padded().shape[0:2] == (batch_size, max_num_faces) + class TexturesUV(TexturesBase): def __init__( @@ -1213,6 +1222,18 @@ class TexturesUV(TexturesBase): centers = centers[0, :, 0].T return centers + def check_shapes( + self, batch_size: int, max_num_verts: int, max_num_faces: int + ) -> bool: + """ + Check if the dimensions of the verts/faces uvs match that of the mesh + """ + # (N, F) should be the same + # (N, V) is not guaranteed to be the same + return (self.faces_uvs_padded().shape[0:2] == (batch_size, max_num_faces)) and ( + self.verts_uvs_padded().shape[0] == batch_size + ) + class TexturesVertex(TexturesBase): def __init__( @@ -1292,6 +1313,13 @@ class TexturesVertex(TexturesBase): new_props = self._getitem(index, props) verts_features = new_props["verts_features_list"] if isinstance(verts_features, list): + # Handle the case of an empty list + if len(verts_features) == 0: + verts_features = torch.empty( + size=(0, 0, 3), + dtype=torch.float32, + device=self.verts_features_padded().device, + ) new_tex = self.__class__(verts_features=verts_features) elif torch.is_tensor(verts_features): new_tex = self.__class__(verts_features=[verts_features]) @@ -1410,3 +1438,12 @@ class TexturesVertex(TexturesBase): Return a new TexturesVertex amalgamating the batch. """ return self.__class__(verts_features=[torch.cat(self.verts_features_list())]) + + def check_shapes( + self, batch_size: int, max_num_verts: int, max_num_faces: int + ) -> bool: + """ + Check if the dimensions of the verts features match that of the mesh verts + """ + # (N, V) should be the same + return self.verts_features_padded().shape[:-1] == (batch_size, max_num_verts) diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index ac88cd05..92ba4da1 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -255,6 +255,7 @@ class Meshes: if textures is not None and not hasattr(textures, "sample_textures"): msg = "Expected textures to be an instance of type TexturesBase; got %r" raise ValueError(msg % type(textures)) + self.textures = textures # Indicates whether the meshes in the list/batch have the same number @@ -424,10 +425,14 @@ class Meshes: ) # Set the num verts/faces on the textures if present. - if self.textures is not None: + if textures is not None: + shape_ok = self.textures.check_shapes(self._N, self._V, self._F) + if not shape_ok: + msg = "Textures do not match the dimensions of Meshes." + raise ValueError(msg) + self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist() self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist() - self.textures._N = self._N self.textures.valid = self.valid if verts_normals is not None: @@ -1560,6 +1565,13 @@ class Meshes: def sample_textures(self, fragments): if self.textures is not None: + + # Check dimensions of textures match that of meshes + shape_ok = self.textures.check_shapes(self._N, self._V, self._F) + if not shape_ok: + msg = "Textures do not match the dimensions of Meshes." + raise ValueError(msg) + # Pass in faces packed. If the textures are defined per # vertex, the face indices are needed in order to interpolate # the vertex attributes across the face. diff --git a/tests/test_texturing.py b/tests/test_texturing.py index c82377c4..ceb0dc0d 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -251,7 +251,7 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase): def test_getitem(self): N = 5 V = 20 - source = {"verts_features": torch.randn(size=(N, 10, 128))} + source = {"verts_features": torch.randn(size=(N, V, 128))} tex = TexturesVertex(verts_features=source["verts_features"]) verts = torch.rand(size=(N, V, 3)) @@ -268,6 +268,30 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase): tryindex(self, index, tex, meshes, source) tryindex(self, [2, 4], tex, meshes, source) + def test_sample_textures_error(self): + N = 5 + V = 20 + verts = torch.rand(size=(N, V, 3)) + faces = torch.randint(size=(N, 10, 3), high=V) + tex = TexturesVertex(verts_features=torch.randn(size=(N, 10, 128))) + + # Verts features have the wrong number of verts + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + # Verts features have the wrong batch dim + tex = TexturesVertex(verts_features=torch.randn(size=(1, V, 128))) + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + meshes = Meshes(verts=verts, faces=faces) + meshes.textures = tex + + # Cannot use the texture attribute set on meshes for sampling + # textures if the dimensions don't match + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + meshes.sample_textures(None) + class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): def test_sample_texture_atlas(self): @@ -456,11 +480,12 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): def test_getitem(self): N = 5 V = 20 - source = {"atlas": torch.randn(size=(N, 10, 4, 4, 3))} + F = 10 + source = {"atlas": torch.randn(size=(N, F, 4, 4, 3))} tex = TexturesAtlas(atlas=source["atlas"]) verts = torch.rand(size=(N, V, 3)) - faces = torch.randint(size=(N, 10, 3), high=V) + faces = torch.randint(size=(N, F, 3), high=V) meshes = Meshes(verts=verts, faces=faces, textures=tex) tryindex(self, 2, tex, meshes, source) @@ -473,6 +498,32 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): tryindex(self, index, tex, meshes, source) tryindex(self, [2, 4], tex, meshes, source) + def test_sample_textures_error(self): + N = 1 + V = 20 + F = 10 + verts = torch.rand(size=(5, V, 3)) + faces = torch.randint(size=(5, F, 3), high=V) + meshes = Meshes(verts=verts, faces=faces) + + # TexturesAtlas have the wrong batch dim + tex = TexturesAtlas(atlas=torch.randn(size=(1, F, 4, 4, 3))) + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + # TexturesAtlas have the wrong number of faces + tex = TexturesAtlas(atlas=torch.randn(size=(N, 15, 4, 4, 3))) + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + meshes = Meshes(verts=verts, faces=faces) + meshes.textures = tex + + # Cannot use the texture attribute set on meshes for sampling + # textures if the dimensions don't match + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + meshes.sample_textures(None) + class TestTexturesUV(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: @@ -824,9 +875,10 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): def test_getitem(self): N = 5 V = 20 + F = 10 source = { "maps": torch.rand(size=(N, 1, 1, 3)), - "faces_uvs": torch.randint(size=(N, 10, 3), high=V), + "faces_uvs": torch.randint(size=(N, F, 3), high=V), "verts_uvs": torch.randn(size=(N, V, 2)), } tex = TexturesUV( @@ -836,7 +888,7 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): ) verts = torch.rand(size=(N, V, 3)) - faces = torch.randint(size=(N, 10, 3), high=V) + faces = torch.randint(size=(N, F, 3), high=V) meshes = Meshes(verts=verts, faces=faces, textures=tex) tryindex(self, 2, tex, meshes, source) @@ -858,6 +910,46 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): expected = torch.FloatTensor([[32, 224], [64, 96], [64, 128]]) self.assertClose(tex.centers_for_image(0), expected) + def test_sample_textures_error(self): + N = 1 + V = 20 + F = 10 + maps = torch.rand(size=(N, 1, 1, 3)) + verts_uvs = torch.randn(size=(N, V, 2)) + tex = TexturesUV( + maps=maps, + faces_uvs=torch.randint(size=(N, 15, 3), high=V), + verts_uvs=verts_uvs, + ) + verts = torch.rand(size=(5, V, 3)) + faces = torch.randint(size=(5, 10, 3), high=V) + meshes = Meshes(verts=verts, faces=faces) + + # Wrong number of faces + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + # Wrong batch dim for faces + tex = TexturesUV( + maps=maps, + faces_uvs=torch.randint(size=(1, F, 3), high=V), + verts_uvs=verts_uvs, + ) + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + Meshes(verts=verts, faces=faces, textures=tex) + + # Wrong batch dim for verts_uvs is not necessary to check as + # there is already a check inside TexturesUV for a batch dim + # mismatch with faces_uvs + + meshes = Meshes(verts=verts, faces=faces) + meshes.textures = tex + + # Cannot use the texture attribute set on meshes for sampling + # textures if the dimensions don't match + with self.assertRaisesRegex(ValueError, "do not match the dimensions"): + meshes.sample_textures(None) + class TestRectanglePacking(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: