diff --git a/pytorch3d/csrc/utils/pytorch3d_cutils.h b/pytorch3d/csrc/utils/pytorch3d_cutils.h index c88b7c53..f5f0c4a8 100644 --- a/pytorch3d/csrc/utils/pytorch3d_cutils.h +++ b/pytorch3d/csrc/utils/pytorch3d_cutils.h @@ -3,9 +3,9 @@ #pragma once #include -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.") +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor.") #define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.") + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.") #define CHECK_CONTIGUOUS_CUDA(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index dea66c15..2de233d0 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -7,10 +7,10 @@ from typing import Dict, List, Optional import torch from pytorch3d.io import load_objs_as_meshes from pytorch3d.renderer import ( + FoVPerspectiveCameras, HardPhongShader, MeshRasterizer, MeshRenderer, - FoVPerspectiveCameras, PointLights, RasterizationSettings, TexturesVertex, diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 1eb184fe..3fbdd5c5 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -45,7 +45,10 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: # Mask for the background. is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W) - background_color = colors.new_tensor(blend_params.background_color) # (3) + if torch.is_tensor(blend_params.background_color): + background_color = blend_params.background_color + else: + background_color = colors.new_tensor(blend_params.background_color) # (3) # Find out how much background_color needs to be expanded to be used for masked_scatter. num_background_pixels = is_background.sum() diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index dab5aaff..d06e0619 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -137,7 +137,7 @@ def _pad_texture_maps( # This is also useful to have so that inside `Meshes` # we can allow the input textures to be any texture # type which is an instance of the base class. -class TexturesBase(object): +class TexturesBase: def __init__(self): self._N = 0 self.valid = None @@ -262,9 +262,6 @@ class TexturesBase(object): """ raise NotImplementedError() - def __repr__(self): - return "TexturesBase" - def Textures( maps: Union[List, torch.Tensor, None] = None, @@ -385,14 +382,6 @@ class TexturesAtlas(TexturesBase): # refer to the __init__ of Meshes. self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) - # This is a hack to allow the child classes to also have the same representation - # as the parent. In meshes.py we check that the input textures have the correct - # type. However due to circular imports issues, we can't import the texture - # classes into any files in pytorch3d.structures. Instead we check - # for repr(textures) == "TexturesBase". - def __repr__(self): - return super().__repr__() - def clone(self): tex = self.__class__(atlas=self.atlas_padded().clone()) if self._atlas_list is not None: @@ -556,10 +545,7 @@ class TexturesUV(TexturesBase): [(H, W, 3)] or a padded tensor of shape (N, H, W, 3) faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex - - Note: only the padded and list representation of the textures are stored - and the packed representations is computed on the fly and - not cached. + (a FloatTensor with values between 0 and 1) """ super().__init__() if isinstance(faces_uvs, (list, tuple)): @@ -611,9 +597,6 @@ class TexturesUV(TexturesBase): "verts_uvs and faces_uvs must have the same batch dimension" ) if not all(v.device == self.device for v in verts_uvs): - import pdb - - pdb.set_trace() raise ValueError("verts_uvs and faces_uvs must be on the same device") # These values may be overridden when textures is @@ -669,9 +652,6 @@ class TexturesUV(TexturesBase): self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) - def __repr__(self): - return super().__repr__() - def clone(self): tex = self.__class__( self.maps_padded().clone(), @@ -759,12 +739,6 @@ class TexturesUV(TexturesBase): ) return self._faces_uvs_list - def faces_uvs_packed(self) -> torch.Tensor: - if self.isempty(): - return torch.zeros((self._N, 3), dtype=torch.float32, device=self.device) - faces_uvs_list = self.faces_uvs_list() - return list_to_packed(faces_uvs_list)[0] - def verts_uvs_padded(self) -> torch.Tensor: if self._verts_uvs_padded is None: if self.isempty(): @@ -789,12 +763,6 @@ class TexturesUV(TexturesBase): ) return self._verts_uvs_list - def verts_uvs_packed(self) -> torch.Tensor: - if self.isempty(): - return torch.zeros((self._N, 2), dtype=torch.float32, device=self.device) - verts_uvs_list = self.verts_uvs_list() - return list_to_packed(verts_uvs_list)[0] - # Currently only the padded maps are used. def maps_padded(self) -> torch.Tensor: return self._maps_padded @@ -850,9 +818,15 @@ class TexturesUV(TexturesBase): texels: tensor of shape (N, H, W, K, C) giving the interpolated texture for each pixel in the rasterized image. """ - verts_uvs = self.verts_uvs_packed() - faces_uvs = self.faces_uvs_packed() - faces_verts_uvs = verts_uvs[faces_uvs] + if self.isempty(): + faces_verts_uvs = torch.zeros( + (self._N, 3, 2), dtype=torch.float32, device=self.device + ) + else: + packing_list = [ + i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list()) + ] + faces_verts_uvs = torch.cat(packing_list) texture_maps = self.maps_padded() # pixel_uvs: (N, H, W, K, 2) @@ -890,6 +864,7 @@ class TexturesUV(TexturesBase): if texture_maps.device != pixel_uvs.device: texture_maps = texture_maps.to(pixel_uvs.device) texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False) + # texels now has shape (NK, C, H_out, W_out) texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) return texels @@ -990,9 +965,6 @@ class TexturesVertex(TexturesBase): # refer to the __init__ of Meshes. self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device) - def __repr__(self): - return super().__repr__() - def clone(self): tex = self.__class__(self.verts_features_padded().clone()) if self._verts_features_list is not None: @@ -1048,7 +1020,7 @@ class TexturesVertex(TexturesBase): if self._verts_features_list is None: if self.isempty(): self._verts_features_list = [ - torch.empty((0, 3, 0), dtype=torch.float32, device=self.device) + torch.empty((0, 3), dtype=torch.float32, device=self.device) ] * self._N else: self._verts_features_list = padded_to_list( diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 4e4c1b0a..908839ee 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -233,9 +233,9 @@ class Meshes(object): Refer to comments above for descriptions of List and Padded representations. """ self.device = None - if textures is not None and not repr(textures) == "TexturesBase": + if textures is not None and not hasattr(textures, "sample_textures"): msg = "Expected textures to be an instance of type TexturesBase; got %r" - raise ValueError(msg % repr(textures)) + raise ValueError(msg % type(textures)) self.textures = textures # Indicates whether the meshes in the list/batch have the same number diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 94e49862..e0535846 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -33,8 +33,9 @@ from pytorch3d.renderer.mesh.shader import ( SoftSilhouetteShader, TexturedSoftPhongShader, ) -from pytorch3d.structures.meshes import Meshes, join_mesh +from pytorch3d.structures.meshes import Meshes, join_mesh, join_meshes_as_batch from pytorch3d.utils.ico_sphere import ico_sphere +from pytorch3d.utils.torus import torus # If DEBUG=True, save out images generated in the tests for debugging. @@ -490,6 +491,86 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): self.assertClose(rgb, image_ref, atol=0.05) + def test_batch_uvs(self): + """Test that two random tori with TexturesUV render the same as each individually.""" + torch.manual_seed(1) + device = torch.device("cuda:0") + plain_torus = torus(r=1, R=4, sides=10, rings=10, device=device) + [verts] = plain_torus.verts_list() + [faces] = plain_torus.faces_list() + nocolor = torch.zeros((100, 100), device=device) + color_gradient = torch.linspace(0, 1, steps=100, device=device) + color_gradient1 = color_gradient[None].expand_as(nocolor) + color_gradient2 = color_gradient[:, None].expand_as(nocolor) + colors1 = torch.stack([nocolor, color_gradient1, color_gradient2], dim=2) + colors2 = torch.stack([color_gradient1, color_gradient2, nocolor], dim=2) + verts_uvs1 = torch.rand(size=(verts.shape[0], 2), device=device) + verts_uvs2 = torch.rand(size=(verts.shape[0], 2), device=device) + + textures1 = TexturesUV( + maps=[colors1], faces_uvs=[faces], verts_uvs=[verts_uvs1] + ) + textures2 = TexturesUV( + maps=[colors2], faces_uvs=[faces], verts_uvs=[verts_uvs2] + ) + mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1) + mesh2 = Meshes(verts=[verts], faces=[faces], textures=textures2) + mesh_both = join_meshes_as_batch([mesh1, mesh2]) + + R, T = look_at_view_transform(10, 10, 0) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) + + raster_settings = RasterizationSettings( + image_size=128, blur_radius=0.0, faces_per_pixel=1 + ) + + # Init shader settings + lights = PointLights(device=device) + lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None] + + blend_params = BlendParams( + sigma=1e-1, + gamma=1e-4, + background_color=torch.tensor([1.0, 1.0, 1.0], device=device), + ) + # Init renderer + renderer = MeshRenderer( + rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), + shader=HardPhongShader( + device=device, lights=lights, cameras=cameras, blend_params=blend_params + ), + ) + + outputs = [] + for meshes in [mesh_both, mesh1, mesh2]: + outputs.append(renderer(meshes)) + + if DEBUG: + Image.fromarray( + (outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_batch_uvs0.png") + Image.fromarray( + (outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_batch_uvs1.png") + Image.fromarray( + (outputs[0][1, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_batch_uvs2.png") + Image.fromarray( + (outputs[2][0, ..., :3].cpu().numpy() * 255).astype(np.uint8) + ).save(DATA_DIR / "test_batch_uvs3.png") + + diff = torch.abs(outputs[0][0, ..., :3] - outputs[1][0, ..., :3]) + Image.fromarray(((diff > 1e-5).cpu().numpy().astype(np.uint8) * 255)).save( + DATA_DIR / "test_batch_uvs01.png" + ) + diff = torch.abs(outputs[0][1, ..., :3] - outputs[2][0, ..., :3]) + Image.fromarray(((diff > 1e-5).cpu().numpy().astype(np.uint8) * 255)).save( + DATA_DIR / "test_batch_uvs23.png" + ) + + self.assertClose(outputs[0][0, ..., :3], outputs[1][0, ..., :3], atol=1e-5) + self.assertClose(outputs[0][1, ..., :3], outputs[2][0, ..., :3], atol=1e-5) + def test_joined_spheres(self): """ Test a list of Meshes can be joined as a single mesh and diff --git a/tests/test_texturing.py b/tests/test_texturing.py index bcab7ac6..44d03ae4 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -29,8 +29,8 @@ def tryindex(self, index, tex, meshes, source): basic = basic[None] if len(basic) == 0: - self.assertEquals(len(from_texture), 0) - self.assertEquals(len(from_meshes), 0) + self.assertEqual(len(from_texture), 0) + self.assertEqual(len(from_meshes), 0) else: self.assertClose(basic, from_texture) self.assertClose(basic, from_meshes) @@ -608,12 +608,8 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): [ tex_init.faces_uvs_padded(), new_tex.faces_uvs_padded(), - tex_init.faces_uvs_packed(), - new_tex.faces_uvs_packed(), tex_init.verts_uvs_padded(), new_tex.verts_uvs_padded(), - tex_init.verts_uvs_packed(), - new_tex.verts_uvs_packed(), tex_init.maps_padded(), new_tex.maps_padded(), ] @@ -646,11 +642,9 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): tex1 = tex.clone() tex1._num_faces_per_mesh = num_faces_per_mesh tex1._num_verts_per_mesh = num_verts_per_mesh - verts_packed = tex1.verts_uvs_packed() verts_list = tex1.verts_uvs_list() verts_padded = tex1.verts_uvs_padded() - faces_packed = tex1.faces_uvs_packed() faces_list = tex1.faces_uvs_list() faces_padded = tex1.faces_uvs_padded() @@ -660,9 +654,7 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): for f1, f2 in zip(verts_list, verts_uvs_list): self.assertTrue((f1 == f2).all().item()) - self.assertTrue(faces_packed.shape == (3 + 2, 3)) self.assertTrue(faces_padded.shape == (2, 3, 3)) - self.assertTrue(verts_packed.shape == (9 + 6, 2)) self.assertTrue(verts_padded.shape == (2, 9, 2)) # Case where num_faces_per_mesh is not set and faces_verts_uvs @@ -672,16 +664,9 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): verts_uvs=verts_padded, faces_uvs=faces_padded, ) - faces_packed = tex2.faces_uvs_packed() faces_list = tex2.faces_uvs_list() - verts_packed = tex2.verts_uvs_packed() verts_list = tex2.verts_uvs_list() - # Packed is just flattened padded as num_faces_per_mesh - # has not been provided. - self.assertTrue(faces_packed.shape == (3 * 2, 3)) - self.assertTrue(verts_packed.shape == (9 * 2, 2)) - for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)): n = num_faces_per_mesh[i] self.assertTrue((f1[:n] == f2).all().item())