mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Add integrated position encoding based on MIPNerf implementation.
Summary: Add a new implicit module Integral Position Encoding based on [MIP-NeRF](https://arxiv.org/abs/2103.13415). Reviewed By: shapovalov Differential Revision: D46352730 fbshipit-source-id: c6a56134c975d80052b3a11f5e92fd7d95cbff1e
This commit is contained in:
parent
29b8ebd802
commit
ccf860f1db
@ -361,6 +361,7 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_hidden_neurons_dir: 128
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
use_integrated_positional_encoding: false
|
||||
transformer_dim_down_factor: 2.0
|
||||
n_hidden_neurons_xyz: 80
|
||||
n_layers_xyz: 2
|
||||
@ -372,6 +373,7 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_hidden_neurons_dir: 128
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
use_integrated_positional_encoding: false
|
||||
transformer_dim_down_factor: 1.0
|
||||
n_hidden_neurons_xyz: 256
|
||||
n_layers_xyz: 8
|
||||
@ -741,6 +743,7 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_hidden_neurons_dir: 128
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
use_integrated_positional_encoding: false
|
||||
transformer_dim_down_factor: 2.0
|
||||
n_hidden_neurons_xyz: 80
|
||||
n_layers_xyz: 2
|
||||
@ -752,6 +755,7 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_hidden_neurons_dir: 128
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
use_integrated_positional_encoding: false
|
||||
transformer_dim_down_factor: 1.0
|
||||
n_hidden_neurons_xyz: 256
|
||||
n_layers_xyz: 8
|
||||
@ -979,6 +983,7 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_hidden_neurons_dir: 128
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
use_integrated_positional_encoding: false
|
||||
transformer_dim_down_factor: 2.0
|
||||
n_hidden_neurons_xyz: 80
|
||||
n_layers_xyz: 2
|
||||
@ -990,6 +995,7 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_hidden_neurons_dir: 128
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
use_integrated_positional_encoding: false
|
||||
transformer_dim_down_factor: 1.0
|
||||
n_hidden_neurons_xyz: 256
|
||||
n_layers_xyz: 8
|
||||
|
@ -9,11 +9,14 @@ from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.models.renderer.base import (
|
||||
conical_frustum_to_gaussian,
|
||||
ImplicitronRayBundle,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points
|
||||
|
||||
from .base import ImplicitFunctionBase
|
||||
|
||||
@ -36,6 +39,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
input_xyz: bool = True
|
||||
xyz_ray_dir_in_camera_coords: bool = False
|
||||
color_dim: int = 3
|
||||
use_integrated_positional_encoding: bool = False
|
||||
"""
|
||||
Args:
|
||||
n_harmonic_functions_xyz: The number of harmonic functions
|
||||
@ -53,6 +57,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
n_layers_xyz: The number of layers of the MLP that outputs the
|
||||
occupancy field.
|
||||
append_xyz: The list of indices of the skip layers of the occupancy MLP.
|
||||
use_integrated_positional_encoding: If True, use integrated positional enoding
|
||||
as defined in `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
||||
If False, use the classical harmonic embedding
|
||||
defined in `NeRF <https://arxiv.org/abs/2003.08934>`_.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
@ -149,6 +157,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
containing the direction vectors of sampling rays in world coords.
|
||||
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||
containing the lengths at which the rays are sampled.
|
||||
bins: An optional tensor of shape `(minibatch,..., num_points_per_ray + 1)`
|
||||
containing the bins at which the rays are sampled. In this case
|
||||
lengths is equal to the midpoints of bins.
|
||||
|
||||
fun_viewpool: an optional callback with the signature
|
||||
fun_fiewpool(points) -> pooled_features
|
||||
where points is a [N_TGT x N x 3] tensor of world coords,
|
||||
@ -160,11 +172,22 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
denoting the opacitiy of each ray point.
|
||||
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
|
||||
denoting the color of each ray point.
|
||||
|
||||
Raises:
|
||||
ValueError: If `use_integrated_positional_encoding` is True and
|
||||
`ray_bundle.bins` is None.
|
||||
"""
|
||||
# We first convert the ray parametrizations to world
|
||||
# coordinates with `ray_bundle_to_ray_points`.
|
||||
# pyre-ignore[6]
|
||||
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||
if self.use_integrated_positional_encoding and ray_bundle.bins is None:
|
||||
raise ValueError(
|
||||
"When use_integrated_positional_encoding is True, ray_bundle.bins must be set."
|
||||
"Have you set to True `AbstractMaskRaySampler.use_bins_for_ray_sampling`?"
|
||||
)
|
||||
|
||||
rays_points_world, diag_cov = (
|
||||
conical_frustum_to_gaussian(ray_bundle)
|
||||
if self.use_integrated_positional_encoding
|
||||
else (ray_bundle_to_ray_points(ray_bundle), None) # pyre-ignore
|
||||
)
|
||||
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
|
||||
|
||||
embeds = create_embeddings_for_implicit_function(
|
||||
@ -177,6 +200,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
fun_viewpool=fun_viewpool,
|
||||
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
|
||||
camera=camera,
|
||||
diag_cov=diag_cov,
|
||||
)
|
||||
|
||||
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
|
||||
|
@ -36,6 +36,7 @@ def create_embeddings_for_implicit_function(
|
||||
camera: Optional[CamerasBase],
|
||||
fun_viewpool: Optional[Callable],
|
||||
xyz_embedding_function: Optional[Callable],
|
||||
diag_cov: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
|
||||
@ -59,11 +60,11 @@ def create_embeddings_for_implicit_function(
|
||||
prod(spatial_size),
|
||||
pts_per_ray,
|
||||
0,
|
||||
dtype=xyz_world.dtype,
|
||||
device=xyz_world.device,
|
||||
)
|
||||
else:
|
||||
embeds = xyz_embedding_function(ray_points_for_embed).reshape(
|
||||
|
||||
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
|
||||
embeds = embeds.reshape(
|
||||
bs,
|
||||
1,
|
||||
prod(spatial_size),
|
||||
|
@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@ -16,8 +18,18 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
append_input: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Given an input tensor `x` of shape [minibatch, ... , dim],
|
||||
the harmonic embedding layer converts each feature
|
||||
The harmonic embedding layer supports the classical
|
||||
Nerf positional encoding described in
|
||||
`NeRF <https://arxiv.org/abs/2003.08934>`_
|
||||
and the integrated position encoding in
|
||||
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
||||
|
||||
During, the inference you can provide the extra argument `diag_cov`.
|
||||
|
||||
If `diag_cov is None`, it converts
|
||||
rays parametrized with a `ray_bundle` to 3D points by
|
||||
extending each ray according to the corresponding length.
|
||||
Then it converts each feature
|
||||
(i.e. vector along the last dimension) in `x`
|
||||
into a series of harmonic features `embedding`,
|
||||
where for each i in range(dim) the following are present
|
||||
@ -38,6 +50,31 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
|
||||
denoting the i-th frequency of the harmonic embedding.
|
||||
|
||||
|
||||
If `diag_cov is not None`, it approximates
|
||||
conical frustums following a ray bundle as gaussians,
|
||||
defined by x, the means of the gaussians and diag_cov,
|
||||
the diagonal covariances.
|
||||
Then it converts each gaussian
|
||||
into a series of harmonic features `embedding`,
|
||||
where for each i in range(dim) the following are present
|
||||
in embedding[...]::
|
||||
|
||||
[
|
||||
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
||||
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
|
||||
...
|
||||
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
||||
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
||||
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
|
||||
...
|
||||
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
||||
x[..., i], # only present if append_input is True.
|
||||
]
|
||||
|
||||
where N equals `n_harmonic_functions-1`, and f_i is a scalar
|
||||
denoting the i-th frequency of the harmonic embedding.
|
||||
|
||||
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
||||
powers of 2:
|
||||
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
||||
@ -59,8 +96,7 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
logspace or linear space
|
||||
append_input: bool, whether to concat the original
|
||||
input to the harmonic embedding. If true the
|
||||
output is of the form (x, embed.sin(), embed.cos()
|
||||
|
||||
output is of the form (embed.sin(), embed.cos(), x)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -78,23 +114,42 @@ class HarmonicEmbedding(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
|
||||
self.register_buffer(
|
||||
"_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False
|
||||
)
|
||||
self.append_input = append_input
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: tensor of shape [..., dim]
|
||||
diag_cov: An optional tensor of shape `(..., dim)`
|
||||
representing the diagonal covariance matrices of our Gaussians, joined with x
|
||||
as means of the Gaussians.
|
||||
|
||||
Returns:
|
||||
embedding: a harmonic embedding of `x`
|
||||
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
|
||||
embedding: a harmonic embedding of `x` of shape
|
||||
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
|
||||
"""
|
||||
embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
|
||||
embed = torch.cat(
|
||||
(embed.sin(), embed.cos(), x)
|
||||
if self.append_input
|
||||
else (embed.sin(), embed.cos()),
|
||||
dim=-1,
|
||||
)
|
||||
# [..., dim, n_harmonic_functions]
|
||||
embed = x[..., None] * self._frequencies
|
||||
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
|
||||
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
|
||||
# Use the trig identity cos(x) = sin(x + pi/2)
|
||||
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
|
||||
embed = embed.sin()
|
||||
if diag_cov is not None:
|
||||
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
|
||||
exp_var = torch.exp(-0.5 * x_var)
|
||||
# [..., 2, dim, n_harmonic_functions]
|
||||
embed = embed * exp_var[..., None, :, :]
|
||||
|
||||
embed = embed.reshape(*x.shape[:-1], -1)
|
||||
|
||||
if self.append_input:
|
||||
return torch.cat([embed, x], dim=-1)
|
||||
return embed
|
||||
|
||||
@staticmethod
|
||||
|
@ -0,0 +1,66 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.implicit_function.base import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
|
||||
NeuralRadianceFieldImplicitFunction,
|
||||
)
|
||||
|
||||
|
||||
class TestNeuralRadianceFieldImplicitFunction(unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_forward_with_integrated_positionial_embedding(self):
|
||||
shape = [2, 4, 4]
|
||||
ray_bundle = ImplicitronRayBundle(
|
||||
origins=torch.randn(*shape, 3),
|
||||
directions=torch.randn(*shape, 3),
|
||||
bins=torch.randn(*shape, 6 + 1),
|
||||
lengths=torch.randn(*shape, 6),
|
||||
pixel_radii_2d=torch.randn(*shape, 1),
|
||||
xys=None,
|
||||
)
|
||||
model = NeuralRadianceFieldImplicitFunction(
|
||||
n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
|
||||
)
|
||||
raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)
|
||||
|
||||
self.assertEqual(raw_densities.shape, (*shape, ray_bundle.lengths.shape[-1], 1))
|
||||
self.assertEqual(ray_colors.shape, (*shape, ray_bundle.lengths.shape[-1], 3))
|
||||
|
||||
def test_forward_with_integrated_positionial_embedding_raise_exception(self):
|
||||
shape = [2, 4, 4]
|
||||
ray_bundle = ImplicitronRayBundle(
|
||||
origins=torch.randn(*shape, 3),
|
||||
directions=torch.randn(*shape, 3),
|
||||
bins=None,
|
||||
lengths=torch.randn(*shape, 6),
|
||||
pixel_radii_2d=torch.randn(*shape, 1),
|
||||
xys=None,
|
||||
)
|
||||
model = NeuralRadianceFieldImplicitFunction(
|
||||
n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
|
||||
)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = model(ray_bundle=ray_bundle)
|
||||
|
||||
def test_forward(self):
|
||||
shape = [2, 4, 4]
|
||||
ray_bundle = ImplicitronRayBundle(
|
||||
origins=torch.randn(*shape, 3),
|
||||
directions=torch.randn(*shape, 3),
|
||||
lengths=torch.randn(*shape, 6),
|
||||
pixel_radii_2d=torch.randn(*shape, 1),
|
||||
xys=None,
|
||||
)
|
||||
model = NeuralRadianceFieldImplicitFunction(n_hidden_neurons_dir=32)
|
||||
raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)
|
||||
self.assertEqual(raw_densities.shape, (*shape, 6, 1))
|
||||
self.assertEqual(ray_colors.shape, (*shape, 6, 3))
|
@ -8,6 +8,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
from torch.distributions import MultivariateNormal
|
||||
|
||||
from .common_testing import TestCaseMixin
|
||||
|
||||
@ -36,16 +37,117 @@ class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))
|
||||
|
||||
def test_correct_embed_out(self):
|
||||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
|
||||
n_harmonic_functions = 2
|
||||
x = torch.randn((1, 5))
|
||||
D = 5 * 4
|
||||
D = 5 * n_harmonic_functions * 2 # sin + cos
|
||||
|
||||
embed_fun = HarmonicEmbedding(
|
||||
n_harmonic_functions=n_harmonic_functions, append_input=False
|
||||
)
|
||||
embed_out = embed_fun(x)
|
||||
|
||||
self.assertEqual(embed_out.shape, (1, D))
|
||||
# Sum the squares of the respective frequencies
|
||||
# cos^2(x) + sin^2(x) = 1
|
||||
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
|
||||
self.assertClose(sum_squares, torch.ones((D // 2)))
|
||||
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True)
|
||||
embed_out = embed_fun(x)
|
||||
self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5)))
|
||||
|
||||
# Test append input
|
||||
embed_fun = HarmonicEmbedding(
|
||||
n_harmonic_functions=n_harmonic_functions, append_input=True
|
||||
)
|
||||
embed_out_appended_input = embed_fun(x)
|
||||
self.assertClose(
|
||||
embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1]))
|
||||
)
|
||||
# Last plane in output is the input
|
||||
self.assertClose(embed_out[..., -5:], x)
|
||||
self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x)
|
||||
self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out)
|
||||
|
||||
def test_correct_embed_out_with_diag_cov(self):
|
||||
n_harmonic_functions = 2
|
||||
x = torch.randn((1, 3))
|
||||
diag_cov = torch.randn((1, 3))
|
||||
D = 3 * n_harmonic_functions * 2 # sin + cos
|
||||
|
||||
embed_fun = HarmonicEmbedding(
|
||||
n_harmonic_functions=n_harmonic_functions, append_input=False
|
||||
)
|
||||
embed_out = embed_fun(x, diag_cov=diag_cov)
|
||||
|
||||
self.assertEqual(embed_out.shape, (1, D))
|
||||
|
||||
# Compute the scaling factor introduce in MipNerf
|
||||
scale_factor = (
|
||||
-0.5 * diag_cov[..., None] * torch.pow(embed_fun._frequencies[None, :], 2)
|
||||
)
|
||||
scale_factor = torch.exp(scale_factor).reshape(1, -1).tile((1, 2))
|
||||
# If we remove this scaling factor, we should go back to the
|
||||
# classical harmonic embedding:
|
||||
# Sum the squares of the respective frequencies
|
||||
# cos^2(x) + sin^2(x) = 1
|
||||
embed_out_without_cov = embed_out / scale_factor
|
||||
sum_squares = (
|
||||
embed_out_without_cov[0, : D // 2] ** 2
|
||||
+ embed_out_without_cov[0, D // 2 :] ** 2
|
||||
)
|
||||
self.assertClose(sum_squares, torch.ones((D // 2)))
|
||||
|
||||
# Test append input
|
||||
embed_fun = HarmonicEmbedding(
|
||||
n_harmonic_functions=n_harmonic_functions, append_input=True
|
||||
)
|
||||
embed_out_appended_input = embed_fun(x, diag_cov=diag_cov)
|
||||
self.assertClose(
|
||||
embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1]))
|
||||
)
|
||||
# Last plane in output is the input
|
||||
self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x)
|
||||
self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out)
|
||||
|
||||
def test_correct_behavior_between_ipe_and_its_estimation_from_harmonic_embedding(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
Check that the HarmonicEmbedding with integrated_position_encoding (IPE) set to
|
||||
True is coherent with the HarmonicEmbedding.
|
||||
|
||||
What is the idea behind this test?
|
||||
|
||||
We wish to produce an IPE that is the expectation
|
||||
of our lifted multivariate gaussian, modulated by the sine and cosine of
|
||||
the coordinates. These expectation has a closed-form
|
||||
(see equations 11, 12, 13, 14 of [1]).
|
||||
|
||||
We sample N elements from the multivariate gaussian defined by its mean and covariance
|
||||
and compute the HarmonicEmbedding. The expected value of those embeddings should be
|
||||
equal to our IPE.
|
||||
|
||||
Inspired from:
|
||||
https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/mip_test.py#L359
|
||||
|
||||
References:
|
||||
[1] `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
||||
"""
|
||||
num_dims = 3
|
||||
n_harmonic_functions = 6
|
||||
mean = torch.randn(num_dims)
|
||||
diag_cov = torch.rand(num_dims)
|
||||
|
||||
he_fun = HarmonicEmbedding(
|
||||
n_harmonic_functions=n_harmonic_functions, logspace=True, append_input=False
|
||||
)
|
||||
ipe_fun = HarmonicEmbedding(
|
||||
n_harmonic_functions=n_harmonic_functions,
|
||||
append_input=False,
|
||||
)
|
||||
|
||||
embedding_ipe = ipe_fun(mean, diag_cov=diag_cov)
|
||||
|
||||
rand_mvn = MultivariateNormal(mean, torch.eye(num_dims) * diag_cov)
|
||||
|
||||
# Providing a large enough number of samples
|
||||
# we should obtain an estimation close to our IPE
|
||||
num_samples = 100000
|
||||
embedding_he = he_fun(rand_mvn.sample_n(num_samples))
|
||||
self.assertClose(embedding_he.mean(0), embedding_ipe, rtol=1e-2, atol=1e-2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user