Move Harmonic embedding to core pytorch3d

Summary:
Moved `HarmonicEmbedding` function in core PyTorch3D.
In the next diff will update the NeRF project.

Reviewed By: bottler

Differential Revision: D32833808

fbshipit-source-id: 0a12ccd1627c0ce024463c796544c91eb8d4d122
This commit is contained in:
Nikhila Ravi 2021-12-21 15:03:33 -08:00 committed by Facebook GitHub Bot
parent d67662d13c
commit f9a26a22fc
4 changed files with 179 additions and 1 deletions

View File

@ -37,6 +37,7 @@ from .implicit import (
VolumeSampler, VolumeSampler,
ray_bundle_to_ray_points, ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points, ray_bundle_variables_to_ray_points,
HarmonicEmbedding,
) )
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
from .materials import Materials from .materials import Materials

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .harmonic_embedding import HarmonicEmbedding
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
@ -13,5 +14,4 @@ from .utils import (
ray_bundle_variables_to_ray_points, ray_bundle_variables_to_ray_points,
) )
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,127 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
class HarmonicEmbedding(torch.nn.Module):
def __init__(
self,
n_harmonic_functions: int = 6,
omega_0: float = 1.0,
logspace: bool = True,
append_input: bool = True,
) -> None:
"""
Given an input tensor `x` of shape [minibatch, ... , dim],
the harmonic embedding layer converts each feature
(i.e. vector along the last dimension) in `x`
into a series of harmonic features `embedding`,
where for each i in range(dim) the following are present
in embedding[...]:
```
[
sin(f_1*x[..., i]),
sin(f_2*x[..., i]),
...
sin(f_N * x[..., i]),
cos(f_1*x[..., i]),
cos(f_2*x[..., i]),
...
cos(f_N * x[..., i]),
x[..., i], # only present if append_input is True.
]
```
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
denoting the i-th frequency of the harmonic embedding.
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
powers of 2:
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
If `logspace==False`, frequencies are linearly spaced between
`1.0` and `2**(n_harmonic_functions-1)`:
`f_1, ..., f_N = torch.linspace(
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
)`
Note that `x` is also premultiplied by the base frequency `omega_0`
before evaluating the harmonic functions.
Args:
n_harmonic_functions: int, number of harmonic
features
omega_0: float, base frequency
logspace: bool, Whether to space the frequencies in
logspace or linear space
append_input: bool, whether to concat the original
input to the harmonic embedding. If true the
output is of the form (x, embed.sin(), embed.cos()
"""
super().__init__()
if logspace:
frequencies = 2.0 ** torch.arange(
n_harmonic_functions,
dtype=torch.float32,
)
else:
frequencies = torch.linspace(
1.0,
2.0 ** (n_harmonic_functions - 1),
n_harmonic_functions,
dtype=torch.float32,
)
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
self.append_input = append_input
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: tensor of shape [..., dim]
Returns:
embedding: a harmonic embedding of `x`
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
"""
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
embed = torch.cat(
(embed.sin(), embed.cos(), x)
if self.append_input
else (embed.sin(), embed.cos()),
dim=-1,
)
return embed
@staticmethod
def get_output_dim_static(
input_dims: int,
n_harmonic_functions: int,
append_input: bool,
) -> int:
"""
Utility to help predict the shape of the output of `forward`.
Args:
input_dims: length of the last dimension of the input tensor
n_harmonic_functions: number of embedding frequencies
append_input: whether or not to concat the original
input to the harmonic embedding
Returns:
int: the length of the last dimension of the output tensor
"""
return input_dims * (2 * n_harmonic_functions + int(append_input))
def get_output_dim(self, input_dims: int = 3) -> int:
"""
Same as above. The default for input_dims is 3 for 3D applications
which use harmonic embedding for positional encoding,
so the input might be xyz.
"""
return self.get_output_dim_static(
input_dims, len(self._frequencies), self.append_input
)

View File

@ -0,0 +1,50 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.implicit import HarmonicEmbedding
class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
def test_correct_output_dim(self):
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
# input_dims * (2 * n_harmonic_functions + int(append_input))
output_dim = 3 * (2 * 2 + int(False))
self.assertEqual(
output_dim,
embed_fun.get_output_dim_static(
input_dims=3, n_harmonic_functions=2, append_input=False
),
)
self.assertEqual(output_dim, embed_fun.get_output_dim())
def test_correct_frequency_range(self):
embed_fun_log = HarmonicEmbedding(n_harmonic_functions=3)
embed_fun_lin = HarmonicEmbedding(n_harmonic_functions=3, logspace=False)
self.assertClose(embed_fun_log._frequencies, torch.FloatTensor((1.0, 2.0, 4.0)))
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))
def test_correct_embed_out(self):
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
x = torch.randn((1, 5))
D = 5 * 4
embed_out = embed_fun(x)
self.assertEqual(embed_out.shape, (1, D))
# Sum the squares of the respective frequencies
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
self.assertClose(sum_squares, torch.ones((D // 2)))
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True)
embed_out = embed_fun(x)
self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5)))
# Last plane in output is the input
self.assertClose(embed_out[..., -5:], x)