diff --git a/pytorch3d/renderer/implicit/renderer.py b/pytorch3d/renderer/implicit/renderer.py index 1bed49c9..98fa5d28 100644 --- a/pytorch3d/renderer/implicit/renderer.py +++ b/pytorch3d/renderer/implicit/renderer.py @@ -261,19 +261,28 @@ class VolumeSampler(torch.nn.Module): at 3D points sampled along projection rays. """ - def __init__(self, volumes: Volumes, sample_mode: str = "bilinear") -> None: + def __init__( + self, + volumes: Volumes, + sample_mode: str = "bilinear", + padding_mode: str = "zeros", + ) -> None: """ Args: volumes: An instance of the `Volumes` class representing a batch of volumes that are being rendered. sample_mode: Defines the algorithm used to sample the volumetric voxel grid. Can be either "bilinear" or "nearest". + padding_mode: How to handle values outside of the volume. + One of: zeros, border, reflection + See torch.nn.functional.grid_sample for more information. """ super().__init__() if not isinstance(volumes, Volumes): raise ValueError("'volumes' have to be an instance of the 'Volumes' class.") self._volumes = volumes self._sample_mode = sample_mode + self._padding_mode = padding_mode def _get_ray_directions_transform(self): """ @@ -375,6 +384,7 @@ class VolumeSampler(torch.nn.Module): rays_points_local_flat, align_corners=True, mode=self._sample_mode, + padding_mode=self._padding_mode, ) # permute the dimensions & reshape densities after sampling @@ -392,6 +402,7 @@ class VolumeSampler(torch.nn.Module): rays_points_local_flat, align_corners=True, mode=self._sample_mode, + padding_mode=self._padding_mode, ) # permute the dimensions & reshape features after sampling