VolumeSampler: expose padding_mode for inside out rendering (#1638)

Summary:
This exposes a setting on VolumeSampler so you can change the padding_mode. This is very useful when using cameras inside a volume that doesn't cover the entire world. By setting the value to `border` you can get much better behavior than `zeros` which causes edge effects for things like the sky. Border emulates infinitely tall buildings instead.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1638

Test Plan:
Tested with torchdrive

Example before:
![image](https://github.com/facebookresearch/pytorch3d/assets/909104/e99ffb7c-c4ba-40f8-b15c-ad5d1b53f0df)

Example after:
![image](https://github.com/facebookresearch/pytorch3d/assets/909104/f8d9821b-93d5-44b5-b9d4-c1670711ddce)

Reviewed By: MichaelRamamonjisoa

Differential Revision: D49384383

Pulled By: bottler

fbshipit-source-id: 202b526e07320a18944c39a148beec94c0f5d68c
This commit is contained in:
Tristan Rice 2023-09-20 08:00:02 -07:00 committed by Facebook GitHub Bot
parent 6f2212da46
commit b7f4ba097c

View File

@ -261,19 +261,28 @@ class VolumeSampler(torch.nn.Module):
at 3D points sampled along projection rays. at 3D points sampled along projection rays.
""" """
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear") -> None: def __init__(
self,
volumes: Volumes,
sample_mode: str = "bilinear",
padding_mode: str = "zeros",
) -> None:
""" """
Args: Args:
volumes: An instance of the `Volumes` class representing a volumes: An instance of the `Volumes` class representing a
batch of volumes that are being rendered. batch of volumes that are being rendered.
sample_mode: Defines the algorithm used to sample the volumetric sample_mode: Defines the algorithm used to sample the volumetric
voxel grid. Can be either "bilinear" or "nearest". voxel grid. Can be either "bilinear" or "nearest".
padding_mode: How to handle values outside of the volume.
One of: zeros, border, reflection
See torch.nn.functional.grid_sample for more information.
""" """
super().__init__() super().__init__()
if not isinstance(volumes, Volumes): if not isinstance(volumes, Volumes):
raise ValueError("'volumes' have to be an instance of the 'Volumes' class.") raise ValueError("'volumes' have to be an instance of the 'Volumes' class.")
self._volumes = volumes self._volumes = volumes
self._sample_mode = sample_mode self._sample_mode = sample_mode
self._padding_mode = padding_mode
def _get_ray_directions_transform(self): def _get_ray_directions_transform(self):
""" """
@ -375,6 +384,7 @@ class VolumeSampler(torch.nn.Module):
rays_points_local_flat, rays_points_local_flat,
align_corners=True, align_corners=True,
mode=self._sample_mode, mode=self._sample_mode,
padding_mode=self._padding_mode,
) )
# permute the dimensions & reshape densities after sampling # permute the dimensions & reshape densities after sampling
@ -392,6 +402,7 @@ class VolumeSampler(torch.nn.Module):
rays_points_local_flat, rays_points_local_flat,
align_corners=True, align_corners=True,
mode=self._sample_mode, mode=self._sample_mode,
padding_mode=self._padding_mode,
) )
# permute the dimensions & reshape features after sampling # permute the dimensions & reshape features after sampling