pytorch3d/projects/nerf/nerf/harmonic_embedding.py
Jeremy Reizenstein 1b8d86a104 (breaking) image_size-agnostic GridRaySampler
Summary:
As suggested in #802. By not persisting the _xy_grid buffer, we can allow (in some cases) a model with one image_size to be loaded from a saved model which was trained at a different resolution.

Also avoid persisting _frequencies in HarmonicEmbedding for similar reasons.

BC-break: This will cause load_state_dict, in strict mode, to complain if you try to load an old model with the new code.

Reviewed By: patricklabatut

Differential Revision: D30349234

fbshipit-source-id: d6061d1e51c9f79a78d61a9f732c9a5dfadbbb47
2021-08-31 14:30:24 -07:00

89 lines
3.1 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
class HarmonicEmbedding(torch.nn.Module):
def __init__(
self,
n_harmonic_functions: int = 6,
omega0: float = 1.0,
logspace: bool = True,
include_input: bool = True,
) -> None:
"""
Given an input tensor `x` of shape [minibatch, ... , dim],
the harmonic embedding layer converts each feature
in `x` into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]:
```
[
sin(x[..., i]),
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i] # only present if include_input is True.
]
```
where N corresponds to `n_harmonic_functions`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
The shape of the output is [minibatch, ... , dim * (2 * N + 1)] if
include_input is True, otherwise [minibatch, ... , dim * (2 * N)].
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
powers of 2:
`f_1 = 1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
If `logspace==False`, frequencies are linearly spaced between
`1.0` and `2**(n_harmonic_functions-1)`:
`f_1, ..., f_N = torch.linspace(
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
)`
Note that `x` is also premultiplied by the base frequency `omega0`
before evaluating the harmonic functions.
"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(
n_harmonic_functions,
dtype=torch.float32,
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (n_harmonic_functions - 1),
n_harmonic_functions,
dtype=torch.float32,
)
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
self.include_input = include_input
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
Returns:
embedding: a harmonic embedding of `x` of shape
[..., dim * (n_harmonic_functions * 2 + T)] where
T is 1 if include_input is True and 0 otherwise.
"""
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
if self.include_input:
return torch.cat((embed.sin(), embed.cos(), x), dim=-1)
else:
return torch.cat((embed.sin(), embed.cos()), dim=-1)