Implicit/Volume renderer

Summary: Implements the `ImplicitRenderer` and `VolumeRenderer`.

Reviewed By: gkioxari

Differential Revision: D24418791

fbshipit-source-id: 127f21186d8e210895db1dcd0681f09f230d81a4
This commit is contained in:
David Novotny
2021-01-06 06:21:50 -08:00
committed by Facebook GitHub Bot
parent e6a32bfc37
commit b466c381da
8 changed files with 1575 additions and 3 deletions

View File

@@ -24,9 +24,12 @@ from .implicit import (
AbsorptionOnlyRaymarcher,
EmissionAbsorptionRaymarcher,
GridRaysampler,
ImplicitRenderer,
MonteCarloRaysampler,
NDCGridRaysampler,
RayBundle,
VolumeRenderer,
VolumeSampler,
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
)

View File

@@ -2,6 +2,7 @@
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
from .utils import (
RayBundle,
ray_bundle_to_ray_points,

View File

@@ -0,0 +1,372 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Callable, Tuple
import torch
from ...ops.utils import eyes
from ...structures import Volumes
from ...transforms import Transform3d
from ..cameras import CamerasBase
from .raysampling import RayBundle
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points
# The implicit renderer class should be initialized with a
# function for raysampling and a function for raymarching.
# During the forward pass:
# 1) The raysampler:
# - samples rays from input cameras
# - transforms the rays to world coordinates
# 2) The volumetric_function (which is a callable argument of the forwad pass)
# evaluates ray_densities and ray_features at the sampled ray-points.
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
# algorithm to render each ray.
class ImplicitRenderer(torch.nn.Module):
"""
A class for rendering a batch of implicit surfaces. The class should
be initialized with a raysampler and raymarcher class which both have
to be a `Callable`.
VOLUMETRIC_FUNCTION
The `forward` function of the renderer accepts as input the rendering cameras as well
as the `volumetric_function` `Callable`, which defines a field of opacity
and feature vectors over the 3D domain of the scene.
A standard `volumetric_function` has the following signature:
```
def volumetric_function(ray_bundle: RayBundle) -> Tuple[torch.Tensor, torch.Tensor]
```
With the following arguments:
`ray_bundle`: A RayBundle object containing the following variables:
`rays_origins`: A tensor of shape `(minibatch, ..., 3)` denoting
the origins of the rendering rays.
`rays_directions`: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of rendering rays.
`rays_lengths`: A tensor of shape
`(minibatch, ..., num_points_per_ray)`containing the
lengths at which the ray points are sampled.
Calling `volumetric_function` then returns the following:
`rays_densities`: A tensor of shape
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing
the an opacity vector for each ray point.
`rays_features`: A tensor of shape
`(minibatch, ..., num_points_per_ray, feature_dim)` containing
the an feature vector for each ray point.
Example:
A simple volumetric function of a 0-centered
RGB sphere with a unit diameter is defined as follows:
```
def volumetric_function(
ray_bundle: RayBundle,
) -> Tuple[torch.Tensor, torch.Tensor]:
# first convert the ray origins, directions and lengths
# to 3D ray point locations in world coords
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# set the densities as an inverse sigmoid of the
# ray point distance from the sphere centroid
rays_densities = torch.sigmoid(
-100.0 * rays_points_world.norm(dim=-1, keepdim=True)
)
# set the ray features to RGB colors proportional
# to the 3D location of the projection of ray points
# on the sphere surface
rays_features = torch.nn.functional.normalize(
rays_points_world, dim=-1
) * 0.5 + 0.5
return rays_densities, rays_features
```
"""
def __init__(self, raysampler: Callable, raymarcher: Callable):
"""
Args:
raysampler: A `Callable` that takes as input scene cameras
(an instance of `CamerasBase`) and returns a `RayBundle` that
describes the rays emitted from the cameras.
raymarcher: A `Callable` that receives the response of the
`volumetric_function` (an input to `self.forward`) evaluated
along the sampled rays, and renders the rays with a
ray-marching algorithm.
"""
super().__init__()
if not callable(raysampler):
raise ValueError('"raysampler" has to be a "Callable" object.')
if not callable(raymarcher):
raise ValueError('"raymarcher" has to be a "Callable" object.')
self.raysampler = raysampler
self.raymarcher = raymarcher
def forward(
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
) -> Tuple[torch.Tensor, RayBundle]:
"""
Render a batch of images using a volumetric function
represented as a callable (e.g. a Pytorch module).
Args:
cameras: A batch of cameras that render the scene. A `self.raysampler`
takes the cameras as input and samples rays that pass through the
domain of the volumentric function.
volumetric_function: A `Callable` that accepts the parametrizations
of the rendering rays and returns the densities and features
at the respective 3D of the rendering rays. Please refer to
the main class documentation for details.
Returns:
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
containing the result of the rendering.
ray_bundle: A `RayBundle` containing the parametrizations of the
sampled rendering rays.
"""
if not callable(volumetric_function):
raise ValueError('"volumetric_function" has to be a "Callable" object.')
# first call the ray sampler that returns the RayBundle parametrizing
# the rendering rays.
ray_bundle = self.raysampler(
cameras=cameras, volumetric_function=volumetric_function, **kwargs
)
# ray_bundle.origins - minibatch x ... x 3
# ray_bundle.directions - minibatch x ... x 3
# ray_bundle.lengths - minibatch x ... x n_pts_per_ray
# ray_bundle.xys - minibatch x ... x 2
# given sampled rays, call the volumetric function that
# evaluates the densities and features at the locations of the
# ray points
rays_densities, rays_features = volumetric_function(
ray_bundle=ray_bundle, cameras=cameras, **kwargs
)
# ray_densities - minibatch x ... x n_pts_per_ray x density_dim
# ray_features - minibatch x ... x n_pts_per_ray x feature_dim
# finally, march along the sampled rays to obtain the renders
images = self.raymarcher(
rays_densities=rays_densities,
rays_features=rays_features,
ray_bundle=ray_bundle,
**kwargs
)
# images - minibatch x ... x (feature_dim + opacity_dim)
return images, ray_bundle
# The volume renderer class should be initialized with a
# function for raysampling and a function for raymarching.
# During the forward pass:
# 1) The raysampler:
# - samples rays from input cameras
# - transforms the rays to world coordinates
# 2) The scene volumes (which are an argument of the forward function)
# are then sampled at the locations of the ray-points to generate
# ray_densities and ray_features.
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
# algorithm to render each ray.
class VolumeRenderer(torch.nn.Module):
"""
A class for rendering a batch of Volumes. The class should
be initialized with a raysampler and a raymarcher class which both have
to be a `Callable`.
"""
def __init__(
self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear"
):
"""
Args:
raysampler: A `Callable` that takes as input scene cameras
(an instance of `CamerasBase`) and returns a `RayBundle` that
describes the rays emitted from the cameras.
raymarcher: A `Callable` that receives the `volumes`
(an instance of `Volumes` input to `self.forward`)
sampled at the ray-points, and renders the rays with a
ray-marching algorithm.
sample_mode: Defines the algorithm used to sample the volumetric
voxel grid. Can be either "bilinear" or "nearest".
"""
super().__init__()
self.renderer = ImplicitRenderer(raysampler, raymarcher)
self._sample_mode = sample_mode
def forward(
self, cameras: CamerasBase, volumes: Volumes, **kwargs
) -> Tuple[torch.Tensor, RayBundle]:
"""
Render a batch of images using raymarching over rays cast through
input `Volumes`.
Args:
cameras: A batch of cameras that render the scene. A `self.raysampler`
takes the cameras as input and samples rays that pass through the
domain of the volumentric function.
volumes: An instance of the `Volumes` class representing a
batch of volumes that are being rendered.
Returns:
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
containing the result of the rendering.
ray_bundle: A `RayBundle` containing the parametrizations of the
sampled rendering rays.
"""
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
return self.renderer(
cameras=cameras, volumetric_function=volumetric_function, **kwargs
)
class VolumeSampler(torch.nn.Module):
"""
A class that allows to sample a batch of volumes `Volumes`
at 3D points sampled along projection rays.
"""
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"):
"""
Args:
volumes: An instance of the `Volumes` class representing a
batch if volumes that are being rendered.
sample_mode: Defines the algorithm used to sample the volumetric
voxel grid. Can be either "bilinear" or "nearest".
"""
super().__init__()
if not isinstance(volumes, Volumes):
raise ValueError("'volumes' have to be an instance of the 'Volumes' class.")
self._volumes = volumes
self._sample_mode = sample_mode
def _get_ray_directions_transform(self):
"""
Compose the ray-directions transform by removing the translation component
from the volume global-to-local coords transform.
"""
world2local = self._volumes.get_world_to_local_coords_transform().get_matrix()
directions_transform_matrix = eyes(
4,
N=world2local.shape[0],
device=world2local.device,
dtype=world2local.dtype,
)
directions_transform_matrix[:, :3, :3] = world2local[:, :3, :3]
directions_transform = Transform3d(matrix=directions_transform_matrix)
return directions_transform
def forward(
self, ray_bundle: RayBundle, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given an input ray parametrization, the forward function samples
`self._volumes` at the respective 3D ray-points.
Args:
ray_bundle: A RayBundle object with the following fields:
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
rays_directions_world: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
rays_lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
Returns:
rays_densities: A tensor of shape
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing the
densitity vectors sampled from the volume at the locations of
the ray points.
rays_features: A tensor of shape
`(minibatch, ..., num_points_per_ray, feature_dim)` containing the
feature vectors sampled from the volume at the locations of
the ray points.
"""
# take out the interesting parts of ray_bundle
rays_origins_world = ray_bundle.origins
rays_directions_world = ray_bundle.directions
rays_lengths = ray_bundle.lengths
# validate the inputs
_validate_ray_bundle_variables(
rays_origins_world, rays_directions_world, rays_lengths
)
if self._volumes.densities().shape[0] != rays_origins_world.shape[0]:
raise ValueError("Input volumes have to have the same batch size as rays.")
#########################################################
# 1) convert the origins/directions to the local coords #
#########################################################
# origins are mapped with the world_to_local transform of the volumes
rays_origins_local = self._volumes.world_to_local_coords(rays_origins_world)
# obtain the Transform3d object that transforms ray directions to local coords
directions_transform = self._get_ray_directions_transform()
# transform the directions to the local coords
rays_directions_local = directions_transform.transform_points(
rays_directions_world.view(rays_lengths.shape[0], -1, 3)
).view(rays_directions_world.shape)
############################
# 2) obtain the ray points #
############################
# this op produces a fairly big tensor (minibatch, ..., n_samples_per_ray, 3)
rays_points_local = ray_bundle_variables_to_ray_points(
rays_origins_local, rays_directions_local, rays_lengths
)
########################
# 3) sample the volume #
########################
# generate the tensor for sampling
volumes_densities = self._volumes.densities()
dim_density = volumes_densities.shape[1]
volumes_features = self._volumes.features()
# adjust the volumes_features variable in case we have a feature-less volume
if volumes_features is None:
dim_feature = 0
data_to_sample = volumes_densities
else:
dim_feature = volumes_features.shape[1]
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)
# reshape to a size which grid_sample likes
rays_points_local_flat = rays_points_local.view(
rays_points_local.shape[0], -1, 1, 1, 3
)
# run the grid sampler
data_sampled = torch.nn.functional.grid_sample(
data_to_sample,
rays_points_local_flat,
align_corners=True,
mode=self._sample_mode,
)
# permute the dimensions & reshape after sampling
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
*rays_points_local.shape[:-1], data_sampled.shape[1]
)
# split back to densities and features
rays_densities, rays_features = data_sampled.split(
[dim_density, dim_feature], dim=-1
)
return rays_densities, rays_features

View File

@@ -53,12 +53,12 @@ def ray_bundle_variables_to_ray_points(
rays_lengths: torch.Tensor,
) -> torch.Tensor:
"""
Converts rays parametrized with origins, directions
Converts rays parametrized with origins and directions
to 3D points by extending each ray according to the corresponding
ray_length:
ray length:
E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions`
and `rays_lengths`, the ray point at position `[i, j]` is:
and `rays_lengths`, the ray point at position `[i, j]` is:
```
rays_points[i, j, :] = (
rays_origins[i, :]
@@ -80,3 +80,39 @@ def ray_bundle_variables_to_ray_points(
+ rays_lengths[..., :, None] * rays_directions[..., None, :]
)
return rays_points
def _validate_ray_bundle_variables(
rays_origins: torch.Tensor,
rays_directions: torch.Tensor,
rays_lengths: torch.Tensor,
):
"""
Validate the shapes of RayBundle variables
`rays_origins`, `rays_directions`, and `rays_lengths`.
"""
ndim = rays_origins.ndim
if any(r.ndim != ndim for r in (rays_directions, rays_lengths)):
raise ValueError(
"rays_origins, rays_directions and rays_lengths"
+ " have to have the same number of dimensions."
)
if ndim <= 2:
raise ValueError(
"rays_origins, rays_directions and rays_lengths"
+ " have to have at least 3 dimensions."
)
spatial_size = rays_origins.shape[:-1]
if any(spatial_size != r.shape[:-1] for r in (rays_directions, rays_lengths)):
raise ValueError(
"The shapes of rays_origins, rays_directions and rays_lengths"
+ " may differ only in the last dimension."
)
if any(r.shape[-1] != 3 for r in (rays_origins, rays_directions)):
raise ValueError(
"The size of the last dimension of rays_origins/rays_directions"
+ "has to be 3."
)