From ab73f8c3fd2a0cc46c46c7f79565eacfc66d19f7 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 2 Jun 2021 05:42:15 -0700 Subject: [PATCH] LinearWithRepeat layer for NeRF Summary: Add custom layer to avoid repeating copied data for every ray position. This should also save time in the backward pass because there are fewer multiplies with the weights. Reviewed By: theschnitz Differential Revision: D28382412 fbshipit-source-id: 1ba7356cd8520ebd598568ae503e47d31d3989eb --- projects/nerf/nerf/implicit_function.py | 17 ++------ projects/nerf/nerf/linear_with_repeat.py | 52 ++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 14 deletions(-) create mode 100644 projects/nerf/nerf/linear_with_repeat.py diff --git a/projects/nerf/nerf/implicit_function.py b/projects/nerf/nerf/implicit_function.py index ed4c5e23..8589e11b 100644 --- a/projects/nerf/nerf/implicit_function.py +++ b/projects/nerf/nerf/implicit_function.py @@ -5,6 +5,7 @@ import torch from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points from .harmonic_embedding import HarmonicEmbedding +from .linear_with_repeat import LinearWithRepeat def _xavier_init(linear): @@ -75,7 +76,7 @@ class NeuralRadianceField(torch.nn.Module): self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough self.color_layer = torch.nn.Sequential( - torch.nn.Linear( + LinearWithRepeat( n_hidden_neurons_xyz + embedding_dim_dir, n_hidden_neurons_dir ), torch.nn.ReLU(True), @@ -116,7 +117,6 @@ class NeuralRadianceField(torch.nn.Module): and evaluates the color model in order to attach to each point a 3D vector of its RGB color. """ - spatial_size = features.shape[:-1] # Normalize the ray_directions to unit l2 norm. rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1) @@ -129,18 +129,7 @@ class NeuralRadianceField(torch.nn.Module): dim=-1, ) - # Expand the ray directions tensor so that its spatial size - # is equal to the size of features. - rays_embedding_expand = rays_embedding[..., None, :].expand( - *spatial_size, rays_embedding.shape[-1] - ) - - # Concatenate ray direction embeddings with - # features and evaluate the color model. - color_layer_input = torch.cat( - (self.intermediate_linear(features), rays_embedding_expand), dim=-1 - ) - return self.color_layer(color_layer_input) + return self.color_layer((self.intermediate_linear(features), rays_embedding)) def forward( self, diff --git a/projects/nerf/nerf/linear_with_repeat.py b/projects/nerf/nerf/linear_with_repeat.py new file mode 100644 index 00000000..efdc0321 --- /dev/null +++ b/projects/nerf/nerf/linear_with_repeat.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Tuple + +import torch +import torch.nn.functional as F + + +class LinearWithRepeat(torch.nn.Linear): + """ + if x has shape (..., k, n1) + and y has shape (..., n2) + then + LinearWithRepeat(n1 + n2, out_features).forward((x,y)) + is equivalent to + Linear(n1 + n2, out_features).forward( + torch.cat([x, y.unsqueeze(-2).expand(..., k, n2)], dim=-1) + ) + + Or visually: + Given the following, for each ray, + + feature -> + + ray xxxxxxxx + position xxxxxxxx + | xxxxxxxx + v xxxxxxxx + + + and + yyyyyyyy + + where the y's do not depend on the position + but only on the ray, + we want to evaluate a Linear layer on both + types of data at every position. + + It's as if we constructed + + xxxxxxxxyyyyyyyy + xxxxxxxxyyyyyyyy + xxxxxxxxyyyyyyyy + xxxxxxxxyyyyyyyy + + and sent that through the Linear. + """ + + def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + n1 = input[0].shape[-1] + output1 = F.linear(input[0], self.weight[:, :n1], self.bias) + output2 = F.linear(input[1], self.weight[:, n1:], None) + return output1 + output2.unsqueeze(-2)