diff --git a/pytorch3d/structures/textures.py b/pytorch3d/structures/textures.py index fecece4b..fa32ae38 100644 --- a/pytorch3d/structures/textures.py +++ b/pytorch3d/structures/textures.py @@ -3,7 +3,7 @@ from typing import List, Optional, Union import torch -import torchvision.transforms as T +from torch.nn.functional import interpolate from .utils import padded_to_list, padded_to_packed @@ -18,10 +18,10 @@ def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor: Pad all texture images so they have the same height and width. Args: - images: list of N tensors of shape (H, W) + images: list of N tensors of shape (H, W, 3) Returns: - tex_maps: Tensor of shape (N, max_H, max_W) + tex_maps: Tensor of shape (N, max_H, max_W, 3) """ tex_maps = [] max_H = 0 @@ -35,15 +35,13 @@ def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor: tex_maps.append(im) max_shape = (max_H, max_W) - # If all texture images are not the same size then resize to the - # largest size. - resize = T.Compose([T.ToPILImage(), T.Resize(size=max_shape), T.ToTensor()]) - for i, image in enumerate(tex_maps): - if image.shape != max_shape: - # ToPIL takes and returns a C x H x W tensor - image = resize(image.permute(2, 0, 1)).permute(1, 2, 0) - tex_maps[i] = image + if image.shape[:2] != max_shape: + image_BCHW = image.permute(2, 0, 1)[None] + new_image_BCHW = interpolate( + image_BCHW, size=max_shape, mode="bilinear", align_corners=False + ) + tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0) tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3) return tex_maps diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index 1084c7eb..3a40ba7f 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -607,6 +607,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): check_triple(mesh, mesh3) self.assertTupleEqual(mesh.textures.maps_padded().shape, (1, 1024, 1024, 3)) + # Try mismatched texture map sizes, which needs a call to interpolate() + mesh2048 = mesh.clone() + maps = mesh.textures.maps_padded() + mesh2048.textures._maps_padded = torch.cat([maps, maps], dim=1) + join_meshes_as_batch([mesh.to("cuda:0"), mesh2048.to("cuda:0")]) + mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False) mesh3_notex = load_objs_as_meshes( [obj_filename, obj_filename, obj_filename], load_textures=False