mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Move Harmonic embedding to core pytorch3d
Summary: Moved `HarmonicEmbedding` function in core PyTorch3D. In the next diff will update the NeRF project. Reviewed By: bottler Differential Revision: D32833808 fbshipit-source-id: 0a12ccd1627c0ce024463c796544c91eb8d4d122
This commit is contained in:
parent
d67662d13c
commit
f9a26a22fc
@ -37,6 +37,7 @@ from .implicit import (
|
||||
VolumeSampler,
|
||||
ray_bundle_to_ray_points,
|
||||
ray_bundle_variables_to_ray_points,
|
||||
HarmonicEmbedding,
|
||||
)
|
||||
from .lighting import AmbientLights, DirectionalLights, PointLights, diffuse, specular
|
||||
from .materials import Materials
|
||||
|
@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .harmonic_embedding import HarmonicEmbedding
|
||||
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
|
||||
@ -13,5 +14,4 @@ from .utils import (
|
||||
ray_bundle_variables_to_ray_points,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
127
pytorch3d/renderer/implicit/harmonic_embedding.py
Normal file
127
pytorch3d/renderer/implicit/harmonic_embedding.py
Normal file
@ -0,0 +1,127 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class HarmonicEmbedding(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_harmonic_functions: int = 6,
|
||||
omega_0: float = 1.0,
|
||||
logspace: bool = True,
|
||||
append_input: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Given an input tensor `x` of shape [minibatch, ... , dim],
|
||||
the harmonic embedding layer converts each feature
|
||||
(i.e. vector along the last dimension) in `x`
|
||||
into a series of harmonic features `embedding`,
|
||||
where for each i in range(dim) the following are present
|
||||
in embedding[...]:
|
||||
```
|
||||
[
|
||||
sin(f_1*x[..., i]),
|
||||
sin(f_2*x[..., i]),
|
||||
...
|
||||
sin(f_N * x[..., i]),
|
||||
cos(f_1*x[..., i]),
|
||||
cos(f_2*x[..., i]),
|
||||
...
|
||||
cos(f_N * x[..., i]),
|
||||
x[..., i], # only present if append_input is True.
|
||||
]
|
||||
```
|
||||
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
|
||||
denoting the i-th frequency of the harmonic embedding.
|
||||
|
||||
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
||||
powers of 2:
|
||||
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
||||
|
||||
If `logspace==False`, frequencies are linearly spaced between
|
||||
`1.0` and `2**(n_harmonic_functions-1)`:
|
||||
`f_1, ..., f_N = torch.linspace(
|
||||
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
|
||||
)`
|
||||
|
||||
Note that `x` is also premultiplied by the base frequency `omega_0`
|
||||
before evaluating the harmonic functions.
|
||||
|
||||
Args:
|
||||
n_harmonic_functions: int, number of harmonic
|
||||
features
|
||||
omega_0: float, base frequency
|
||||
logspace: bool, Whether to space the frequencies in
|
||||
logspace or linear space
|
||||
append_input: bool, whether to concat the original
|
||||
input to the harmonic embedding. If true the
|
||||
output is of the form (x, embed.sin(), embed.cos()
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if logspace:
|
||||
frequencies = 2.0 ** torch.arange(
|
||||
n_harmonic_functions,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
else:
|
||||
frequencies = torch.linspace(
|
||||
1.0,
|
||||
2.0 ** (n_harmonic_functions - 1),
|
||||
n_harmonic_functions,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
|
||||
self.append_input = append_input
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
Returns:
|
||||
embedding: a harmonic embedding of `x`
|
||||
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
|
||||
"""
|
||||
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
|
||||
embed = torch.cat(
|
||||
(embed.sin(), embed.cos(), x)
|
||||
if self.append_input
|
||||
else (embed.sin(), embed.cos()),
|
||||
dim=-1,
|
||||
)
|
||||
return embed
|
||||
|
||||
@staticmethod
|
||||
def get_output_dim_static(
|
||||
input_dims: int,
|
||||
n_harmonic_functions: int,
|
||||
append_input: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Utility to help predict the shape of the output of `forward`.
|
||||
|
||||
Args:
|
||||
input_dims: length of the last dimension of the input tensor
|
||||
n_harmonic_functions: number of embedding frequencies
|
||||
append_input: whether or not to concat the original
|
||||
input to the harmonic embedding
|
||||
Returns:
|
||||
int: the length of the last dimension of the output tensor
|
||||
"""
|
||||
return input_dims * (2 * n_harmonic_functions + int(append_input))
|
||||
|
||||
def get_output_dim(self, input_dims: int = 3) -> int:
|
||||
"""
|
||||
Same as above. The default for input_dims is 3 for 3D applications
|
||||
which use harmonic embedding for positional encoding,
|
||||
so the input might be xyz.
|
||||
"""
|
||||
return self.get_output_dim_static(
|
||||
input_dims, len(self._frequencies), self.append_input
|
||||
)
|
50
tests/test_harmonic_embedding.py
Normal file
50
tests/test_harmonic_embedding.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
|
||||
|
||||
class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
def test_correct_output_dim(self):
|
||||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
|
||||
# input_dims * (2 * n_harmonic_functions + int(append_input))
|
||||
output_dim = 3 * (2 * 2 + int(False))
|
||||
self.assertEqual(
|
||||
output_dim,
|
||||
embed_fun.get_output_dim_static(
|
||||
input_dims=3, n_harmonic_functions=2, append_input=False
|
||||
),
|
||||
)
|
||||
self.assertEqual(output_dim, embed_fun.get_output_dim())
|
||||
|
||||
def test_correct_frequency_range(self):
|
||||
embed_fun_log = HarmonicEmbedding(n_harmonic_functions=3)
|
||||
embed_fun_lin = HarmonicEmbedding(n_harmonic_functions=3, logspace=False)
|
||||
self.assertClose(embed_fun_log._frequencies, torch.FloatTensor((1.0, 2.0, 4.0)))
|
||||
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))
|
||||
|
||||
def test_correct_embed_out(self):
|
||||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
|
||||
x = torch.randn((1, 5))
|
||||
D = 5 * 4
|
||||
embed_out = embed_fun(x)
|
||||
self.assertEqual(embed_out.shape, (1, D))
|
||||
# Sum the squares of the respective frequencies
|
||||
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
|
||||
self.assertClose(sum_squares, torch.ones((D // 2)))
|
||||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True)
|
||||
embed_out = embed_fun(x)
|
||||
self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5)))
|
||||
# Last plane in output is the input
|
||||
self.assertClose(embed_out[..., -5:], x)
|
Loading…
x
Reference in New Issue
Block a user