diff --git a/projects/nerf/nerf/implicit_function.py b/projects/nerf/nerf/implicit_function.py index db5b37d5..f07cec2c 100644 --- a/projects/nerf/nerf/implicit_function.py +++ b/projects/nerf/nerf/implicit_function.py @@ -1,5 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import List +from typing import Tuple import torch from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points @@ -23,7 +23,8 @@ class NeuralRadianceField(torch.nn.Module): n_hidden_neurons_xyz: int = 256, n_hidden_neurons_dir: int = 128, n_layers_xyz: int = 8, - append_xyz: List[int] = (5,), + append_xyz: Tuple[int] = (5,), + use_multiple_streams: bool = True, **kwargs, ): """ @@ -43,6 +44,8 @@ class NeuralRadianceField(torch.nn.Module): n_layers_xyz: The number of layers of the MLP that outputs the occupancy field. append_xyz: The list of indices of the skip layers of the occupancy MLP. + use_multiple_streams: Whether density and color should be calculated on + separate CUDA streams. """ super().__init__() @@ -83,13 +86,14 @@ class NeuralRadianceField(torch.nn.Module): torch.nn.Linear(n_hidden_neurons_dir, 3), torch.nn.Sigmoid(), ) + self.use_multiple_streams = use_multiple_streams def _get_densities( self, features: torch.Tensor, depth_values: torch.Tensor, density_noise_std: float, - ): + ) -> torch.Tensor: """ This function takes `features` predicted by `self.mlp_xyz` and converts them to `raw_densities` with `self.density_layer`. @@ -111,7 +115,9 @@ class NeuralRadianceField(torch.nn.Module): densities = 1 - (-deltas * torch.relu(raw_densities)).exp() return densities - def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor): + def _get_colors( + self, features: torch.Tensor, rays_directions: torch.Tensor + ) -> torch.Tensor: """ This function takes per-point `features` predicted by `self.mlp_xyz` and evaluates the color model in order to attach to each @@ -125,12 +131,54 @@ class NeuralRadianceField(torch.nn.Module): return self.color_layer((self.intermediate_linear(features), rays_embedding)) + def _get_densities_and_colors( + self, features: torch.Tensor, ray_bundle: RayBundle, density_noise_std: float + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + The second part of the forward calculation. + + Args: + features: the output of the common mlp (the prior part of the + calculation), shape + (minibatch x ... x self.n_hidden_neurons_xyz). + ray_bundle: As for forward(). + density_noise_std: As for forward(). + + Returns: + rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` + denoting the opacity of each ray point. + rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` + denoting the color of each ray point. + """ + if self.use_multiple_streams and features.is_cuda: + current_stream = torch.cuda.current_stream(features.device) + other_stream = torch.cuda.Stream(features.device) + other_stream.wait_stream(current_stream) + + with torch.cuda.stream(other_stream): + rays_densities = self._get_densities( + features, ray_bundle.lengths, density_noise_std + ) + # rays_densities.shape = [minibatch x ... x 1] in [0-1] + + rays_colors = self._get_colors(features, ray_bundle.directions) + # rays_colors.shape = [minibatch x ... x 3] in [0-1] + + current_stream.wait_stream(other_stream) + else: + # Same calculation as above, just serial. + rays_densities = self._get_densities( + features, ray_bundle.lengths, density_noise_std + ) + rays_colors = self._get_colors(features, ray_bundle.directions) + return rays_densities, rays_colors + def forward( self, ray_bundle: RayBundle, density_noise_std: float = 0.0, **kwargs, - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: """ The forward function accepts the parametrizations of 3D points sampled along projection rays. The forward @@ -169,14 +217,9 @@ class NeuralRadianceField(torch.nn.Module): features = self.mlp_xyz(embeds_xyz, embeds_xyz) # features.shape = [minibatch x ... x self.n_hidden_neurons_xyz] - rays_densities = self._get_densities( - features, ray_bundle.lengths, density_noise_std + rays_densities, rays_colors = self._get_densities_and_colors( + features, ray_bundle, density_noise_std ) - # rays_densities.shape = [minibatch x ... x 1] in [0-1] - - rays_colors = self._get_colors(features, ray_bundle.directions) - # rays_colors.shape = [minibatch x ... x 3] in [0-1] - return rays_densities, rays_colors @@ -207,7 +250,7 @@ class MLPWithInputSkips(torch.nn.Module): output_dim: int, skip_dim: int, hidden_dim: int, - input_skips: List[int] = (), + input_skips: Tuple[int] = (), ): """ Args: @@ -238,7 +281,7 @@ class MLPWithInputSkips(torch.nn.Module): self.mlp = torch.nn.ModuleList(layers) self._input_skips = set(input_skips) - def forward(self, x, z): + def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """ Args: x: The input tensor of shape `(..., input_dim)`. diff --git a/projects/nerf/nerf/nerf_renderer.py b/projects/nerf/nerf/nerf_renderer.py index c84afc26..084cecc5 100644 --- a/projects/nerf/nerf/nerf_renderer.py +++ b/projects/nerf/nerf/nerf_renderer.py @@ -62,7 +62,7 @@ class RadianceFieldRenderer(torch.nn.Module): n_hidden_neurons_xyz: int = 256, n_hidden_neurons_dir: int = 128, n_layers_xyz: int = 8, - append_xyz: List[int] = (5,), + append_xyz: Tuple[int] = (5,), density_noise_std: float = 0.0, ): """