diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index b0a938ed..0d69450f 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -98,13 +98,14 @@ def _padded_to_list_wrapper( def _pad_texture_maps( - images: Union[Tuple[torch.Tensor], List[torch.Tensor]] + images: Union[Tuple[torch.Tensor], List[torch.Tensor]], align_corners: bool ) -> torch.Tensor: """ Pad all texture images so they have the same height and width. Args: - images: list of N tensors of shape (H, W, 3) + images: list of N tensors of shape (H_i, W_i, 3) + align_corners: used for interpolation Returns: tex_maps: Tensor of shape (N, max_H, max_W, 3) @@ -125,7 +126,7 @@ def _pad_texture_maps( if image.shape[:2] != max_shape: image_BCHW = image.permute(2, 0, 1)[None] new_image_BCHW = interpolate( - image_BCHW, size=max_shape, mode="bilinear", align_corners=False + image_BCHW, size=max_shape, mode="bilinear", align_corners=align_corners ) tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0) tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3) @@ -535,6 +536,8 @@ class TexturesUV(TexturesBase): maps: Union[torch.Tensor, List[torch.Tensor]], faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], + padding_mode: str = "border", + align_corners: bool = True, ): """ Textures are represented as a per mesh texture map and uv coordinates for each @@ -543,11 +546,42 @@ class TexturesUV(TexturesBase): Args: maps: texture map per mesh. This can either be a list of maps [(H, W, 3)] or a padded tensor of shape (N, H, W, 3) - faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face + faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs + for each face verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex - (a FloatTensor with values between 0 and 1) + (a FloatTensor with values between 0 and 1). + align_corners: If true, the extreme values 0 and 1 for verts_uvs + indicate the centers of the edge pixels in the maps. + padding_mode: padding mode for outside grid values + ("zeros", "border" or "reflection"). + + The align_corners and padding_mode arguments correspond to the arguments + of the `grid_sample` torch function. There is an informative illustration of + the two align_corners options at + https://discuss.pytorch.org/t/22663/9 . + + An example of how the indexing into the maps, with align_corners=True + works is as follows. + If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j] + is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex + whose color is given by maps[i][700, 40]. padding_mode affects what + happens if a value in verts_uvs is less than 0 or greater than 1. + Note that increasing a value in verts_uvs[..., 0] increases an index + in maps, whereas increasing a value in verts_uvs[..., 1] _decreases_ + an _earlier_ index in maps. + + If align_corners=False, an example would be as follows. + If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j] + is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex + whose color is given by maps[i][700, 40]. + In this case, padding_mode even matters for values in verts_uvs + slightly above 0 or slightly below 1. In this case, it matters if the + first value is outside the interval [0.0005, 0.9995] or if the second + is outside the interval [0.005, 0.995]. """ super().__init__() + self.padding_mode = padding_mode + self.align_corners = align_corners if isinstance(faces_uvs, (list, tuple)): for fv in faces_uvs: # pyre-fixme[16]: `Tensor` has no attribute `ndim`. @@ -632,7 +666,7 @@ class TexturesUV(TexturesBase): raise ValueError("Expected one texture map per mesh in the batch.") self._maps_list = maps if self._N > 0: - maps = _pad_texture_maps(maps) + maps = _pad_texture_maps(maps, align_corners=self.align_corners) else: maps = torch.empty( (self._N, 0, 0, 3), dtype=torch.float32, device=self.device @@ -698,11 +732,19 @@ class TexturesUV(TexturesBase): # if index has multiple values then faces/verts/maps may be a list of tensors if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]): new_tex = self.__class__( - faces_uvs=faces_uvs, verts_uvs=verts_uvs, maps=maps + faces_uvs=faces_uvs, + verts_uvs=verts_uvs, + maps=maps, + padding_mode=self.padding_mode, + align_corners=self.align_corners, ) elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]): new_tex = self.__class__( - faces_uvs=[faces_uvs], verts_uvs=[verts_uvs], maps=[maps] + faces_uvs=[faces_uvs], + verts_uvs=[verts_uvs], + maps=[maps], + padding_mode=self.padding_mode, + align_corners=self.align_corners, ) else: raise ValueError("Not all values are provided in the correct format") @@ -785,6 +827,8 @@ class TexturesUV(TexturesBase): maps=new_props["maps_padded"], faces_uvs=new_props["faces_uvs_padded"], verts_uvs=new_props["verts_uvs_padded"], + padding_mode=self.padding_mode, + align_corners=self.align_corners, ) new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"] @@ -859,7 +903,12 @@ class TexturesUV(TexturesBase): texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map if texture_maps.device != pixel_uvs.device: texture_maps = texture_maps.to(pixel_uvs.device) - texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False) + texels = F.grid_sample( + texture_maps, + pixel_uvs, + align_corners=self.align_corners, + padding_mode=self.padding_mode, + ) # texels now has shape (NK, C, H_out, W_out) texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) return texels @@ -881,6 +930,17 @@ class TexturesUV(TexturesBase): if not tex_types_same: raise ValueError("All textures must be of type TexturesUV.") + padding_modes_same = all( + tex.padding_mode == self.padding_mode for tex in textures + ) + if not padding_modes_same: + raise ValueError("All textures must have the same padding_mode.") + align_corners_same = all( + tex.align_corners == self.align_corners for tex in textures + ) + if not align_corners_same: + raise ValueError("All textures must have the same align_corners value.") + verts_uvs_list = [] faces_uvs_list = [] maps_list = [] @@ -896,7 +956,11 @@ class TexturesUV(TexturesBase): maps_list += tex_map_list new_tex = self.__class__( - maps=maps_list, verts_uvs=verts_uvs_list, faces_uvs=faces_uvs_list + maps=maps_list, + verts_uvs=verts_uvs_list, + faces_uvs=faces_uvs_list, + padding_mode=self.padding_mode, + align_corners=self.align_corners, ) new_tex._num_faces_per_mesh = num_faces_per_mesh return new_tex diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 908839ee..c621b6a7 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -1505,9 +1505,13 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True): Merge multiple Meshes objects, i.e. concatenate the meshes objects. They must all be on the same device. If include_textures is true, they must all be compatible, either all or none having textures, and all the Textures - objects having the same members. If include_textures is False, textures are + objects being the same type. If include_textures is False, textures are ignored. + If the textures are TexturesAtlas then being the same type includes having + the same resolution. If they are TexturesUV then it includes having the same + align_corners and padding_mode. + Args: meshes: list of meshes. include_textures: (bool) whether to try to join the textures. diff --git a/tests/data/test_blurry_textured_rendering.png b/tests/data/test_blurry_textured_rendering.png index 30a870ad..a71726b0 100644 Binary files a/tests/data/test_blurry_textured_rendering.png and b/tests/data/test_blurry_textured_rendering.png differ diff --git a/tests/data/test_texture_map_back.png b/tests/data/test_texture_map_back.png index 4052ce99..985d6962 100644 Binary files a/tests/data/test_texture_map_back.png and b/tests/data/test_texture_map_back.png differ diff --git a/tests/data/test_texture_map_front.png b/tests/data/test_texture_map_front.png index 44156560..e280182e 100644 Binary files a/tests/data/test_texture_map_front.png and b/tests/data/test_texture_map_front.png differ diff --git a/tests/test_meshes.py b/tests/test_meshes.py index 2117fcb9..7a6446f1 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -373,10 +373,10 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): self.assertFalse(new_mesh.verts_packed().requires_grad) self.assertClose(new_mesh.verts_packed(), mesh.verts_packed()) - self.assertTrue(new_mesh.verts_padded().requires_grad == False) + self.assertFalse(new_mesh.verts_padded().requires_grad) self.assertClose(new_mesh.verts_padded(), mesh.verts_padded()) for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()): - self.assertTrue(newv.requires_grad == False) + self.assertFalse(newv.requires_grad) self.assertClose(newv, v) def test_laplacian_packed(self): diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py index e1b77d2e..701254c1 100644 --- a/tests/test_pointclouds.py +++ b/tests/test_pointclouds.py @@ -411,11 +411,11 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): new_clouds = clouds.detach() for cloud in new_clouds.points_list(): - self.assertTrue(cloud.requires_grad == False) + self.assertFalse(cloud.requires_grad) for normal in new_clouds.normals_list(): - self.assertTrue(normal.requires_grad == False) + self.assertFalse(normal.requires_grad) for feats in new_clouds.features_list(): - self.assertTrue(feats.requires_grad == False) + self.assertFalse(feats.requires_grad) for attrib in [ "points_packed", @@ -425,9 +425,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): "normals_padded", "features_padded", ]: - self.assertTrue( - getattr(new_clouds, attrib)().requires_grad == False - ) + self.assertFalse(getattr(new_clouds, attrib)().requires_grad) self.assertCloudsEqual(clouds, new_clouds) diff --git a/tests/test_texturing.py b/tests/test_texturing.py index 3ed2950d..5e847073 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -443,19 +443,26 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): dists=pix_to_face, ) - tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs]) - meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex) - mesh_textures = meshes.textures - texels = mesh_textures.sample_textures(fragments) + for align_corners in [True, False]: + tex = TexturesUV( + maps=tex_map, + faces_uvs=[face_uvs], + verts_uvs=[vert_uvs], + align_corners=align_corners, + ) + meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex) + mesh_textures = meshes.textures + texels = mesh_textures.sample_textures(fragments) - # Expected output - pixel_uvs = interpolated_uvs * 2.0 - 1.0 - pixel_uvs = pixel_uvs.view(2, 1, 1, 2) - tex_map = torch.flip(tex_map, [1]) - tex_map = tex_map.permute(0, 3, 1, 2) - tex_map = torch.cat([tex_map, tex_map], dim=0) - expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False) - self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze())) + # Expected output + pixel_uvs = interpolated_uvs * 2.0 - 1.0 + pixel_uvs = pixel_uvs.view(2, 1, 1, 2) + tex_map_ = torch.flip(tex_map, [1]).permute(0, 3, 1, 2) + tex_map_ = torch.cat([tex_map_, tex_map_], dim=0) + expected_out = F.grid_sample( + tex_map_, pixel_uvs, align_corners=align_corners, padding_mode="border" + ) + self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze())) def test_textures_uv_init_fail(self): # Maps has wrong shape