mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Implicit/Volume renderer
Summary: Implements the `ImplicitRenderer` and `VolumeRenderer`. Reviewed By: gkioxari Differential Revision: D24418791 fbshipit-source-id: 127f21186d8e210895db1dcd0681f09f230d81a4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e6a32bfc37
commit
b466c381da
@@ -24,9 +24,12 @@ from .implicit import (
|
||||
AbsorptionOnlyRaymarcher,
|
||||
EmissionAbsorptionRaymarcher,
|
||||
GridRaysampler,
|
||||
ImplicitRenderer,
|
||||
MonteCarloRaysampler,
|
||||
NDCGridRaysampler,
|
||||
RayBundle,
|
||||
VolumeRenderer,
|
||||
VolumeSampler,
|
||||
ray_bundle_to_ray_points,
|
||||
ray_bundle_variables_to_ray_points,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
|
||||
from .utils import (
|
||||
RayBundle,
|
||||
ray_bundle_to_ray_points,
|
||||
|
||||
372
pytorch3d/renderer/implicit/renderer.py
Normal file
372
pytorch3d/renderer/implicit/renderer.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...ops.utils import eyes
|
||||
from ...structures import Volumes
|
||||
from ...transforms import Transform3d
|
||||
from ..cameras import CamerasBase
|
||||
from .raysampling import RayBundle
|
||||
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points
|
||||
|
||||
|
||||
# The implicit renderer class should be initialized with a
|
||||
# function for raysampling and a function for raymarching.
|
||||
|
||||
# During the forward pass:
|
||||
# 1) The raysampler:
|
||||
# - samples rays from input cameras
|
||||
# - transforms the rays to world coordinates
|
||||
# 2) The volumetric_function (which is a callable argument of the forwad pass)
|
||||
# evaluates ray_densities and ray_features at the sampled ray-points.
|
||||
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
|
||||
# algorithm to render each ray.
|
||||
|
||||
|
||||
class ImplicitRenderer(torch.nn.Module):
|
||||
"""
|
||||
A class for rendering a batch of implicit surfaces. The class should
|
||||
be initialized with a raysampler and raymarcher class which both have
|
||||
to be a `Callable`.
|
||||
|
||||
VOLUMETRIC_FUNCTION
|
||||
|
||||
The `forward` function of the renderer accepts as input the rendering cameras as well
|
||||
as the `volumetric_function` `Callable`, which defines a field of opacity
|
||||
and feature vectors over the 3D domain of the scene.
|
||||
|
||||
A standard `volumetric_function` has the following signature:
|
||||
```
|
||||
def volumetric_function(ray_bundle: RayBundle) -> Tuple[torch.Tensor, torch.Tensor]
|
||||
```
|
||||
With the following arguments:
|
||||
`ray_bundle`: A RayBundle object containing the following variables:
|
||||
`rays_origins`: A tensor of shape `(minibatch, ..., 3)` denoting
|
||||
the origins of the rendering rays.
|
||||
`rays_directions`: A tensor of shape `(minibatch, ..., 3)`
|
||||
containing the direction vectors of rendering rays.
|
||||
`rays_lengths`: A tensor of shape
|
||||
`(minibatch, ..., num_points_per_ray)`containing the
|
||||
lengths at which the ray points are sampled.
|
||||
Calling `volumetric_function` then returns the following:
|
||||
`rays_densities`: A tensor of shape
|
||||
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing
|
||||
the an opacity vector for each ray point.
|
||||
`rays_features`: A tensor of shape
|
||||
`(minibatch, ..., num_points_per_ray, feature_dim)` containing
|
||||
the an feature vector for each ray point.
|
||||
|
||||
Example:
|
||||
A simple volumetric function of a 0-centered
|
||||
RGB sphere with a unit diameter is defined as follows:
|
||||
```
|
||||
def volumetric_function(
|
||||
ray_bundle: RayBundle,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# first convert the ray origins, directions and lengths
|
||||
# to 3D ray point locations in world coords
|
||||
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||
|
||||
# set the densities as an inverse sigmoid of the
|
||||
# ray point distance from the sphere centroid
|
||||
rays_densities = torch.sigmoid(
|
||||
-100.0 * rays_points_world.norm(dim=-1, keepdim=True)
|
||||
)
|
||||
|
||||
# set the ray features to RGB colors proportional
|
||||
# to the 3D location of the projection of ray points
|
||||
# on the sphere surface
|
||||
rays_features = torch.nn.functional.normalize(
|
||||
rays_points_world, dim=-1
|
||||
) * 0.5 + 0.5
|
||||
|
||||
return rays_densities, rays_features
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, raysampler: Callable, raymarcher: Callable):
|
||||
"""
|
||||
Args:
|
||||
raysampler: A `Callable` that takes as input scene cameras
|
||||
(an instance of `CamerasBase`) and returns a `RayBundle` that
|
||||
describes the rays emitted from the cameras.
|
||||
raymarcher: A `Callable` that receives the response of the
|
||||
`volumetric_function` (an input to `self.forward`) evaluated
|
||||
along the sampled rays, and renders the rays with a
|
||||
ray-marching algorithm.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if not callable(raysampler):
|
||||
raise ValueError('"raysampler" has to be a "Callable" object.')
|
||||
if not callable(raymarcher):
|
||||
raise ValueError('"raymarcher" has to be a "Callable" object.')
|
||||
|
||||
self.raysampler = raysampler
|
||||
self.raymarcher = raymarcher
|
||||
|
||||
def forward(
|
||||
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
|
||||
) -> Tuple[torch.Tensor, RayBundle]:
|
||||
"""
|
||||
Render a batch of images using a volumetric function
|
||||
represented as a callable (e.g. a Pytorch module).
|
||||
|
||||
Args:
|
||||
cameras: A batch of cameras that render the scene. A `self.raysampler`
|
||||
takes the cameras as input and samples rays that pass through the
|
||||
domain of the volumentric function.
|
||||
volumetric_function: A `Callable` that accepts the parametrizations
|
||||
of the rendering rays and returns the densities and features
|
||||
at the respective 3D of the rendering rays. Please refer to
|
||||
the main class documentation for details.
|
||||
|
||||
Returns:
|
||||
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
|
||||
containing the result of the rendering.
|
||||
ray_bundle: A `RayBundle` containing the parametrizations of the
|
||||
sampled rendering rays.
|
||||
"""
|
||||
|
||||
if not callable(volumetric_function):
|
||||
raise ValueError('"volumetric_function" has to be a "Callable" object.')
|
||||
|
||||
# first call the ray sampler that returns the RayBundle parametrizing
|
||||
# the rendering rays.
|
||||
ray_bundle = self.raysampler(
|
||||
cameras=cameras, volumetric_function=volumetric_function, **kwargs
|
||||
)
|
||||
# ray_bundle.origins - minibatch x ... x 3
|
||||
# ray_bundle.directions - minibatch x ... x 3
|
||||
# ray_bundle.lengths - minibatch x ... x n_pts_per_ray
|
||||
# ray_bundle.xys - minibatch x ... x 2
|
||||
|
||||
# given sampled rays, call the volumetric function that
|
||||
# evaluates the densities and features at the locations of the
|
||||
# ray points
|
||||
rays_densities, rays_features = volumetric_function(
|
||||
ray_bundle=ray_bundle, cameras=cameras, **kwargs
|
||||
)
|
||||
# ray_densities - minibatch x ... x n_pts_per_ray x density_dim
|
||||
# ray_features - minibatch x ... x n_pts_per_ray x feature_dim
|
||||
|
||||
# finally, march along the sampled rays to obtain the renders
|
||||
images = self.raymarcher(
|
||||
rays_densities=rays_densities,
|
||||
rays_features=rays_features,
|
||||
ray_bundle=ray_bundle,
|
||||
**kwargs
|
||||
)
|
||||
# images - minibatch x ... x (feature_dim + opacity_dim)
|
||||
|
||||
return images, ray_bundle
|
||||
|
||||
|
||||
# The volume renderer class should be initialized with a
|
||||
# function for raysampling and a function for raymarching.
|
||||
|
||||
# During the forward pass:
|
||||
# 1) The raysampler:
|
||||
# - samples rays from input cameras
|
||||
# - transforms the rays to world coordinates
|
||||
# 2) The scene volumes (which are an argument of the forward function)
|
||||
# are then sampled at the locations of the ray-points to generate
|
||||
# ray_densities and ray_features.
|
||||
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
|
||||
# algorithm to render each ray.
|
||||
|
||||
|
||||
class VolumeRenderer(torch.nn.Module):
|
||||
"""
|
||||
A class for rendering a batch of Volumes. The class should
|
||||
be initialized with a raysampler and a raymarcher class which both have
|
||||
to be a `Callable`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear"
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
raysampler: A `Callable` that takes as input scene cameras
|
||||
(an instance of `CamerasBase`) and returns a `RayBundle` that
|
||||
describes the rays emitted from the cameras.
|
||||
raymarcher: A `Callable` that receives the `volumes`
|
||||
(an instance of `Volumes` input to `self.forward`)
|
||||
sampled at the ray-points, and renders the rays with a
|
||||
ray-marching algorithm.
|
||||
sample_mode: Defines the algorithm used to sample the volumetric
|
||||
voxel grid. Can be either "bilinear" or "nearest".
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.renderer = ImplicitRenderer(raysampler, raymarcher)
|
||||
self._sample_mode = sample_mode
|
||||
|
||||
def forward(
|
||||
self, cameras: CamerasBase, volumes: Volumes, **kwargs
|
||||
) -> Tuple[torch.Tensor, RayBundle]:
|
||||
"""
|
||||
Render a batch of images using raymarching over rays cast through
|
||||
input `Volumes`.
|
||||
|
||||
Args:
|
||||
cameras: A batch of cameras that render the scene. A `self.raysampler`
|
||||
takes the cameras as input and samples rays that pass through the
|
||||
domain of the volumentric function.
|
||||
volumes: An instance of the `Volumes` class representing a
|
||||
batch of volumes that are being rendered.
|
||||
|
||||
Returns:
|
||||
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
|
||||
containing the result of the rendering.
|
||||
ray_bundle: A `RayBundle` containing the parametrizations of the
|
||||
sampled rendering rays.
|
||||
"""
|
||||
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
|
||||
return self.renderer(
|
||||
cameras=cameras, volumetric_function=volumetric_function, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class VolumeSampler(torch.nn.Module):
|
||||
"""
|
||||
A class that allows to sample a batch of volumes `Volumes`
|
||||
at 3D points sampled along projection rays.
|
||||
"""
|
||||
|
||||
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"):
|
||||
"""
|
||||
Args:
|
||||
volumes: An instance of the `Volumes` class representing a
|
||||
batch if volumes that are being rendered.
|
||||
sample_mode: Defines the algorithm used to sample the volumetric
|
||||
voxel grid. Can be either "bilinear" or "nearest".
|
||||
"""
|
||||
super().__init__()
|
||||
if not isinstance(volumes, Volumes):
|
||||
raise ValueError("'volumes' have to be an instance of the 'Volumes' class.")
|
||||
self._volumes = volumes
|
||||
self._sample_mode = sample_mode
|
||||
|
||||
def _get_ray_directions_transform(self):
|
||||
"""
|
||||
Compose the ray-directions transform by removing the translation component
|
||||
from the volume global-to-local coords transform.
|
||||
"""
|
||||
world2local = self._volumes.get_world_to_local_coords_transform().get_matrix()
|
||||
directions_transform_matrix = eyes(
|
||||
4,
|
||||
N=world2local.shape[0],
|
||||
device=world2local.device,
|
||||
dtype=world2local.dtype,
|
||||
)
|
||||
directions_transform_matrix[:, :3, :3] = world2local[:, :3, :3]
|
||||
directions_transform = Transform3d(matrix=directions_transform_matrix)
|
||||
return directions_transform
|
||||
|
||||
def forward(
|
||||
self, ray_bundle: RayBundle, **kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Given an input ray parametrization, the forward function samples
|
||||
`self._volumes` at the respective 3D ray-points.
|
||||
|
||||
Args:
|
||||
ray_bundle: A RayBundle object with the following fields:
|
||||
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||
origins of the sampling rays in world coords.
|
||||
rays_directions_world: A tensor of shape `(minibatch, ..., 3)`
|
||||
containing the direction vectors of sampling rays in world coords.
|
||||
rays_lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||
containing the lengths at which the rays are sampled.
|
||||
|
||||
Returns:
|
||||
rays_densities: A tensor of shape
|
||||
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing the
|
||||
densitity vectors sampled from the volume at the locations of
|
||||
the ray points.
|
||||
rays_features: A tensor of shape
|
||||
`(minibatch, ..., num_points_per_ray, feature_dim)` containing the
|
||||
feature vectors sampled from the volume at the locations of
|
||||
the ray points.
|
||||
"""
|
||||
|
||||
# take out the interesting parts of ray_bundle
|
||||
rays_origins_world = ray_bundle.origins
|
||||
rays_directions_world = ray_bundle.directions
|
||||
rays_lengths = ray_bundle.lengths
|
||||
|
||||
# validate the inputs
|
||||
_validate_ray_bundle_variables(
|
||||
rays_origins_world, rays_directions_world, rays_lengths
|
||||
)
|
||||
if self._volumes.densities().shape[0] != rays_origins_world.shape[0]:
|
||||
raise ValueError("Input volumes have to have the same batch size as rays.")
|
||||
|
||||
#########################################################
|
||||
# 1) convert the origins/directions to the local coords #
|
||||
#########################################################
|
||||
|
||||
# origins are mapped with the world_to_local transform of the volumes
|
||||
rays_origins_local = self._volumes.world_to_local_coords(rays_origins_world)
|
||||
|
||||
# obtain the Transform3d object that transforms ray directions to local coords
|
||||
directions_transform = self._get_ray_directions_transform()
|
||||
|
||||
# transform the directions to the local coords
|
||||
rays_directions_local = directions_transform.transform_points(
|
||||
rays_directions_world.view(rays_lengths.shape[0], -1, 3)
|
||||
).view(rays_directions_world.shape)
|
||||
|
||||
############################
|
||||
# 2) obtain the ray points #
|
||||
############################
|
||||
|
||||
# this op produces a fairly big tensor (minibatch, ..., n_samples_per_ray, 3)
|
||||
rays_points_local = ray_bundle_variables_to_ray_points(
|
||||
rays_origins_local, rays_directions_local, rays_lengths
|
||||
)
|
||||
|
||||
########################
|
||||
# 3) sample the volume #
|
||||
########################
|
||||
|
||||
# generate the tensor for sampling
|
||||
volumes_densities = self._volumes.densities()
|
||||
dim_density = volumes_densities.shape[1]
|
||||
volumes_features = self._volumes.features()
|
||||
# adjust the volumes_features variable in case we have a feature-less volume
|
||||
if volumes_features is None:
|
||||
dim_feature = 0
|
||||
data_to_sample = volumes_densities
|
||||
else:
|
||||
dim_feature = volumes_features.shape[1]
|
||||
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)
|
||||
|
||||
# reshape to a size which grid_sample likes
|
||||
rays_points_local_flat = rays_points_local.view(
|
||||
rays_points_local.shape[0], -1, 1, 1, 3
|
||||
)
|
||||
|
||||
# run the grid sampler
|
||||
data_sampled = torch.nn.functional.grid_sample(
|
||||
data_to_sample,
|
||||
rays_points_local_flat,
|
||||
align_corners=True,
|
||||
mode=self._sample_mode,
|
||||
)
|
||||
|
||||
# permute the dimensions & reshape after sampling
|
||||
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
|
||||
*rays_points_local.shape[:-1], data_sampled.shape[1]
|
||||
)
|
||||
|
||||
# split back to densities and features
|
||||
rays_densities, rays_features = data_sampled.split(
|
||||
[dim_density, dim_feature], dim=-1
|
||||
)
|
||||
|
||||
return rays_densities, rays_features
|
||||
@@ -53,12 +53,12 @@ def ray_bundle_variables_to_ray_points(
|
||||
rays_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts rays parametrized with origins, directions
|
||||
Converts rays parametrized with origins and directions
|
||||
to 3D points by extending each ray according to the corresponding
|
||||
ray_length:
|
||||
ray length:
|
||||
|
||||
E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions`
|
||||
and `rays_lengths`, the ray point at position `[i, j]` is:
|
||||
and `rays_lengths`, the ray point at position `[i, j]` is:
|
||||
```
|
||||
rays_points[i, j, :] = (
|
||||
rays_origins[i, :]
|
||||
@@ -80,3 +80,39 @@ def ray_bundle_variables_to_ray_points(
|
||||
+ rays_lengths[..., :, None] * rays_directions[..., None, :]
|
||||
)
|
||||
return rays_points
|
||||
|
||||
|
||||
def _validate_ray_bundle_variables(
|
||||
rays_origins: torch.Tensor,
|
||||
rays_directions: torch.Tensor,
|
||||
rays_lengths: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Validate the shapes of RayBundle variables
|
||||
`rays_origins`, `rays_directions`, and `rays_lengths`.
|
||||
"""
|
||||
ndim = rays_origins.ndim
|
||||
if any(r.ndim != ndim for r in (rays_directions, rays_lengths)):
|
||||
raise ValueError(
|
||||
"rays_origins, rays_directions and rays_lengths"
|
||||
+ " have to have the same number of dimensions."
|
||||
)
|
||||
|
||||
if ndim <= 2:
|
||||
raise ValueError(
|
||||
"rays_origins, rays_directions and rays_lengths"
|
||||
+ " have to have at least 3 dimensions."
|
||||
)
|
||||
|
||||
spatial_size = rays_origins.shape[:-1]
|
||||
if any(spatial_size != r.shape[:-1] for r in (rays_directions, rays_lengths)):
|
||||
raise ValueError(
|
||||
"The shapes of rays_origins, rays_directions and rays_lengths"
|
||||
+ " may differ only in the last dimension."
|
||||
)
|
||||
|
||||
if any(r.shape[-1] != 3 for r in (rays_origins, rays_directions)):
|
||||
raise ValueError(
|
||||
"The size of the last dimension of rays_origins/rays_directions"
|
||||
+ "has to be 3."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user