mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Update Harmonic embedding in NeRF
Summary: Removed harmonic embedding function from projects/nerf and changed import to be from core pytorch3d. Reviewed By: patricklabatut Differential Revision: D33142358 fbshipit-source-id: 3004247d50392dbd04ea72e9cd4bace0dc03606b
This commit is contained in:
parent
f9a26a22fc
commit
52c71b8816
@ -1,88 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class HarmonicEmbedding(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_harmonic_functions: int = 6,
|
|
||||||
omega0: float = 1.0,
|
|
||||||
logspace: bool = True,
|
|
||||||
include_input: bool = True,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Given an input tensor `x` of shape [minibatch, ... , dim],
|
|
||||||
the harmonic embedding layer converts each feature
|
|
||||||
in `x` into a series of harmonic features `embedding`,
|
|
||||||
where for each i in range(dim) the following are present
|
|
||||||
in embedding[...]:
|
|
||||||
```
|
|
||||||
[
|
|
||||||
sin(x[..., i]),
|
|
||||||
sin(f_1*x[..., i]),
|
|
||||||
sin(f_2*x[..., i]),
|
|
||||||
...
|
|
||||||
sin(f_N * x[..., i]),
|
|
||||||
cos(x[..., i]),
|
|
||||||
cos(f_1*x[..., i]),
|
|
||||||
cos(f_2*x[..., i]),
|
|
||||||
...
|
|
||||||
cos(f_N * x[..., i]),
|
|
||||||
x[..., i] # only present if include_input is True.
|
|
||||||
]
|
|
||||||
```
|
|
||||||
where N corresponds to `n_harmonic_functions`, and f_i is a scalar
|
|
||||||
denoting the i-th frequency of the harmonic embedding.
|
|
||||||
The shape of the output is [minibatch, ... , dim * (2 * N + 1)] if
|
|
||||||
include_input is True, otherwise [minibatch, ... , dim * (2 * N)].
|
|
||||||
|
|
||||||
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
|
||||||
powers of 2:
|
|
||||||
`f_1 = 1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
|
||||||
|
|
||||||
If `logspace==False`, frequencies are linearly spaced between
|
|
||||||
`1.0` and `2**(n_harmonic_functions-1)`:
|
|
||||||
`f_1, ..., f_N = torch.linspace(
|
|
||||||
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
|
|
||||||
)`
|
|
||||||
|
|
||||||
Note that `x` is also premultiplied by the base frequency `omega0`
|
|
||||||
before evaluating the harmonic functions.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if logspace:
|
|
||||||
frequencies = 2.0 ** torch.arange(
|
|
||||||
n_harmonic_functions,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
frequencies = torch.linspace(
|
|
||||||
1.0,
|
|
||||||
2.0 ** (n_harmonic_functions - 1),
|
|
||||||
n_harmonic_functions,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
|
|
||||||
self.include_input = include_input
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: tensor of shape [..., dim]
|
|
||||||
Returns:
|
|
||||||
embedding: a harmonic embedding of `x` of shape
|
|
||||||
[..., dim * (n_harmonic_functions * 2 + T)] where
|
|
||||||
T is 1 if include_input is True and 0 otherwise.
|
|
||||||
"""
|
|
||||||
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
|
|
||||||
if self.include_input:
|
|
||||||
return torch.cat((embed.sin(), embed.cos(), x), dim=-1)
|
|
||||||
else:
|
|
||||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
|
@ -7,9 +7,8 @@
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
|
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points, HarmonicEmbedding
|
||||||
|
|
||||||
from .harmonic_embedding import HarmonicEmbedding
|
|
||||||
from .linear_with_repeat import LinearWithRepeat
|
from .linear_with_repeat import LinearWithRepeat
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user