From 0c595dcf5b715ea1321825e0fe92ffa74e798d4f Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 1 May 2020 05:17:44 -0700 Subject: [PATCH] Joining mismatched texture maps on CUDA #175 Summary: Use nn.functional.interpolate instead of a TorchVision transform to resize texture maps to a common value. This works on all devices. This fixes issue #175. Also fix the condition so it only happens when needed. Reviewed By: nikhilaravi Differential Revision: D21324510 fbshipit-source-id: c50eb06514984995bd81f2c44079be6e0b4098e4 --- pytorch3d/structures/textures.py | 20 +++++++++----------- tests/test_obj_io.py | 6 ++++++ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/pytorch3d/structures/textures.py b/pytorch3d/structures/textures.py index fecece4b..fa32ae38 100644 --- a/pytorch3d/structures/textures.py +++ b/pytorch3d/structures/textures.py @@ -3,7 +3,7 @@ from typing import List, Optional, Union import torch -import torchvision.transforms as T +from torch.nn.functional import interpolate from .utils import padded_to_list, padded_to_packed @@ -18,10 +18,10 @@ def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor: Pad all texture images so they have the same height and width. Args: - images: list of N tensors of shape (H, W) + images: list of N tensors of shape (H, W, 3) Returns: - tex_maps: Tensor of shape (N, max_H, max_W) + tex_maps: Tensor of shape (N, max_H, max_W, 3) """ tex_maps = [] max_H = 0 @@ -35,15 +35,13 @@ def _pad_texture_maps(images: List[torch.Tensor]) -> torch.Tensor: tex_maps.append(im) max_shape = (max_H, max_W) - # If all texture images are not the same size then resize to the - # largest size. - resize = T.Compose([T.ToPILImage(), T.Resize(size=max_shape), T.ToTensor()]) - for i, image in enumerate(tex_maps): - if image.shape != max_shape: - # ToPIL takes and returns a C x H x W tensor - image = resize(image.permute(2, 0, 1)).permute(1, 2, 0) - tex_maps[i] = image + if image.shape[:2] != max_shape: + image_BCHW = image.permute(2, 0, 1)[None] + new_image_BCHW = interpolate( + image_BCHW, size=max_shape, mode="bilinear", align_corners=False + ) + tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0) tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3) return tex_maps diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index 1084c7eb..3a40ba7f 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -607,6 +607,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): check_triple(mesh, mesh3) self.assertTupleEqual(mesh.textures.maps_padded().shape, (1, 1024, 1024, 3)) + # Try mismatched texture map sizes, which needs a call to interpolate() + mesh2048 = mesh.clone() + maps = mesh.textures.maps_padded() + mesh2048.textures._maps_padded = torch.cat([maps, maps], dim=1) + join_meshes_as_batch([mesh.to("cuda:0"), mesh2048.to("cuda:0")]) + mesh_notex = load_objs_as_meshes([obj_filename], load_textures=False) mesh3_notex = load_objs_as_meshes( [obj_filename, obj_filename, obj_filename], load_textures=False