diff --git a/projects/nerf/nerf/implicit_function.py b/projects/nerf/nerf/implicit_function.py index ed4c5e23..8589e11b 100644 --- a/projects/nerf/nerf/implicit_function.py +++ b/projects/nerf/nerf/implicit_function.py @@ -5,6 +5,7 @@ import torch from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points from .harmonic_embedding import HarmonicEmbedding +from .linear_with_repeat import LinearWithRepeat def _xavier_init(linear): @@ -75,7 +76,7 @@ class NeuralRadianceField(torch.nn.Module): self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough self.color_layer = torch.nn.Sequential( - torch.nn.Linear( + LinearWithRepeat( n_hidden_neurons_xyz + embedding_dim_dir, n_hidden_neurons_dir ), torch.nn.ReLU(True), @@ -116,7 +117,6 @@ class NeuralRadianceField(torch.nn.Module): and evaluates the color model in order to attach to each point a 3D vector of its RGB color. """ - spatial_size = features.shape[:-1] # Normalize the ray_directions to unit l2 norm. rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1) @@ -129,18 +129,7 @@ class NeuralRadianceField(torch.nn.Module): dim=-1, ) - # Expand the ray directions tensor so that its spatial size - # is equal to the size of features. - rays_embedding_expand = rays_embedding[..., None, :].expand( - *spatial_size, rays_embedding.shape[-1] - ) - - # Concatenate ray direction embeddings with - # features and evaluate the color model. - color_layer_input = torch.cat( - (self.intermediate_linear(features), rays_embedding_expand), dim=-1 - ) - return self.color_layer(color_layer_input) + return self.color_layer((self.intermediate_linear(features), rays_embedding)) def forward( self, diff --git a/projects/nerf/nerf/linear_with_repeat.py b/projects/nerf/nerf/linear_with_repeat.py new file mode 100644 index 00000000..efdc0321 --- /dev/null +++ b/projects/nerf/nerf/linear_with_repeat.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Tuple + +import torch +import torch.nn.functional as F + + +class LinearWithRepeat(torch.nn.Linear): + """ + if x has shape (..., k, n1) + and y has shape (..., n2) + then + LinearWithRepeat(n1 + n2, out_features).forward((x,y)) + is equivalent to + Linear(n1 + n2, out_features).forward( + torch.cat([x, y.unsqueeze(-2).expand(..., k, n2)], dim=-1) + ) + + Or visually: + Given the following, for each ray, + + feature -> + + ray xxxxxxxx + position xxxxxxxx + | xxxxxxxx + v xxxxxxxx + + + and + yyyyyyyy + + where the y's do not depend on the position + but only on the ray, + we want to evaluate a Linear layer on both + types of data at every position. + + It's as if we constructed + + xxxxxxxxyyyyyyyy + xxxxxxxxyyyyyyyy + xxxxxxxxyyyyyyyy + xxxxxxxxyyyyyyyy + + and sent that through the Linear. + """ + + def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + n1 = input[0].shape[-1] + output1 = F.linear(input[0], self.weight[:, :n1], self.bias) + output2 = F.linear(input[1], self.weight[:, n1:], None) + return output1 + output2.unsqueeze(-2)