diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 736be41d..ed9d6b73 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -1248,8 +1248,9 @@ class TexturesUV(TexturesBase): pixel_to_map_ids = ( maps_ids_padded.flatten() .gather(0, pix_to_face.flatten()) - .view(N, K, H_out, W_out) - ) + .view(N, H_out, W_out, K, 1) + .permute(0, 3, 1, 2, 4) + ) # N x H_out x W_out x K x 1 # Normalize between -1 and 1 with M (number of maps) pixel_to_map_ids = (2.0 * pixel_to_map_ids.float() / float(M - 1)) - 1 @@ -1258,10 +1259,10 @@ class TexturesUV(TexturesBase): pixel_uvs.new_tensor([-1.0, 1.0]), pixel_uvs.new_tensor([1.0, -1.0]), pixel_uvs, - ) + ) # N x H_out x W_out x K x 2 # N x H_out x W_out x K x 3 - pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids.unsqueeze(4)), dim=4) + pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids), dim=4) # (N, M, H, W, C) -> (N, C, M, H, W) texture_maps = texture_maps.permute(0, 4, 1, 2, 3) if texture_maps.device != pixel_uvs.device: