mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
cuda streams for color/density in NeRF
Summary: Inside the implicit function, the color and density calculations are independent and time is saved by putting them on separate streams. (In fact, colors is slower than densities, and the raymarcher does some calculation with the densities before the colors. So theoretically we could go further and not join the streams together until the colors are actually needed. The code would be more complicated. But the profile suggests that the raymarcher is quick and so this wouldn't be expected to make a big difference.) In inference, this might increase memory usage, so it isn't an obvious win. That is why I have added a flag. Reviewed By: nikhilaravi Differential Revision: D28648549 fbshipit-source-id: c087de80d8ccfce1dad3a13e71df2f305a36952e
This commit is contained in:
parent
f63e49d245
commit
280fed3c76
@ -1,5 +1,5 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
|
||||
@ -23,7 +23,8 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
n_hidden_neurons_xyz: int = 256,
|
||||
n_hidden_neurons_dir: int = 128,
|
||||
n_layers_xyz: int = 8,
|
||||
append_xyz: List[int] = (5,),
|
||||
append_xyz: Tuple[int] = (5,),
|
||||
use_multiple_streams: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -43,6 +44,8 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
n_layers_xyz: The number of layers of the MLP that outputs the
|
||||
occupancy field.
|
||||
append_xyz: The list of indices of the skip layers of the occupancy MLP.
|
||||
use_multiple_streams: Whether density and color should be calculated on
|
||||
separate CUDA streams.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -83,13 +86,14 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
torch.nn.Linear(n_hidden_neurons_dir, 3),
|
||||
torch.nn.Sigmoid(),
|
||||
)
|
||||
self.use_multiple_streams = use_multiple_streams
|
||||
|
||||
def _get_densities(
|
||||
self,
|
||||
features: torch.Tensor,
|
||||
depth_values: torch.Tensor,
|
||||
density_noise_std: float,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function takes `features` predicted by `self.mlp_xyz`
|
||||
and converts them to `raw_densities` with `self.density_layer`.
|
||||
@ -111,7 +115,9 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
densities = 1 - (-deltas * torch.relu(raw_densities)).exp()
|
||||
return densities
|
||||
|
||||
def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
|
||||
def _get_colors(
|
||||
self, features: torch.Tensor, rays_directions: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function takes per-point `features` predicted by `self.mlp_xyz`
|
||||
and evaluates the color model in order to attach to each
|
||||
@ -125,12 +131,54 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
|
||||
return self.color_layer((self.intermediate_linear(features), rays_embedding))
|
||||
|
||||
def _get_densities_and_colors(
|
||||
self, features: torch.Tensor, ray_bundle: RayBundle, density_noise_std: float
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The second part of the forward calculation.
|
||||
|
||||
Args:
|
||||
features: the output of the common mlp (the prior part of the
|
||||
calculation), shape
|
||||
(minibatch x ... x self.n_hidden_neurons_xyz).
|
||||
ray_bundle: As for forward().
|
||||
density_noise_std: As for forward().
|
||||
|
||||
Returns:
|
||||
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
|
||||
denoting the opacity of each ray point.
|
||||
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
||||
denoting the color of each ray point.
|
||||
"""
|
||||
if self.use_multiple_streams and features.is_cuda:
|
||||
current_stream = torch.cuda.current_stream(features.device)
|
||||
other_stream = torch.cuda.Stream(features.device)
|
||||
other_stream.wait_stream(current_stream)
|
||||
|
||||
with torch.cuda.stream(other_stream):
|
||||
rays_densities = self._get_densities(
|
||||
features, ray_bundle.lengths, density_noise_std
|
||||
)
|
||||
# rays_densities.shape = [minibatch x ... x 1] in [0-1]
|
||||
|
||||
rays_colors = self._get_colors(features, ray_bundle.directions)
|
||||
# rays_colors.shape = [minibatch x ... x 3] in [0-1]
|
||||
|
||||
current_stream.wait_stream(other_stream)
|
||||
else:
|
||||
# Same calculation as above, just serial.
|
||||
rays_densities = self._get_densities(
|
||||
features, ray_bundle.lengths, density_noise_std
|
||||
)
|
||||
rays_colors = self._get_colors(features, ray_bundle.directions)
|
||||
return rays_densities, rays_colors
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ray_bundle: RayBundle,
|
||||
density_noise_std: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
The forward function accepts the parametrizations of
|
||||
3D points sampled along projection rays. The forward
|
||||
@ -169,14 +217,9 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
features = self.mlp_xyz(embeds_xyz, embeds_xyz)
|
||||
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
|
||||
|
||||
rays_densities = self._get_densities(
|
||||
features, ray_bundle.lengths, density_noise_std
|
||||
rays_densities, rays_colors = self._get_densities_and_colors(
|
||||
features, ray_bundle, density_noise_std
|
||||
)
|
||||
# rays_densities.shape = [minibatch x ... x 1] in [0-1]
|
||||
|
||||
rays_colors = self._get_colors(features, ray_bundle.directions)
|
||||
# rays_colors.shape = [minibatch x ... x 3] in [0-1]
|
||||
|
||||
return rays_densities, rays_colors
|
||||
|
||||
|
||||
@ -207,7 +250,7 @@ class MLPWithInputSkips(torch.nn.Module):
|
||||
output_dim: int,
|
||||
skip_dim: int,
|
||||
hidden_dim: int,
|
||||
input_skips: List[int] = (),
|
||||
input_skips: Tuple[int] = (),
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -238,7 +281,7 @@ class MLPWithInputSkips(torch.nn.Module):
|
||||
self.mlp = torch.nn.ModuleList(layers)
|
||||
self._input_skips = set(input_skips)
|
||||
|
||||
def forward(self, x, z):
|
||||
def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: The input tensor of shape `(..., input_dim)`.
|
||||
|
@ -62,7 +62,7 @@ class RadianceFieldRenderer(torch.nn.Module):
|
||||
n_hidden_neurons_xyz: int = 256,
|
||||
n_hidden_neurons_dir: int = 128,
|
||||
n_layers_xyz: int = 8,
|
||||
append_xyz: List[int] = (5,),
|
||||
append_xyz: Tuple[int] = (5,),
|
||||
density_noise_std: float = 0.0,
|
||||
):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user