From e651a4299caf39dd40710e15044191cb48a3e72b Mon Sep 17 00:00:00 2001 From: Georgia Gkioxari Date: Tue, 6 Oct 2020 10:43:06 -0700 Subject: [PATCH] add texture vertex sampling functionality to textures Summary: Enhance every texture type with `faces_verts_textures_packed` that allows users to query the texture of each vertex in mesh Reviewed By: nikhilaravi Differential Revision: D24058778 fbshipit-source-id: 19d0e3a244fa96aae462c47bf52e07dfd3b7c6f0 --- pytorch3d/io/mtl_io.py | 2 +- pytorch3d/renderer/mesh/textures.py | 86 +++++++++++++++++++++++++++- tests/test_texturing.py | 89 +++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+), 4 deletions(-) diff --git a/pytorch3d/io/mtl_io.py b/pytorch3d/io/mtl_io.py index 59a27dc1..f194b976 100644 --- a/pytorch3d/io/mtl_io.py +++ b/pytorch3d/io/mtl_io.py @@ -129,7 +129,7 @@ def make_material_atlas( the formulation from [1]. For a triangle with vertices (v0, v1, v2) we can create a barycentric coordinate system - with the x axis being the vector (v1 - v0) and the y axis being the vector (v2 - v0). + with the x axis being the vector (v0 - v2) and the y axis being the vector (v1 - v2). The barycentric coordinates range from [0, 1] in the +x and +y direction so this creates a triangular texture space with vertices at (0, 1), (0, 0) and (1, 0). diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index af3ebf86..b87638b8 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -236,6 +236,17 @@ class TexturesBase: """ raise NotImplementedError() + def faces_verts_textures_packed(self): + """ + Returns the texture for each vertex for each face in the mesh. + For N meshes, this function returns sum(Fi)x3xC where Fi is the + number of faces in the i-th mesh and C is the dimensional of + the feature (C = 3 for RGB textures). + You can use the utils function in structures.utils to convert the + packed respresentation to a list or padded. + """ + raise NotImplementedError() + def clone(self): """ Each texture class should implement a method @@ -286,7 +297,7 @@ def Textures( Returns: a Textures class which is an instance of TexturesBase e.g. TexturesUV, - TexturesAtlas, TexturesVerte + TexturesAtlas, TexturesVertex """ @@ -507,6 +518,23 @@ class TexturesAtlas(TexturesBase): return texels + def faces_verts_textures_packed(self) -> torch.Tensor: + """ + Samples texture from each vertex for each face in the mesh. + For N meshes with {Fi} number of faces, it returns a + tensor of shape sum(Fi)x3xD (D = 3 for RGB). + You can use the utils function in structures.utils to convert the + packed respresentation to a list or padded. + """ + atlas_packed = self.atlas_packed() + # assume each face consists of (v0, v1, v2). + # to sample from the atlas we only need the first two barycentric coordinates. + # for details on how this texture sample works refer to the sample_textures function. + t0 = atlas_packed[:, 0, -1] # corresponding to v0 with bary = (1, 0) + t1 = atlas_packed[:, -1, 0] # corresponding to v1 with bary = (0, 1) + t2 = atlas_packed[:, 0, 0] # corresponding to v2 with bary = (0, 0) + return torch.stack((t0, t1, t2), dim=1) + def join_batch(self, textures: List["TexturesAtlas"]) -> "TexturesAtlas": """ Join the list of textures given by `textures` to @@ -514,10 +542,10 @@ class TexturesAtlas(TexturesBase): TexturesAtlas object with the combined textures. Args: - textures: List of TextureAtlas objects + textures: List of TexturesAtlas objects Returns: - new_tex: TextureAtlas object with the combined + new_tex: TexturesAtlas object with the combined textures from self and the list `textures`. """ tex_types_same = all(isinstance(tex, TexturesAtlas) for tex in textures) @@ -917,6 +945,46 @@ class TexturesUV(TexturesBase): texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) return texels + def faces_verts_textures_packed(self) -> torch.Tensor: + """ + Samples texture from each vertex and for each face in the mesh. + For N meshes with {Fi} number of faces, it returns a + tensor of shape sum(Fi)x3xC (C = 3 for RGB). + You can use the utils function in structures.utils to convert the + packed representation to a list or padded. + """ + if self.isempty(): + return torch.zeros( + (0, 3, self.maps_padded().shape[-1]), + dtype=torch.float32, + device=self.device, + ) + else: + packing_list = [ + i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list()) + ] + faces_verts_uvs = _list_to_padded_wrapper( + packing_list, pad_value=0.0 + ) # Nxmax(Fi)x3x2 + texture_maps = self.maps_padded() # NxHxWxC + texture_maps = texture_maps.permute(0, 3, 1, 2) # NxCxHxW + + faces_verts_uvs = faces_verts_uvs * 2.0 - 1.0 + texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map + + textures = F.grid_sample( + texture_maps, + faces_verts_uvs, + align_corners=self.align_corners, + padding_mode=self.padding_mode, + ) # NxCxmax(Fi)x3 + + textures = textures.permute(0, 2, 3, 1) # Nxmax(Fi)x3xC + textures = _padded_to_list_wrapper( + textures, split_size=self._num_faces_per_mesh + ) # list of N {Fix3xC} tensors + return list_to_packed(textures)[0] + def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV": """ Join the list of textures given by `textures` to @@ -1268,6 +1336,18 @@ class TexturesVertex(TexturesBase): ) return texels + def faces_verts_textures_packed(self, faces_packed=None) -> torch.Tensor: + """ + Samples texture from each vertex and for each face in the mesh. + For N meshes with {Fi} number of faces, it returns a + tensor of shape sum(Fi)x3xC (C = 3 for RGB). + You can use the utils function in structures.utils to convert the + packed respresentation to a list or padded. + """ + verts_features_packed = self.verts_features_packed() + faces_verts_features = verts_features_packed[faces_packed] + return faces_verts_features + def join_batch(self, textures: List["TexturesVertex"]) -> "TexturesVertex": """ Join the list of textures given by `textures` to diff --git a/tests/test_texturing.py b/tests/test_texturing.py index 543291d0..c10b3632 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -112,6 +112,31 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "verts_features"): TexturesVertex(verts_features=(1, 1, 1)) + def test_faces_verts_textures(self): + device = torch.device("cuda:0") + verts = torch.randn((2, 4, 3), dtype=torch.float32, device=device) + faces = torch.tensor( + [[[2, 1, 0], [3, 1, 0]], [[1, 3, 0], [2, 1, 3]]], + dtype=torch.int64, + device=device, + ) + + # define TexturesVertex + verts_texture = torch.rand(verts.shape) + textures = TexturesVertex(verts_features=verts_texture) + + # compute packed faces + ff = faces.unbind(0) + faces_packed = torch.cat([ff[0], ff[1] + verts.shape[1]]) + + # face verts textures + faces_verts_texts = textures.faces_verts_textures_packed(faces_packed) + + verts_texts_packed = torch.cat(verts_texture.unbind(0)) + faces_verts_texts_packed = verts_texts_packed[faces_packed] + + self.assertClose(faces_verts_texts_packed, faces_verts_texts) + def test_clone(self): tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128))) tex.verts_features_list() @@ -303,6 +328,33 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "atlas"): TexturesAtlas(atlas=(1, 1, 1)) + def test_faces_verts_textures(self): + device = torch.device("cuda:0") + N, F, R = 2, 2, 8 + num_faces = torch.randint(low=1, high=F, size=(N,)) + faces_atlas = [ + torch.rand(size=(num_faces[i].item(), R, R, 3), device=device) + for i in range(N) + ] + tex = TexturesAtlas(atlas=faces_atlas) + + # faces_verts naive + faces_verts = [] + for n in range(N): + ff = num_faces[n].item() + temp = torch.zeros(ff, 3, 3) + for f in range(ff): + t0 = faces_atlas[n][f, 0, -1] # for v0, bary = (1, 0) + t1 = faces_atlas[n][f, -1, 0] # for v1, bary = (0, 1) + t2 = faces_atlas[n][f, 0, 0] # for v2, bary = (0, 0) + temp[f, 0] = t0 + temp[f, 1] = t1 + temp[f, 2] = t2 + faces_verts.append(temp) + faces_verts = torch.cat(faces_verts, 0) + + self.assertClose(faces_verts, tex.faces_verts_textures_packed().cpu()) + def test_clone(self): tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3))) tex.atlas_list() @@ -522,6 +574,43 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): verts_uvs=torch.rand(size=(5, 15, 2)), ) + def test_faces_verts_textures(self): + device = torch.device("cuda:0") + N, V, F, H, W = 2, 5, 12, 8, 8 + vert_uvs = torch.rand((N, V, 2), dtype=torch.float32, device=device) + face_uvs = torch.randint( + high=V, size=(N, F, 3), dtype=torch.int64, device=device + ) + maps = torch.rand((N, H, W, 3), dtype=torch.float32, device=device) + + tex = TexturesUV(maps=maps, verts_uvs=vert_uvs, faces_uvs=face_uvs) + + # naive faces_verts_textures + faces_verts_texs = [] + for n in range(N): + temp = torch.zeros((F, 3, 3), device=device, dtype=torch.float32) + for f in range(F): + uv0 = vert_uvs[n, face_uvs[n, f, 0]] + uv1 = vert_uvs[n, face_uvs[n, f, 1]] + uv2 = vert_uvs[n, face_uvs[n, f, 2]] + + idx = torch.stack((uv0, uv1, uv2), dim=0).view(1, 1, 3, 2) # 1x1x3x2 + idx = idx * 2.0 - 1.0 + imap = maps[n].view(1, H, W, 3).permute(0, 3, 1, 2) # 1x3xHxW + imap = torch.flip(imap, [2]) + + texts = torch.nn.functional.grid_sample( + imap, + idx, + align_corners=tex.align_corners, + padding_mode=tex.padding_mode, + ) # 1x3x1x3 + temp[f] = texts[0, :, 0, :].permute(1, 0) + faces_verts_texs.append(temp) + faces_verts_texs = torch.cat(faces_verts_texs, 0) + + self.assertClose(faces_verts_texs, tex.faces_verts_textures_packed()) + def test_clone(self): tex = TexturesUV( maps=torch.ones((5, 16, 16, 3)),