Fix K>1 in multimap UV sampling

Summary:
Fixes https://github.com/facebookresearch/pytorch3d/issues/1897
"Wrong dimension on gather".

Reviewed By: cijose

Differential Revision: D65280675

fbshipit-source-id: 1d587036887972bb2a2ea56d40df19cbf1aeb6cc
This commit is contained in:
Jeremy Reizenstein 2024-10-31 16:05:10 -07:00 committed by Facebook GitHub Bot
parent e13848265d
commit 9eaed4c495

View File

@ -1248,8 +1248,9 @@ class TexturesUV(TexturesBase):
pixel_to_map_ids = (
maps_ids_padded.flatten()
.gather(0, pix_to_face.flatten())
.view(N, K, H_out, W_out)
)
.view(N, H_out, W_out, K, 1)
.permute(0, 3, 1, 2, 4)
) # N x H_out x W_out x K x 1
# Normalize between -1 and 1 with M (number of maps)
pixel_to_map_ids = (2.0 * pixel_to_map_ids.float() / float(M - 1)) - 1
@ -1258,10 +1259,10 @@ class TexturesUV(TexturesBase):
pixel_uvs.new_tensor([-1.0, 1.0]),
pixel_uvs.new_tensor([1.0, -1.0]),
pixel_uvs,
)
) # N x H_out x W_out x K x 2
# N x H_out x W_out x K x 3
pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids.unsqueeze(4)), dim=4)
pixel_uvms = torch.cat((pixel_uvs, pixel_to_map_ids), dim=4)
# (N, M, H, W, C) -> (N, C, M, H, W)
texture_maps = texture_maps.permute(0, 4, 1, 2, 3)
if texture_maps.device != pixel_uvs.device: