mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Let harmonic embedding layer include input (NeRF)
Summary: When harmonic embedding is used, we always cat its input onto its output before proceeding. Avoid an intermediate tensor by making the module do that for itself. Reviewed By: davnov134 Differential Revision: D28185791 fbshipit-source-id: 98d92c94a918dd42e16cdadcaac71dabbc7de5c3
This commit is contained in:
parent
ab73f8c3fd
commit
f63e49d245
@ -8,14 +8,16 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
n_harmonic_functions: int = 6,
|
n_harmonic_functions: int = 6,
|
||||||
omega0: float = 1.0,
|
omega0: float = 1.0,
|
||||||
logspace: bool = True,
|
logspace: bool = True,
|
||||||
|
include_input: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given an input tensor `x` of shape [minibatch, ... , dim],
|
Given an input tensor `x` of shape [minibatch, ... , dim],
|
||||||
the harmonic embedding layer converts each feature
|
the harmonic embedding layer converts each feature
|
||||||
in `x` into a series of harmonic features `embedding`
|
in `x` into a series of harmonic features `embedding`,
|
||||||
as follows:
|
where for each i in range(dim) the following are present
|
||||||
|
in embedding[...]:
|
||||||
```
|
```
|
||||||
embedding[..., i*dim:(i+1)*dim] = [
|
[
|
||||||
sin(x[..., i]),
|
sin(x[..., i]),
|
||||||
sin(f_1*x[..., i]),
|
sin(f_1*x[..., i]),
|
||||||
sin(f_2*x[..., i]),
|
sin(f_2*x[..., i]),
|
||||||
@ -25,15 +27,18 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
cos(f_1*x[..., i]),
|
cos(f_1*x[..., i]),
|
||||||
cos(f_2*x[..., i]),
|
cos(f_2*x[..., i]),
|
||||||
...
|
...
|
||||||
cos(f_N * x[..., i])
|
cos(f_N * x[..., i]),
|
||||||
|
x[..., i] # only present if include_input is True.
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
where N corresponds to `n_harmonic_functions`, and f_i is a scalar
|
where N corresponds to `n_harmonic_functions`, and f_i is a scalar
|
||||||
denoting the i-th frequency of the harmonic embedding.
|
denoting the i-th frequency of the harmonic embedding.
|
||||||
|
The shape of the output is [minibatch, ... , dim * (2 * N + 1)] if
|
||||||
|
include_input is True, otherwise [minibatch, ... , dim * (2 * N)].
|
||||||
|
|
||||||
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
||||||
either powers of 2:
|
powers of 2:
|
||||||
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
`f_1 = 1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
||||||
|
|
||||||
If `logspace==False`, frequencies are linearly spaced between
|
If `logspace==False`, frequencies are linearly spaced between
|
||||||
`1.0` and `2**(n_harmonic_functions-1)`:
|
`1.0` and `2**(n_harmonic_functions-1)`:
|
||||||
@ -60,14 +65,19 @@ class HarmonicEmbedding(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.register_buffer("_frequencies", omega0 * frequencies)
|
self.register_buffer("_frequencies", omega0 * frequencies)
|
||||||
|
self.include_input = include_input
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: tensor of shape [..., dim]
|
x: tensor of shape [..., dim]
|
||||||
Returns:
|
Returns:
|
||||||
embedding: a harmonic embedding of `x`
|
embedding: a harmonic embedding of `x` of shape
|
||||||
of shape [..., n_harmonic_functions * dim * 2]
|
[..., dim * (n_harmonic_functions * 2 + T)] where
|
||||||
|
T is 1 if include_input is True and 0 otherwise.
|
||||||
"""
|
"""
|
||||||
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
|
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
|
||||||
|
if self.include_input:
|
||||||
|
return torch.cat((embed.sin(), embed.cos(), x), dim=-1)
|
||||||
|
else:
|
||||||
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
||||||
|
@ -121,13 +121,7 @@ class NeuralRadianceField(torch.nn.Module):
|
|||||||
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
|
||||||
|
|
||||||
# Obtain the harmonic embedding of the normalized ray directions.
|
# Obtain the harmonic embedding of the normalized ray directions.
|
||||||
rays_embedding = torch.cat(
|
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
|
||||||
(
|
|
||||||
self.harmonic_embedding_dir(rays_directions_normed),
|
|
||||||
rays_directions_normed,
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.color_layer((self.intermediate_linear(features), rays_embedding))
|
return self.color_layer((self.intermediate_linear(features), rays_embedding))
|
||||||
|
|
||||||
@ -168,10 +162,7 @@ class NeuralRadianceField(torch.nn.Module):
|
|||||||
# rays_points_world.shape = [minibatch x ... x 3]
|
# rays_points_world.shape = [minibatch x ... x 3]
|
||||||
|
|
||||||
# For each 3D world coordinate, we obtain its harmonic embedding.
|
# For each 3D world coordinate, we obtain its harmonic embedding.
|
||||||
embeds_xyz = torch.cat(
|
embeds_xyz = self.harmonic_embedding_xyz(rays_points_world)
|
||||||
(self.harmonic_embedding_xyz(rays_points_world), rays_points_world),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
# embeds_xyz.shape = [minibatch x ... x self.n_harmonic_functions*6 + 3]
|
# embeds_xyz.shape = [minibatch x ... x self.n_harmonic_functions*6 + 3]
|
||||||
|
|
||||||
# self.mlp maps each harmonic embedding to a latent feature space.
|
# self.mlp maps each harmonic embedding to a latent feature space.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user