diff --git a/pytorch3d/renderer/mesh/texturing.py b/pytorch3d/renderer/mesh/texturing.py index fbcae3f8..1891e3fa 100644 --- a/pytorch3d/renderer/mesh/texturing.py +++ b/pytorch3d/renderer/mesh/texturing.py @@ -107,7 +107,9 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor: There will be one C dimensional value for each element in fragments.pix_to_face. """ - vertex_textures = meshes.textures.verts_rgb_padded().view(-1, 3) # (V, C) + vertex_textures = meshes.textures.verts_rgb_padded().reshape( + -1, 3 + ) # (V, C) vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :] faces_packed = meshes.faces_packed() faces_textures = vertex_textures[faces_packed] # (F, 3, C) diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 7f0effca..20d3bfd2 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -223,27 +223,32 @@ class TensorProperties(object): self with all properties reshaped. e.g. a property with shape (N, 3) is transformed to shape (B, 3). """ + # Iterate through the attributes of the class which are tensors. for k in dir(self): v = getattr(self, k) if torch.is_tensor(v): if v.shape[0] > 1: # There are different values for each batch element - # so gather these using the batch_idx - idx_dims = batch_idx.shape + # so gather these using the batch_idx. + # First clone the input batch_idx tensor before + # modifying it. + _batch_idx = batch_idx.clone() + idx_dims = _batch_idx.shape tensor_dims = v.shape if len(idx_dims) > len(tensor_dims): msg = "batch_idx cannot have more dimensions than %s. " msg += "got shape %r and %s has shape %r" raise ValueError(msg % (k, idx_dims, k, tensor_dims)) if idx_dims != tensor_dims: - # To use torch.gather the index tensor (batch_idx) has + # To use torch.gather the index tensor (_batch_idx) has # to have the same shape as the input tensor. new_dims = len(tensor_dims) - len(idx_dims) new_shape = idx_dims + (1,) * new_dims expand_dims = (-1,) + tensor_dims[1:] - batch_idx = batch_idx.view(*new_shape) - batch_idx = batch_idx.expand(*expand_dims) - v = v.gather(0, batch_idx) + _batch_idx = _batch_idx.view(*new_shape) + _batch_idx = _batch_idx.expand(*expand_dims) + + v = v.gather(0, _batch_idx) setattr(self, k, v) return self diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 82741e5a..0060deb4 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -324,14 +324,14 @@ class Meshes(object): ) if self._N > 0: self.device = self._verts_list[0].device - num_verts_per_mesh = torch.tensor( + self._num_verts_per_mesh = torch.tensor( [len(v) for v in self._verts_list], device=self.device ) - self._V = num_verts_per_mesh.max() - num_faces_per_mesh = torch.tensor( + self._V = self._num_verts_per_mesh.max() + self._num_faces_per_mesh = torch.tensor( [len(f) for f in self._faces_list], device=self.device ) - self._F = num_faces_per_mesh.max() + self._F = self._num_faces_per_mesh.max() self.valid = torch.tensor( [ len(v) > 0 and len(f) > 0 @@ -341,8 +341,8 @@ class Meshes(object): device=self.device, ) - if (len(num_verts_per_mesh.unique()) == 1) and ( - len(num_faces_per_mesh.unique()) == 1 + if (len(self._num_verts_per_mesh.unique()) == 1) and ( + len(self._num_faces_per_mesh.unique()) == 1 ): self.equisized = True @@ -355,6 +355,7 @@ class Meshes(object): self._faces_padded = faces.to(torch.int64) self._N = self._verts_padded.shape[0] self._V = self._verts_padded.shape[1] + self.device = self._verts_padded.device self.valid = torch.zeros( (self._N,), dtype=torch.bool, device=self.device @@ -363,18 +364,25 @@ class Meshes(object): # Check that padded faces - which have value -1 - are at the # end of the tensors faces_not_padded = self._faces_padded.gt(-1).all(2) - num_faces = faces_not_padded.sum(1) + self._num_faces_per_mesh = faces_not_padded.sum(1) if (faces_not_padded[:, :-1] < faces_not_padded[:, 1:]).any(): raise ValueError("Padding of faces must be at the end") # NOTE that we don't check for the ordering of padded verts # as long as the faces index correspond to the right vertices. - self.valid = num_faces > 0 - self._F = num_faces.max() - if len(num_faces.unique()) == 1: + self.valid = self._num_faces_per_mesh > 0 + self._F = self._num_faces_per_mesh.max() + if len(self._num_faces_per_mesh.unique()) == 1: self.equisized = True + self._num_verts_per_mesh = torch.full( + size=(self._N,), + fill_value=self._V, + dtype=torch.int64, + device=self.device, + ) + else: raise ValueError( "Verts and Faces must be either a list or a tensor with \ @@ -382,6 +390,23 @@ class Meshes(object): number of verts or faces respectively." ) + if self.isempty(): + self._num_verts_per_mesh = torch.zeros( + (0,), dtype=torch.int64, device=self.device + ) + self._num_faces_per_mesh = torch.zeros( + (0,), dtype=torch.int64, device=self.device + ) + + # Set the num verts/faces on the textures if present. + if self.textures is not None: + self.textures._num_faces_per_mesh = ( + self._num_faces_per_mesh.tolist() + ) + self.textures._num_verts_per_mesh = ( + self._num_verts_per_mesh.tolist() + ) + def __len__(self): return self._N @@ -893,11 +918,9 @@ class Meshes(object): self._verts_packed, self._verts_packed_to_mesh_idx, self._mesh_to_verts_packed_first_idx, - self._num_verts_per_mesh, self._faces_packed, self._faces_packed_to_mesh_idx, self._mesh_to_faces_packed_first_idx, - self._num_faces_per_mesh, ] ) ): @@ -920,7 +943,6 @@ class Meshes(object): self._num_verts_per_mesh = torch.zeros( (0,), dtype=torch.int64, device=self.device ) - self._faces_packed = -torch.ones( (0, 3), dtype=torch.int64, device=self.device ) @@ -1354,6 +1376,7 @@ class Meshes(object): tex = None if self.textures is not None: tex = self.textures.extend(N) + return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex) diff --git a/pytorch3d/structures/textures.py b/pytorch3d/structures/textures.py index 3ed0a1bd..94102ce3 100644 --- a/pytorch3d/structures/textures.py +++ b/pytorch3d/structures/textures.py @@ -4,7 +4,7 @@ from typing import List, Optional, Union import torch import torchvision.transforms as T -from .utils import list_to_packed, padded_to_list +from .utils import padded_to_list, padded_to_packed """ @@ -92,14 +92,19 @@ class Textures(object): faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each vertex in the face. Padding value is assumed to be -1. verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex. - verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. + verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding + value is assumed to be -1. + + Note: only the padded representation of the textures is stored + and the packed/list representations are computed on the fly and + not cached. """ if faces_uvs is not None and faces_uvs.ndim != 3: msg = "Expected faces_uvs to be of shape (N, F, 3); got %r" raise ValueError(msg % repr(faces_uvs.shape)) if verts_uvs is not None and verts_uvs.ndim != 3: msg = "Expected verts_uvs to be of shape (N, V, 2); got %r" - raise ValueError(msg % repr(faces_uvs.shape)) + raise ValueError(msg % repr(verts_uvs.shape)) if verts_rgb is not None and verts_rgb.ndim != 3: msg = "Expected verts_rgb to be of shape (N, V, 3); got %r" raise ValueError(msg % repr(verts_rgb.shape)) @@ -109,20 +114,20 @@ class Textures(object): raise ValueError(msg % repr(maps.shape)) elif isinstance(maps, list): maps = _pad_texture_maps(maps) + if faces_uvs is None or verts_uvs is None: + msg = "To use maps, faces_uvs and verts_uvs are required" + raise ValueError(msg) + self._faces_uvs_padded = faces_uvs self._verts_uvs_padded = verts_uvs self._verts_rgb_padded = verts_rgb self._maps_padded = maps - self._num_faces_per_mesh = None - self._set_num_faces_per_mesh() - def _set_num_faces_per_mesh(self) -> None: - """ - Determines and sets the number of textured faces for each mesh. - """ - if self._faces_uvs_padded is not None: - faces_uvs = self._faces_uvs_padded - self._num_faces_per_mesh = faces_uvs.gt(-1).all(-1).sum(-1).tolist() + # The number of faces/verts for each mesh is + # set inside the Meshes object when textures is + # passed into the Meshes constructor. + self._num_faces_per_mesh = None + self._num_verts_per_mesh = None def clone(self): other = Textures() @@ -148,41 +153,67 @@ class Textures(object): setattr(other, key, value[index][None]) else: setattr(other, key, value[index]) - other._set_num_faces_per_mesh() return other def faces_uvs_padded(self) -> torch.Tensor: return self._faces_uvs_padded - def faces_uvs_list(self) -> List[torch.Tensor]: - if self._faces_uvs_padded is not None: - return padded_to_list( - self._faces_uvs_padded, split_size=self._num_faces_per_mesh - ) + def faces_uvs_list(self) -> Union[List[torch.Tensor], None]: + if self._faces_uvs_padded is None: + return None + return padded_to_list( + self._faces_uvs_padded, split_size=self._num_faces_per_mesh + ) - def faces_uvs_packed(self) -> torch.Tensor: - return list_to_packed(self.faces_uvs_list())[0] + def faces_uvs_packed(self) -> Union[torch.Tensor, None]: + if self._faces_uvs_padded is None: + return None + return padded_to_packed( + self._faces_uvs_padded, split_size=self._num_faces_per_mesh + ) - def verts_uvs_padded(self) -> torch.Tensor: + def verts_uvs_padded(self) -> Union[torch.Tensor, None]: return self._verts_uvs_padded - def verts_uvs_list(self) -> List[torch.Tensor]: + def verts_uvs_list(self) -> Union[List[torch.Tensor], None]: + if self._verts_uvs_padded is None: + return None + # Vertices shared between multiple faces + # may have a different uv coordinate for + # each face so the num_verts_uvs_per_mesh + # may be different from num_verts_per_mesh. + # Therefore don't use any split_size. return padded_to_list(self._verts_uvs_padded) - def verts_uvs_packed(self) -> torch.Tensor: - return list_to_packed(self.verts_uvs_list())[0] + def verts_uvs_packed(self) -> Union[torch.Tensor, None]: + if self._verts_uvs_padded is None: + return None + # Vertices shared between multiple faces + # may have a different uv coordinate for + # each face so the num_verts_uvs_per_mesh + # may be different from num_verts_per_mesh. + # Therefore don't use any split_size. + return padded_to_packed(self._verts_uvs_padded) - def verts_rgb_padded(self) -> torch.Tensor: + def verts_rgb_padded(self) -> Union[torch.Tensor, None]: return self._verts_rgb_padded - def verts_rgb_list(self) -> List[torch.Tensor]: - return padded_to_list(self._verts_rgb_padded) + def verts_rgb_list(self) -> Union[List[torch.Tensor], None]: + if self._verts_rgb_padded is None: + return None + return padded_to_list( + self._verts_rgb_padded, split_size=self._num_verts_per_mesh + ) - def verts_rgb_packed(self) -> torch.Tensor: - return list_to_packed(self.verts_rgb_list())[0] + def verts_rgb_packed(self) -> Union[torch.Tensor, None]: + if self._verts_rgb_padded is None: + return None + return padded_to_packed( + self._verts_rgb_padded, split_size=self._num_verts_per_mesh + ) # Currently only the padded maps are used. - def maps_padded(self) -> torch.Tensor: + def maps_padded(self) -> Union[torch.Tensor, None]: return self._maps_padded def extend(self, N: int) -> "Textures": diff --git a/tests/data/test_simple_sphere_light.png b/tests/data/test_simple_sphere_light_phong.png similarity index 100% rename from tests/data/test_simple_sphere_light.png rename to tests/data/test_simple_sphere_light_phong.png diff --git a/tests/data/test_simple_sphere_light_elevated_camera.png b/tests/data/test_simple_sphere_light_phong_elevated_camera.png similarity index 100% rename from tests/data/test_simple_sphere_light_elevated_camera.png rename to tests/data/test_simple_sphere_light_phong_elevated_camera.png diff --git a/tests/test_meshes.py b/tests/test_meshes.py index 57141e83..4d1b9338 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -135,6 +135,15 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): def test_simple(self): mesh = TestMeshes.init_simple_mesh("cuda:0") + # Check that faces/verts per mesh are set in init: + self.assertClose( + mesh._num_faces_per_mesh.cpu(), torch.tensor([1, 2, 7]) + ) + self.assertClose( + mesh._num_verts_per_mesh.cpu(), torch.tensor([3, 4, 5]) + ) + + # Check computed tensors self.assertClose( mesh.verts_packed_to_mesh_idx().cpu(), torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]), @@ -142,9 +151,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): self.assertClose( mesh.mesh_to_verts_packed_first_idx().cpu(), torch.tensor([0, 3, 7]) ) - self.assertClose( - mesh.num_verts_per_mesh().cpu(), torch.tensor([3, 4, 5]) - ) self.assertClose( mesh.verts_padded_to_packed_idx().cpu(), torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]), @@ -156,9 +162,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): self.assertClose( mesh.mesh_to_faces_packed_first_idx().cpu(), torch.tensor([0, 1, 3]) ) - self.assertClose( - mesh.num_faces_per_mesh().cpu(), torch.tensor([1, 2, 7]) - ) self.assertClose( mesh.num_edges_per_mesh().cpu(), torch.tensor([3, 5, 10], dtype=torch.int32), @@ -249,6 +252,8 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): self.assertEqual(mesh.faces_padded().shape[0], 0) self.assertEqual(mesh.verts_packed().shape[0], 0) self.assertEqual(mesh.faces_packed().shape[0], 0) + self.assertEqual(mesh.num_faces_per_mesh().shape[0], 0) + self.assertEqual(mesh.num_verts_per_mesh().shape[0], 0) def test_empty(self): N, V, F = 10, 100, 300 @@ -323,9 +328,11 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): mesh = Meshes(verts=torch.stack(verts), faces=torch.stack(faces)) + # Check verts/faces per mesh are set correctly in init. self.assertListEqual( - mesh.num_faces_per_mesh().tolist(), num_faces.tolist() + mesh._num_faces_per_mesh.tolist(), num_faces.tolist() ) + self.assertListEqual(mesh._num_verts_per_mesh.tolist(), [V] * N) for n, (vv, ff) in enumerate(zip(mesh.verts_list(), mesh.faces_list())): self.assertClose(ff, faces[n][: num_faces[n]]) @@ -364,7 +371,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): mesh._num_verts_per_mesh = torch.randint_like( mesh.num_verts_per_mesh(), high=10 ) - # Check cloned and original Meshes objects do not share tensors. self.assertFalse( torch.allclose(new_mesh._verts_list[0], mesh._verts_list[0]) diff --git a/tests/test_rendering_meshes.py b/tests/test_rendering_meshes.py index 0e5edf5b..853fc17e 100644 --- a/tests/test_rendering_meshes.py +++ b/tests/test_rendering_meshes.py @@ -34,7 +34,7 @@ from pytorch3d.renderer.mesh.texturing import Textures from pytorch3d.structures.meshes import Meshes from pytorch3d.utils.ico_sphere import ico_sphere -# Save out images generated in the tests for debugging +# If DEBUG=True, save out images generated in the tests for debugging. # All saved images have prefix DEBUG_ DEBUG = False DATA_DIR = Path(__file__).resolve().parent / "data" @@ -90,30 +90,31 @@ class TestRenderingMeshes(unittest.TestCase): raster_settings = RasterizationSettings( image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0 ) - - # Init renderer rasterizer = MeshRasterizer( cameras=cameras, raster_settings=raster_settings ) - renderer = MeshRenderer( - rasterizer=rasterizer, - shader=HardPhongShader( - lights=lights, cameras=cameras, materials=materials - ), - ) - images = renderer(sphere_mesh) - rgb = images[0, ..., :3].squeeze().cpu() - if DEBUG: - filename = "DEBUG_simple_sphere_light%s.png" % postfix - Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / filename - ) - # Load reference image - image_ref_phong = load_rgb_image( - "test_simple_sphere_light%s.png" % postfix - ) - self.assertTrue(torch.allclose(rgb, image_ref_phong, atol=0.05)) + # Test several shaders + shaders = { + "phong": HardPhongShader, + "gouraud": HardGouraudShader, + "flat": HardFlatShader, + } + for (name, shader_init) in shaders.items(): + shader = shader_init( + lights=lights, cameras=cameras, materials=materials + ) + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + images = renderer(sphere_mesh) + filename = "simple_sphere_light_%s%s.png" % (name, postfix) + image_ref = load_rgb_image("test_%s" % filename) + rgb = images[0, ..., :3].squeeze().cpu() + if DEBUG: + filename = "DEBUG_" % filename + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / filename + ) + self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05)) ######################################################## # Move the light to the +z axis in world space so it is @@ -121,7 +122,13 @@ class TestRenderingMeshes(unittest.TestCase): # +X left for both world and camera space. ######################################################## lights.location[..., 2] = -2.0 - images = renderer(sphere_mesh, lights=lights) + phong_shader = HardPhongShader( + lights=lights, cameras=cameras, materials=materials + ) + phong_renderer = MeshRenderer( + rasterizer=rasterizer, shader=phong_shader + ) + images = phong_renderer(sphere_mesh, lights=lights) rgb = images[0, ..., :3].squeeze().cpu() if DEBUG: filename = "DEBUG_simple_sphere_dark%s.png" % postfix @@ -135,53 +142,6 @@ class TestRenderingMeshes(unittest.TestCase): ) self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05)) - ###################################### - # Change the shader to a GouraudShader - ###################################### - lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] - renderer = MeshRenderer( - rasterizer=rasterizer, - shader=HardGouraudShader( - lights=lights, cameras=cameras, materials=materials - ), - ) - images = renderer(sphere_mesh) - rgb = images[0, ..., :3].squeeze().cpu() - if DEBUG: - filename = "DEBUG_simple_sphere_light_gouraud%s.png" % postfix - Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / filename - ) - - # Load reference image - image_ref_gouraud = load_rgb_image( - "test_simple_sphere_light_gouraud%s.png" % postfix - ) - self.assertTrue(torch.allclose(rgb, image_ref_gouraud, atol=0.005)) - - ###################################### - # Change the shader to a HardFlatShader - ###################################### - renderer = MeshRenderer( - rasterizer=rasterizer, - shader=HardFlatShader( - lights=lights, cameras=cameras, materials=materials - ), - ) - images = renderer(sphere_mesh) - rgb = images[0, ..., :3].squeeze().cpu() - if DEBUG: - filename = "DEBUG_simple_sphere_light_flat%s.png" % postfix - Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / filename - ) - - # Load reference image - image_ref_flat = load_rgb_image( - "test_simple_sphere_light_flat%s.png" % postfix - ) - self.assertTrue(torch.allclose(rgb, image_ref_flat, atol=0.005)) - def test_simple_sphere_elevated_camera(self): """ Test output of phong and gouraud shading matches a reference image using @@ -193,13 +153,13 @@ class TestRenderingMeshes(unittest.TestCase): def test_simple_sphere_batched(self): """ - Test output of phong shading matches a reference image using - the default values for the light sources. + Test a mesh with vertex textures can be extended to form a batch, and + is rendered correctly with Phong, Gouraud and Flat Shaders. """ - batch_size = 5 + batch_size = 20 device = torch.device("cuda:0") - # Init mesh + # Init mesh with vertex textures. sphere_meshes = ico_sphere(5, device).extend(batch_size) verts_padded = sphere_meshes.verts_padded() faces_padded = sphere_meshes.faces_padded() @@ -224,26 +184,24 @@ class TestRenderingMeshes(unittest.TestCase): lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] # Init renderer - renderer = MeshRenderer( - rasterizer=MeshRasterizer( - cameras=cameras, raster_settings=raster_settings - ), - shader=HardPhongShader( - lights=lights, cameras=cameras, materials=materials - ), + rasterizer = MeshRasterizer( + cameras=cameras, raster_settings=raster_settings ) - images = renderer(sphere_meshes) - - # Load ref image - image_ref = load_rgb_image("test_simple_sphere_light.png") - - for i in range(batch_size): - rgb = images[i, ..., :3].squeeze().cpu() - if DEBUG: - Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / f"DEBUG_simple_sphere_{i}.png" - ) - self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05)) + shaders = { + "phong": HardGouraudShader, + "gouraud": HardGouraudShader, + "flat": HardFlatShader, + } + for (name, shader_init) in shaders.items(): + shader = shader_init( + lights=lights, cameras=cameras, materials=materials + ) + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + images = renderer(sphere_meshes) + image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name) + for i in range(batch_size): + rgb = images[i, ..., :3].squeeze().cpu() + self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05)) def test_silhouette_with_grad(self): """ diff --git a/tests/test_rendering_utils.py b/tests/test_rendering_utils.py index 9f088d34..917d173d 100644 --- a/tests/test_rendering_utils.py +++ b/tests/test_rendering_utils.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import numpy as np import unittest import torch @@ -61,3 +62,28 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase): example = TensorPropertiesTestClass(x=(), y=()) self.assertTrue(len(example) == 0) self.assertTrue(example.isempty()) + + def test_gather_props(self): + N = 4 + x = torch.randn((N, 3, 4)) + y = torch.randn((N, 5)) + test_class = TensorPropertiesTestClass(x=x, y=y) + + S = 15 + idx = torch.tensor(np.random.choice(N, S)) + test_class_gathered = test_class.gather_props(idx) + + self.assertTrue(test_class_gathered.x.shape == (S, 3, 4)) + self.assertTrue(test_class_gathered.y.shape == (S, 5)) + + for i in range(N): + inds = idx == i + if inds.sum() > 0: + # Check the gathered points in the output have the same value from + # the input. + self.assertClose( + test_class_gathered.x[inds].mean(dim=0), x[i, ...] + ) + self.assertClose( + test_class_gathered.y[inds].mean(dim=0), y[i, ...] + ) diff --git a/tests/test_texturing.py b/tests/test_texturing.py index f1cdc6d9..5e235c01 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -12,6 +12,7 @@ from pytorch3d.renderer.mesh.texturing import ( interpolate_vertex_colors, ) from pytorch3d.structures import Meshes, Textures +from pytorch3d.structures.utils import list_to_padded from common_testing import TestCaseMixin from test_meshes import TestMeshes @@ -154,6 +155,108 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): torch.allclose(texels.squeeze(), expected_out.squeeze()) ) + def test_init_rgb_uv_fail(self): + V = 20 + # Maps has wrong shape + with self.assertRaisesRegex(ValueError, "maps"): + Textures( + maps=torch.ones((5, 16, 16, 3, 4)), + faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V), + verts_uvs=torch.ones((5, V, 2)), + ) + # faces_uvs has wrong shape + with self.assertRaisesRegex(ValueError, "faces_uvs"): + Textures( + maps=torch.ones((5, 16, 16, 3)), + faces_uvs=torch.randint(size=(5, 10, 3, 3), low=0, high=V), + verts_uvs=torch.ones((5, V, 2)), + ) + # verts_uvs has wrong shape + with self.assertRaisesRegex(ValueError, "verts_uvs"): + Textures( + maps=torch.ones((5, 16, 16, 3)), + faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V), + verts_uvs=torch.ones((5, V, 2, 3)), + ) + # verts_rgb has wrong shape + with self.assertRaisesRegex(ValueError, "verts_rgb"): + Textures(verts_rgb=torch.ones((5, 16, 16, 3))) + + # maps provided without verts/faces uvs + with self.assertRaisesRegex( + ValueError, "faces_uvs and verts_uvs are required" + ): + Textures(maps=torch.ones((5, 16, 16, 3))) + + def test_padded_to_packed(self): + N = 2 + # Case where each face in the mesh has 3 unique uv vertex indices + # - i.e. even if a vertex is shared between multiple faces it will + # have a unique uv coordinate for each face. + faces_uvs_list = [ + torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]), + torch.tensor([[0, 1, 2], [3, 4, 5]]), + ] # (N, 3, 3) + verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)] + faces_uvs_padded = list_to_padded(faces_uvs_list, pad_value=-1) + verts_uvs_padded = list_to_padded(verts_uvs_list) + tex = Textures( + maps=torch.ones((N, 16, 16, 3)), + faces_uvs=faces_uvs_padded, + verts_uvs=verts_uvs_padded, + ) + + # This is set inside Meshes when textures is passed as an input. + # Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity. + tex1 = tex.clone() + tex1._num_faces_per_mesh = ( + faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist() + ) + tex1._num_verts_per_mesh = torch.tensor([5, 4]) + faces_packed = tex1.faces_uvs_packed() + verts_packed = tex1.verts_uvs_packed() + faces_list = tex1.faces_uvs_list() + verts_list = tex1.verts_uvs_list() + + for f1, f2 in zip(faces_uvs_list, faces_list): + self.assertTrue((f1 == f2).all().item()) + + for f, v1, v2 in zip(faces_list, verts_list, verts_uvs_list): + idx = f.unique() + self.assertTrue((v1[idx] == v2).all().item()) + + self.assertTrue(faces_packed.shape == (3 + 2, 3)) + + # verts_packed is just flattened verts_padded. + # split sizes are not used for verts_uvs. + self.assertTrue(verts_packed.shape == (9 * 2, 2)) + + # Case where num_faces_per_mesh is not set + tex2 = tex.clone() + faces_packed = tex2.faces_uvs_packed() + verts_packed = tex2.verts_uvs_packed() + faces_list = tex2.faces_uvs_list() + verts_list = tex2.verts_uvs_list() + + # Packed is just flattened padded as num_faces_per_mesh + # has not been provided. + self.assertTrue(verts_packed.shape == (9 * 2, 2)) + self.assertTrue(faces_packed.shape == (3 * 2, 3)) + + for i in range(N): + self.assertTrue( + (faces_list[i] == faces_uvs_padded[i, ...].squeeze()) + .all() + .item() + ) + + for i in range(N): + self.assertTrue( + (verts_list[i] == verts_uvs_padded[i, ...].squeeze()) + .all() + .item() + ) + def test_clone(self): V = 20 tex = Textures( @@ -233,13 +336,17 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): mesh = TestMeshes.init_mesh(B, 30, 50) V = mesh._V F = mesh._F - tex = Textures( + + # 1. Texture uvs + tex_uv = Textures( maps=torch.randn((B, 16, 16, 3)), faces_uvs=torch.randint(size=(B, F, 3), low=0, high=V), verts_uvs=torch.randn((B, V, 2)), ) tex_mesh = Meshes( - verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex + verts=mesh.verts_padded(), + faces=mesh.faces_padded(), + textures=tex_uv, ) N = 20 new_mesh = tex_mesh.extend(N) @@ -269,5 +376,43 @@ class TestTexturing(TestCaseMixin, unittest.TestCase): new_tex.maps_padded(), ] ) + + self.assertIsNone(new_tex.verts_rgb_list()) + self.assertIsNone(new_tex.verts_rgb_padded()) + self.assertIsNone(new_tex.verts_rgb_packed()) + + # 2. Texture vertex RGB + tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3))) + tex_mesh_rgb = Meshes( + verts=mesh.verts_padded(), + faces=mesh.faces_padded(), + textures=tex_rgb, + ) + N = 20 + new_mesh_rgb = tex_mesh_rgb.extend(N) + + self.assertEqual(len(tex_mesh_rgb) * N, len(new_mesh_rgb)) + + tex_init = tex_mesh_rgb.textures + new_tex = new_mesh_rgb.textures + + for i in range(len(tex_mesh_rgb)): + for n in range(N): + self.assertClose( + tex_init.verts_rgb_list()[i], + new_tex.verts_rgb_list()[i * N + n], + ) + self.assertAllSeparate( + [tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()] + ) + + self.assertIsNone(new_tex.verts_uvs_padded()) + self.assertIsNone(new_tex.verts_uvs_list()) + self.assertIsNone(new_tex.verts_uvs_packed()) + self.assertIsNone(new_tex.faces_uvs_padded()) + self.assertIsNone(new_tex.faces_uvs_list()) + self.assertIsNone(new_tex.faces_uvs_packed()) + + # 3. Error with self.assertRaises(ValueError): tex_mesh.extend(N=-1)