mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Let harmonic embedding layer include input (NeRF)
Summary: When harmonic embedding is used, we always cat its input onto its output before proceeding. Avoid an intermediate tensor by making the module do that for itself. Reviewed By: davnov134 Differential Revision: D28185791 fbshipit-source-id: 98d92c94a918dd42e16cdadcaac71dabbc7de5c3
This commit is contained in:
parent
ab73f8c3fd
commit
f63e49d245
@ -8,14 +8,16 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
n_harmonic_functions: int = 6,
|
||||
omega0: float = 1.0,
|
||||
logspace: bool = True,
|
||||
include_input: bool = True,
|
||||
):
|
||||
"""
|
||||
Given an input tensor `x` of shape [minibatch, ... , dim],
|
||||
the harmonic embedding layer converts each feature
|
||||
in `x` into a series of harmonic features `embedding`
|
||||
as follows:
|
||||
in `x` into a series of harmonic features `embedding`,
|
||||
where for each i in range(dim) the following are present
|
||||
in embedding[...]:
|
||||
```
|
||||
embedding[..., i*dim:(i+1)*dim] = [
|
||||
[
|
||||
sin(x[..., i]),
|
||||
sin(f_1*x[..., i]),
|
||||
sin(f_2*x[..., i]),
|
||||
@ -25,17 +27,20 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
cos(f_1*x[..., i]),
|
||||
cos(f_2*x[..., i]),
|
||||
...
|
||||
cos(f_N * x[..., i])
|
||||
cos(f_N * x[..., i]),
|
||||
x[..., i] # only present if include_input is True.
|
||||
]
|
||||
```
|
||||
where N corresponds to `n_harmonic_functions`, and f_i is a scalar
|
||||
denoting the i-th frequency of the harmonic embedding.
|
||||
The shape of the output is [minibatch, ... , dim * (2 * N + 1)] if
|
||||
include_input is True, otherwise [minibatch, ... , dim * (2 * N)].
|
||||
|
||||
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
||||
either powers of 2:
|
||||
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
||||
powers of 2:
|
||||
`f_1 = 1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
||||
|
||||
If `logspace==False`, frequencies are linearly spaced between
|
||||
If `logspace==False`, frequencies are linearly spaced between
|
||||
`1.0` and `2**(n_harmonic_functions-1)`:
|
||||
`f_1, ..., f_N = torch.linspace(
|
||||
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
|
||||
@ -60,14 +65,19 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.register_buffer("_frequencies", omega0 * frequencies)
|
||||
self.include_input = include_input
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
Returns:
|
||||
embedding: a harmonic embedding of `x`
|
||||
of shape [..., n_harmonic_functions * dim * 2]
|
||||
embedding: a harmonic embedding of `x` of shape
|
||||
[..., dim * (n_harmonic_functions * 2 + T)] where
|
||||
T is 1 if include_input is True and 0 otherwise.
|
||||
"""
|
||||
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
|
||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||
if self.include_input:
|
||||
return torch.cat((embed.sin(), embed.cos(), x), dim=-1)
|
||||
else:
|
||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||
|
@ -121,13 +121,7 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||
|
||||
# Obtain the harmonic embedding of the normalized ray directions.
|
||||
rays_embedding = torch.cat(
|
||||
(
|
||||
self.harmonic_embedding_dir(rays_directions_normed),
|
||||
rays_directions_normed,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
|
||||
|
||||
return self.color_layer((self.intermediate_linear(features), rays_embedding))
|
||||
|
||||
@ -168,10 +162,7 @@ class NeuralRadianceField(torch.nn.Module):
|
||||
# rays_points_world.shape = [minibatch x ... x 3]
|
||||
|
||||
# For each 3D world coordinate, we obtain its harmonic embedding.
|
||||
embeds_xyz = torch.cat(
|
||||
(self.harmonic_embedding_xyz(rays_points_world), rays_points_world),
|
||||
dim=-1,
|
||||
)
|
||||
embeds_xyz = self.harmonic_embedding_xyz(rays_points_world)
|
||||
# embeds_xyz.shape = [minibatch x ... x self.n_harmonic_functions*6 + 3]
|
||||
|
||||
# self.mlp maps each harmonic embedding to a latent feature space.
|
||||
|
Loading…
x
Reference in New Issue
Block a user