diff --git a/pytorch3d/renderer/implicit/renderer.py b/pytorch3d/renderer/implicit/renderer.py index 56583cdb..635a42c7 100644 --- a/pytorch3d/renderer/implicit/renderer.py +++ b/pytorch3d/renderer/implicit/renderer.py @@ -363,35 +363,40 @@ class VolumeSampler(torch.nn.Module): volumes_densities = self._volumes.densities() dim_density = volumes_densities.shape[1] volumes_features = self._volumes.features() - # adjust the volumes_features variable in case we have a feature-less volume - if volumes_features is None: - dim_feature = 0 - data_to_sample = volumes_densities - else: - dim_feature = volumes_features.shape[1] - data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1) # reshape to a size which grid_sample likes rays_points_local_flat = rays_points_local.view( rays_points_local.shape[0], -1, 1, 1, 3 ) - # run the grid sampler - data_sampled = torch.nn.functional.grid_sample( - data_to_sample, + # run the grid sampler on the volumes densities + rays_densities = torch.nn.functional.grid_sample( + volumes_densities, rays_points_local_flat, align_corners=True, mode=self._sample_mode, ) - # permute the dimensions & reshape after sampling - data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view( - *rays_points_local.shape[:-1], data_sampled.shape[1] + # permute the dimensions & reshape densities after sampling + rays_densities = rays_densities.permute(0, 2, 3, 4, 1).view( + *rays_points_local.shape[:-1], volumes_densities.shape[1] ) - # split back to densities and features - rays_densities, rays_features = data_sampled.split( - [dim_density, dim_feature], dim=-1 - ) + # if features exist, run grid sampler again on the features densities + if volumes_features is None: + dim_feature = 0 + _, rays_features = rays_densities.split([dim_density, dim_feature], dim=-1) + else: + rays_features = torch.nn.functional.grid_sample( + volumes_features, + rays_points_local_flat, + align_corners=True, + mode=self._sample_mode, + ) + + # permute the dimensions & reshape features after sampling + rays_features = rays_features.permute(0, 2, 3, 4, 1).view( + *rays_points_local.shape[:-1], volumes_features.shape[1] + ) return rays_densities, rays_features