mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
LinearWithRepeat layer for NeRF
Summary: Add custom layer to avoid repeating copied data for every ray position. This should also save time in the backward pass because there are fewer multiplies with the weights. Reviewed By: theschnitz Differential Revision: D28382412 fbshipit-source-id: 1ba7356cd8520ebd598568ae503e47d31d3989eb
This commit is contained in:
parent
fe39cc7b80
commit
ab73f8c3fd
@ -5,6 +5,7 @@ import torch
|
||||
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
|
||||
|
||||
from .harmonic_embedding import HarmonicEmbedding
|
||||
from .linear_with_repeat import LinearWithRepeat
|
||||
|
||||
|
||||
def _xavier_init(linear):
|
||||
@ -75,7 +76,7 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough
|
||||
|
||||
self.color_layer = torch.nn.Sequential(
|
||||
torch.nn.Linear(
|
||||
LinearWithRepeat(
|
||||
n_hidden_neurons_xyz + embedding_dim_dir, n_hidden_neurons_dir
|
||||
),
|
||||
torch.nn.ReLU(True),
|
||||
@ -116,7 +117,6 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
and evaluates the color model in order to attach to each
|
||||
point a 3D vector of its RGB color.
|
||||
"""
|
||||
spatial_size = features.shape[:-1]
|
||||
# Normalize the ray_directions to unit l2 norm.
|
||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||
|
||||
@ -129,18 +129,7 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Expand the ray directions tensor so that its spatial size
|
||||
# is equal to the size of features.
|
||||
rays_embedding_expand = rays_embedding[..., None, :].expand(
|
||||
*spatial_size, rays_embedding.shape[-1]
|
||||
)
|
||||
|
||||
# Concatenate ray direction embeddings with
|
||||
# features and evaluate the color model.
|
||||
color_layer_input = torch.cat(
|
||||
(self.intermediate_linear(features), rays_embedding_expand), dim=-1
|
||||
)
|
||||
return self.color_layer(color_layer_input)
|
||||
return self.color_layer((self.intermediate_linear(features), rays_embedding))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
52
projects/nerf/nerf/linear_with_repeat.py
Normal file
52
projects/nerf/nerf/linear_with_repeat.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LinearWithRepeat(torch.nn.Linear):
|
||||
"""
|
||||
if x has shape (..., k, n1)
|
||||
and y has shape (..., n2)
|
||||
then
|
||||
LinearWithRepeat(n1 + n2, out_features).forward((x,y))
|
||||
is equivalent to
|
||||
Linear(n1 + n2, out_features).forward(
|
||||
torch.cat([x, y.unsqueeze(-2).expand(..., k, n2)], dim=-1)
|
||||
)
|
||||
|
||||
Or visually:
|
||||
Given the following, for each ray,
|
||||
|
||||
feature ->
|
||||
|
||||
ray xxxxxxxx
|
||||
position xxxxxxxx
|
||||
| xxxxxxxx
|
||||
v xxxxxxxx
|
||||
|
||||
|
||||
and
|
||||
yyyyyyyy
|
||||
|
||||
where the y's do not depend on the position
|
||||
but only on the ray,
|
||||
we want to evaluate a Linear layer on both
|
||||
types of data at every position.
|
||||
|
||||
It's as if we constructed
|
||||
|
||||
xxxxxxxxyyyyyyyy
|
||||
xxxxxxxxyyyyyyyy
|
||||
xxxxxxxxyyyyyyyy
|
||||
xxxxxxxxyyyyyyyy
|
||||
|
||||
and sent that through the Linear.
|
||||
"""
|
||||
|
||||
def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
n1 = input[0].shape[-1]
|
||||
output1 = F.linear(input[0], self.weight[:, :n1], self.bias)
|
||||
output2 = F.linear(input[1], self.weight[:, n1:], None)
|
||||
return output1 + output2.unsqueeze(-2)
|
Loading…
x
Reference in New Issue
Block a user