From f63e49d2458f30169583c09e30aa47e9cdb02cf8 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 2 Jun 2021 05:42:15 -0700 Subject: [PATCH] Let harmonic embedding layer include input (NeRF) Summary: When harmonic embedding is used, we always cat its input onto its output before proceeding. Avoid an intermediate tensor by making the module do that for itself. Reviewed By: davnov134 Differential Revision: D28185791 fbshipit-source-id: 98d92c94a918dd42e16cdadcaac71dabbc7de5c3 --- projects/nerf/nerf/harmonic_embedding.py | 30 ++++++++++++++++-------- projects/nerf/nerf/implicit_function.py | 13 ++-------- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/projects/nerf/nerf/harmonic_embedding.py b/projects/nerf/nerf/harmonic_embedding.py index 02db628d..fc47d11b 100644 --- a/projects/nerf/nerf/harmonic_embedding.py +++ b/projects/nerf/nerf/harmonic_embedding.py @@ -8,14 +8,16 @@ class HarmonicEmbedding(torch.nn.Module): n_harmonic_functions: int = 6, omega0: float = 1.0, logspace: bool = True, + include_input: bool = True, ): """ Given an input tensor `x` of shape [minibatch, ... , dim], the harmonic embedding layer converts each feature - in `x` into a series of harmonic features `embedding` - as follows: + in `x` into a series of harmonic features `embedding`, + where for each i in range(dim) the following are present + in embedding[...]: ``` - embedding[..., i*dim:(i+1)*dim] = [ + [ sin(x[..., i]), sin(f_1*x[..., i]), sin(f_2*x[..., i]), @@ -25,17 +27,20 @@ class HarmonicEmbedding(torch.nn.Module): cos(f_1*x[..., i]), cos(f_2*x[..., i]), ... - cos(f_N * x[..., i]) + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. ] ``` where N corresponds to `n_harmonic_functions`, and f_i is a scalar denoting the i-th frequency of the harmonic embedding. + The shape of the output is [minibatch, ... , dim * (2 * N + 1)] if + include_input is True, otherwise [minibatch, ... , dim * (2 * N)]. If `logspace==True`, the frequencies `[f_1, ..., f_N]` are - either powers of 2: - `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)` + powers of 2: + `f_1 = 1, ..., f_N = 2**torch.arange(n_harmonic_functions)` - If `logspace==False`, frequencies are linearly spaced between + If `logspace==False`, frequencies are linearly spaced between `1.0` and `2**(n_harmonic_functions-1)`: `f_1, ..., f_N = torch.linspace( 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions @@ -60,14 +65,19 @@ class HarmonicEmbedding(torch.nn.Module): ) self.register_buffer("_frequencies", omega0 * frequencies) + self.include_input = include_input def forward(self, x: torch.Tensor): """ Args: x: tensor of shape [..., dim] Returns: - embedding: a harmonic embedding of `x` - of shape [..., n_harmonic_functions * dim * 2] + embedding: a harmonic embedding of `x` of shape + [..., dim * (n_harmonic_functions * 2 + T)] where + T is 1 if include_input is True and 0 otherwise. """ embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1) - return torch.cat((embed.sin(), embed.cos()), dim=-1) + if self.include_input: + return torch.cat((embed.sin(), embed.cos(), x), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) diff --git a/projects/nerf/nerf/implicit_function.py b/projects/nerf/nerf/implicit_function.py index 8589e11b..db5b37d5 100644 --- a/projects/nerf/nerf/implicit_function.py +++ b/projects/nerf/nerf/implicit_function.py @@ -121,13 +121,7 @@ class NeuralRadianceField(torch.nn.Module): rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1) # Obtain the harmonic embedding of the normalized ray directions. - rays_embedding = torch.cat( - ( - self.harmonic_embedding_dir(rays_directions_normed), - rays_directions_normed, - ), - dim=-1, - ) + rays_embedding = self.harmonic_embedding_dir(rays_directions_normed) return self.color_layer((self.intermediate_linear(features), rays_embedding)) @@ -168,10 +162,7 @@ class NeuralRadianceField(torch.nn.Module): # rays_points_world.shape = [minibatch x ... x 3] # For each 3D world coordinate, we obtain its harmonic embedding. - embeds_xyz = torch.cat( - (self.harmonic_embedding_xyz(rays_points_world), rays_points_world), - dim=-1, - ) + embeds_xyz = self.harmonic_embedding_xyz(rays_points_world) # embeds_xyz.shape = [minibatch x ... x self.n_harmonic_functions*6 + 3] # self.mlp maps each harmonic embedding to a latent feature space.