LinearWithRepeat layer for NeRF

Summary:
Add custom layer to avoid repeating copied data for every ray position.

This should also save time in the backward pass because there are fewer multiplies with the weights.

Reviewed By: theschnitz

Differential Revision: D28382412

fbshipit-source-id: 1ba7356cd8520ebd598568ae503e47d31d3989eb
This commit is contained in:
Jeremy Reizenstein 2021-06-02 05:42:15 -07:00 committed by Facebook GitHub Bot
parent fe39cc7b80
commit ab73f8c3fd
2 changed files with 55 additions and 14 deletions

View File

@ -5,6 +5,7 @@ import torch
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
from .harmonic_embedding import HarmonicEmbedding
from .linear_with_repeat import LinearWithRepeat
def _xavier_init(linear):
@ -75,7 +76,7 @@ class NeuralRadianceField(torch.nn.Module):
self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough
self.color_layer = torch.nn.Sequential(
torch.nn.Linear(
LinearWithRepeat(
n_hidden_neurons_xyz + embedding_dim_dir, n_hidden_neurons_dir
),
torch.nn.ReLU(True),
@ -116,7 +117,6 @@ class NeuralRadianceField(torch.nn.Module):
and evaluates the color model in order to attach to each
point a 3D vector of its RGB color.
"""
spatial_size = features.shape[:-1]
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
@ -129,18 +129,7 @@ class NeuralRadianceField(torch.nn.Module):
dim=-1,
)
# Expand the ray directions tensor so that its spatial size
# is equal to the size of features.
rays_embedding_expand = rays_embedding[..., None, :].expand(
*spatial_size, rays_embedding.shape[-1]
)
# Concatenate ray direction embeddings with
# features and evaluate the color model.
color_layer_input = torch.cat(
(self.intermediate_linear(features), rays_embedding_expand), dim=-1
)
return self.color_layer(color_layer_input)
return self.color_layer((self.intermediate_linear(features), rays_embedding))
def forward(
self,

View File

@ -0,0 +1,52 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple
import torch
import torch.nn.functional as F
class LinearWithRepeat(torch.nn.Linear):
"""
if x has shape (..., k, n1)
and y has shape (..., n2)
then
LinearWithRepeat(n1 + n2, out_features).forward((x,y))
is equivalent to
Linear(n1 + n2, out_features).forward(
torch.cat([x, y.unsqueeze(-2).expand(..., k, n2)], dim=-1)
)
Or visually:
Given the following, for each ray,
feature ->
ray xxxxxxxx
position xxxxxxxx
| xxxxxxxx
v xxxxxxxx
and
yyyyyyyy
where the y's do not depend on the position
but only on the ray,
we want to evaluate a Linear layer on both
types of data at every position.
It's as if we constructed
xxxxxxxxyyyyyyyy
xxxxxxxxyyyyyyyy
xxxxxxxxyyyyyyyy
xxxxxxxxyyyyyyyy
and sent that through the Linear.
"""
def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
n1 = input[0].shape[-1]
output1 = F.linear(input[0], self.weight[:, :n1], self.bias)
output2 = F.linear(input[1], self.weight[:, n1:], None)
return output1 + output2.unsqueeze(-2)