From 9eaed4c495e42527352c288053daa37b9c898af1 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Thu, 31 Oct 2024 16:05:10 -0700 Subject: [PATCH] Fix K>1 in multimap UV sampling Summary: Fixes https://github.com/facebookresearch/pytorch3d/issues/1897 "Wrong dimension on gather". Reviewed By: cijose Differential Revision: D65280675 fbshipit-source-id: 1d587036887972bb2a2ea56d40df19cbf1aeb6cc --- pytorch3d/renderer/mesh/textures.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 736be41d..ed9d6b73 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -1248,8 +1248,9 @@ class TexturesUV(TexturesBase): pixel_to_map_ids = ( maps_ids_padded.flatten() .gather(0, pix_to_face.flatten()) - .view(N, K, H_out, W_out) - ) + .view(N, H_out, W_out, K, 1) + .permute(0, 3, 1, 2, 4) + ) # N x H_out x W_out x K x 1 # Normalize between -1 and 1 with M (number of maps) pixel_to_map_ids = (2.0 * pixel_to_map_ids.float() / float(M - 1)) - 1 @@ -1258,10 +1259,10 @@ class TexturesUV(TexturesBase): pixel_uvs.new_tensor([-1.0, 1.0]), pixel_uvs.new_tensor([1.0, -1.0]), pixel_uvs, - ) + ) # N x H_out x W_out x K x 2 # N x H_out x W_out x K x 3 - pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids.unsqueeze(4)), dim=4) + pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids), dim=4) # (N, M, H, W, C) -> (N, C, M, H, W) texture_maps = texture_maps.permute(0, 4, 1, 2, 3) if texture_maps.device != pixel_uvs.device: