mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Add integrated position encoding based on MIPNerf implementation.
Summary: Add a new implicit module Integral Position Encoding based on [MIP-NeRF](https://arxiv.org/abs/2103.13415). Reviewed By: shapovalov Differential Revision: D46352730 fbshipit-source-id: c6a56134c975d80052b3a11f5e92fd7d95cbff1e
This commit is contained in:
		
							parent
							
								
									29b8ebd802
								
							
						
					
					
						commit
						ccf860f1db
					
				@ -361,6 +361,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_hidden_neurons_dir: 128
 | 
					      n_hidden_neurons_dir: 128
 | 
				
			||||||
      input_xyz: true
 | 
					      input_xyz: true
 | 
				
			||||||
      xyz_ray_dir_in_camera_coords: false
 | 
					      xyz_ray_dir_in_camera_coords: false
 | 
				
			||||||
 | 
					      use_integrated_positional_encoding: false
 | 
				
			||||||
      transformer_dim_down_factor: 2.0
 | 
					      transformer_dim_down_factor: 2.0
 | 
				
			||||||
      n_hidden_neurons_xyz: 80
 | 
					      n_hidden_neurons_xyz: 80
 | 
				
			||||||
      n_layers_xyz: 2
 | 
					      n_layers_xyz: 2
 | 
				
			||||||
@ -372,6 +373,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_hidden_neurons_dir: 128
 | 
					      n_hidden_neurons_dir: 128
 | 
				
			||||||
      input_xyz: true
 | 
					      input_xyz: true
 | 
				
			||||||
      xyz_ray_dir_in_camera_coords: false
 | 
					      xyz_ray_dir_in_camera_coords: false
 | 
				
			||||||
 | 
					      use_integrated_positional_encoding: false
 | 
				
			||||||
      transformer_dim_down_factor: 1.0
 | 
					      transformer_dim_down_factor: 1.0
 | 
				
			||||||
      n_hidden_neurons_xyz: 256
 | 
					      n_hidden_neurons_xyz: 256
 | 
				
			||||||
      n_layers_xyz: 8
 | 
					      n_layers_xyz: 8
 | 
				
			||||||
@ -741,6 +743,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_hidden_neurons_dir: 128
 | 
					      n_hidden_neurons_dir: 128
 | 
				
			||||||
      input_xyz: true
 | 
					      input_xyz: true
 | 
				
			||||||
      xyz_ray_dir_in_camera_coords: false
 | 
					      xyz_ray_dir_in_camera_coords: false
 | 
				
			||||||
 | 
					      use_integrated_positional_encoding: false
 | 
				
			||||||
      transformer_dim_down_factor: 2.0
 | 
					      transformer_dim_down_factor: 2.0
 | 
				
			||||||
      n_hidden_neurons_xyz: 80
 | 
					      n_hidden_neurons_xyz: 80
 | 
				
			||||||
      n_layers_xyz: 2
 | 
					      n_layers_xyz: 2
 | 
				
			||||||
@ -752,6 +755,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_hidden_neurons_dir: 128
 | 
					      n_hidden_neurons_dir: 128
 | 
				
			||||||
      input_xyz: true
 | 
					      input_xyz: true
 | 
				
			||||||
      xyz_ray_dir_in_camera_coords: false
 | 
					      xyz_ray_dir_in_camera_coords: false
 | 
				
			||||||
 | 
					      use_integrated_positional_encoding: false
 | 
				
			||||||
      transformer_dim_down_factor: 1.0
 | 
					      transformer_dim_down_factor: 1.0
 | 
				
			||||||
      n_hidden_neurons_xyz: 256
 | 
					      n_hidden_neurons_xyz: 256
 | 
				
			||||||
      n_layers_xyz: 8
 | 
					      n_layers_xyz: 8
 | 
				
			||||||
@ -979,6 +983,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_hidden_neurons_dir: 128
 | 
					      n_hidden_neurons_dir: 128
 | 
				
			||||||
      input_xyz: true
 | 
					      input_xyz: true
 | 
				
			||||||
      xyz_ray_dir_in_camera_coords: false
 | 
					      xyz_ray_dir_in_camera_coords: false
 | 
				
			||||||
 | 
					      use_integrated_positional_encoding: false
 | 
				
			||||||
      transformer_dim_down_factor: 2.0
 | 
					      transformer_dim_down_factor: 2.0
 | 
				
			||||||
      n_hidden_neurons_xyz: 80
 | 
					      n_hidden_neurons_xyz: 80
 | 
				
			||||||
      n_layers_xyz: 2
 | 
					      n_layers_xyz: 2
 | 
				
			||||||
@ -990,6 +995,7 @@ model_factory_ImplicitronModelFactory_args:
 | 
				
			|||||||
      n_hidden_neurons_dir: 128
 | 
					      n_hidden_neurons_dir: 128
 | 
				
			||||||
      input_xyz: true
 | 
					      input_xyz: true
 | 
				
			||||||
      xyz_ray_dir_in_camera_coords: false
 | 
					      xyz_ray_dir_in_camera_coords: false
 | 
				
			||||||
 | 
					      use_integrated_positional_encoding: false
 | 
				
			||||||
      transformer_dim_down_factor: 1.0
 | 
					      transformer_dim_down_factor: 1.0
 | 
				
			||||||
      n_hidden_neurons_xyz: 256
 | 
					      n_hidden_neurons_xyz: 256
 | 
				
			||||||
      n_layers_xyz: 8
 | 
					      n_layers_xyz: 8
 | 
				
			||||||
 | 
				
			|||||||
@ -9,11 +9,14 @@ from typing import Optional, Tuple
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
 | 
					from pytorch3d.common.linear_with_repeat import LinearWithRepeat
 | 
				
			||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
 | 
					from pytorch3d.implicitron.models.renderer.base import (
 | 
				
			||||||
 | 
					    conical_frustum_to_gaussian,
 | 
				
			||||||
 | 
					    ImplicitronRayBundle,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
 | 
					from pytorch3d.implicitron.tools.config import expand_args_fields, registry
 | 
				
			||||||
from pytorch3d.renderer import ray_bundle_to_ray_points
 | 
					 | 
				
			||||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
					from pytorch3d.renderer.cameras import CamerasBase
 | 
				
			||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
 | 
					from pytorch3d.renderer.implicit import HarmonicEmbedding
 | 
				
			||||||
 | 
					from pytorch3d.renderer.implicit.utils import ray_bundle_to_ray_points
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .base import ImplicitFunctionBase
 | 
					from .base import ImplicitFunctionBase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -36,6 +39,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
 | 
				
			|||||||
    input_xyz: bool = True
 | 
					    input_xyz: bool = True
 | 
				
			||||||
    xyz_ray_dir_in_camera_coords: bool = False
 | 
					    xyz_ray_dir_in_camera_coords: bool = False
 | 
				
			||||||
    color_dim: int = 3
 | 
					    color_dim: int = 3
 | 
				
			||||||
 | 
					    use_integrated_positional_encoding: bool = False
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
        n_harmonic_functions_xyz: The number of harmonic functions
 | 
					        n_harmonic_functions_xyz: The number of harmonic functions
 | 
				
			||||||
@ -53,6 +57,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
 | 
				
			|||||||
        n_layers_xyz: The number of layers of the MLP that outputs the
 | 
					        n_layers_xyz: The number of layers of the MLP that outputs the
 | 
				
			||||||
            occupancy field.
 | 
					            occupancy field.
 | 
				
			||||||
        append_xyz: The list of indices of the skip layers of the occupancy MLP.
 | 
					        append_xyz: The list of indices of the skip layers of the occupancy MLP.
 | 
				
			||||||
 | 
					        use_integrated_positional_encoding: If True, use integrated positional enoding
 | 
				
			||||||
 | 
					            as defined in `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
 | 
				
			||||||
 | 
					            If False, use the classical harmonic embedding
 | 
				
			||||||
 | 
					            defined in `NeRF <https://arxiv.org/abs/2003.08934>`_.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __post_init__(self):
 | 
					    def __post_init__(self):
 | 
				
			||||||
@ -149,6 +157,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
 | 
				
			|||||||
                    containing the direction vectors of sampling rays in world coords.
 | 
					                    containing the direction vectors of sampling rays in world coords.
 | 
				
			||||||
                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
 | 
					                lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
 | 
				
			||||||
                    containing the lengths at which the rays are sampled.
 | 
					                    containing the lengths at which the rays are sampled.
 | 
				
			||||||
 | 
					                bins: An optional tensor of shape `(minibatch,..., num_points_per_ray + 1)`
 | 
				
			||||||
 | 
					                    containing the bins at which the rays are sampled. In this case
 | 
				
			||||||
 | 
					                    lengths is equal to the midpoints of bins.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            fun_viewpool: an optional callback with the signature
 | 
					            fun_viewpool: an optional callback with the signature
 | 
				
			||||||
                    fun_fiewpool(points) -> pooled_features
 | 
					                    fun_fiewpool(points) -> pooled_features
 | 
				
			||||||
                where points is a [N_TGT x N x 3] tensor of world coords,
 | 
					                where points is a [N_TGT x N x 3] tensor of world coords,
 | 
				
			||||||
@ -160,11 +172,22 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
 | 
				
			|||||||
                denoting the opacitiy of each ray point.
 | 
					                denoting the opacitiy of each ray point.
 | 
				
			||||||
            rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
 | 
					            rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
 | 
				
			||||||
                denoting the color of each ray point.
 | 
					                denoting the color of each ray point.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Raises:
 | 
				
			||||||
 | 
					            ValueError: If `use_integrated_positional_encoding` is True and
 | 
				
			||||||
 | 
					                `ray_bundle.bins` is None.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        # We first convert the ray parametrizations to world
 | 
					        if self.use_integrated_positional_encoding and ray_bundle.bins is None:
 | 
				
			||||||
        # coordinates with `ray_bundle_to_ray_points`.
 | 
					            raise ValueError(
 | 
				
			||||||
        # pyre-ignore[6]
 | 
					                "When use_integrated_positional_encoding is True, ray_bundle.bins must be set."
 | 
				
			||||||
        rays_points_world = ray_bundle_to_ray_points(ray_bundle)
 | 
					                "Have you set to True `AbstractMaskRaySampler.use_bins_for_ray_sampling`?"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        rays_points_world, diag_cov = (
 | 
				
			||||||
 | 
					            conical_frustum_to_gaussian(ray_bundle)
 | 
				
			||||||
 | 
					            if self.use_integrated_positional_encoding
 | 
				
			||||||
 | 
					            else (ray_bundle_to_ray_points(ray_bundle), None)  # pyre-ignore
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
 | 
					        # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        embeds = create_embeddings_for_implicit_function(
 | 
					        embeds = create_embeddings_for_implicit_function(
 | 
				
			||||||
@ -177,6 +200,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
 | 
				
			|||||||
            fun_viewpool=fun_viewpool,
 | 
					            fun_viewpool=fun_viewpool,
 | 
				
			||||||
            xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
 | 
					            xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
 | 
				
			||||||
            camera=camera,
 | 
					            camera=camera,
 | 
				
			||||||
 | 
					            diag_cov=diag_cov,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
 | 
					        # embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
 | 
				
			||||||
 | 
				
			|||||||
@ -36,6 +36,7 @@ def create_embeddings_for_implicit_function(
 | 
				
			|||||||
    camera: Optional[CamerasBase],
 | 
					    camera: Optional[CamerasBase],
 | 
				
			||||||
    fun_viewpool: Optional[Callable],
 | 
					    fun_viewpool: Optional[Callable],
 | 
				
			||||||
    xyz_embedding_function: Optional[Callable],
 | 
					    xyz_embedding_function: Optional[Callable],
 | 
				
			||||||
 | 
					    diag_cov: Optional[torch.Tensor] = None,
 | 
				
			||||||
) -> torch.Tensor:
 | 
					) -> torch.Tensor:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
 | 
					    bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
 | 
				
			||||||
@ -59,11 +60,11 @@ def create_embeddings_for_implicit_function(
 | 
				
			|||||||
            prod(spatial_size),
 | 
					            prod(spatial_size),
 | 
				
			||||||
            pts_per_ray,
 | 
					            pts_per_ray,
 | 
				
			||||||
            0,
 | 
					            0,
 | 
				
			||||||
            dtype=xyz_world.dtype,
 | 
					 | 
				
			||||||
            device=xyz_world.device,
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        embeds = xyz_embedding_function(ray_points_for_embed).reshape(
 | 
					
 | 
				
			||||||
 | 
					        embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
 | 
				
			||||||
 | 
					        embeds = embeds.reshape(
 | 
				
			||||||
            bs,
 | 
					            bs,
 | 
				
			||||||
            1,
 | 
					            1,
 | 
				
			||||||
            prod(spatial_size),
 | 
					            prod(spatial_size),
 | 
				
			||||||
 | 
				
			|||||||
@ -4,6 +4,8 @@
 | 
				
			|||||||
# This source code is licensed under the BSD-style license found in the
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
# LICENSE file in the root directory of this source tree.
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -16,8 +18,18 @@ class HarmonicEmbedding(torch.nn.Module):
 | 
				
			|||||||
        append_input: bool = True,
 | 
					        append_input: bool = True,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Given an input tensor `x` of shape [minibatch, ... , dim],
 | 
					        The harmonic embedding layer supports the classical
 | 
				
			||||||
        the harmonic embedding layer converts each feature
 | 
					        Nerf positional encoding described in
 | 
				
			||||||
 | 
					        `NeRF <https://arxiv.org/abs/2003.08934>`_
 | 
				
			||||||
 | 
					        and the integrated position encoding in
 | 
				
			||||||
 | 
					        `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        During, the inference you can provide the extra argument `diag_cov`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        If `diag_cov is None`, it converts
 | 
				
			||||||
 | 
					        rays parametrized with a `ray_bundle` to 3D points by
 | 
				
			||||||
 | 
					        extending each ray according to the corresponding length.
 | 
				
			||||||
 | 
					        Then it converts each feature
 | 
				
			||||||
        (i.e. vector along the last dimension) in `x`
 | 
					        (i.e. vector along the last dimension) in `x`
 | 
				
			||||||
        into a series of harmonic features `embedding`,
 | 
					        into a series of harmonic features `embedding`,
 | 
				
			||||||
        where for each i in range(dim) the following are present
 | 
					        where for each i in range(dim) the following are present
 | 
				
			||||||
@ -38,6 +50,31 @@ class HarmonicEmbedding(torch.nn.Module):
 | 
				
			|||||||
        where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
 | 
					        where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
 | 
				
			||||||
        denoting the i-th frequency of the harmonic embedding.
 | 
					        denoting the i-th frequency of the harmonic embedding.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        If `diag_cov is not None`, it approximates
 | 
				
			||||||
 | 
					        conical frustums following a ray bundle as gaussians,
 | 
				
			||||||
 | 
					        defined by x, the means of the gaussians and diag_cov,
 | 
				
			||||||
 | 
					        the diagonal covariances.
 | 
				
			||||||
 | 
					        Then it converts each gaussian
 | 
				
			||||||
 | 
					        into a series of harmonic features `embedding`,
 | 
				
			||||||
 | 
					        where for each i in range(dim) the following are present
 | 
				
			||||||
 | 
					        in embedding[...]::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            [
 | 
				
			||||||
 | 
					                sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
 | 
				
			||||||
 | 
					                sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
 | 
				
			||||||
 | 
					                ...
 | 
				
			||||||
 | 
					                sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
 | 
				
			||||||
 | 
					                cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
 | 
				
			||||||
 | 
					                cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
 | 
				
			||||||
 | 
					                ...
 | 
				
			||||||
 | 
					                cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
 | 
				
			||||||
 | 
					                x[..., i],              # only present if append_input is True.
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        where N equals `n_harmonic_functions-1`, and f_i is a scalar
 | 
				
			||||||
 | 
					        denoting the i-th frequency of the harmonic embedding.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
 | 
					        If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
 | 
				
			||||||
        powers of 2:
 | 
					        powers of 2:
 | 
				
			||||||
            `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
 | 
					            `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
 | 
				
			||||||
@ -59,8 +96,7 @@ class HarmonicEmbedding(torch.nn.Module):
 | 
				
			|||||||
                logspace or linear space
 | 
					                logspace or linear space
 | 
				
			||||||
            append_input: bool, whether to concat the original
 | 
					            append_input: bool, whether to concat the original
 | 
				
			||||||
                input to the harmonic embedding. If true the
 | 
					                input to the harmonic embedding. If true the
 | 
				
			||||||
                output is of the form (x, embed.sin(), embed.cos()
 | 
					                output is of the form (embed.sin(), embed.cos(), x)
 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -78,23 +114,42 @@ class HarmonicEmbedding(torch.nn.Module):
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
 | 
					        self.register_buffer("_frequencies", frequencies * omega_0, persistent=False)
 | 
				
			||||||
 | 
					        self.register_buffer(
 | 
				
			||||||
 | 
					            "_zero_half_pi", torch.tensor([0.0, 0.5 * torch.pi]), persistent=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        self.append_input = append_input
 | 
					        self.append_input = append_input
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
					    def forward(
 | 
				
			||||||
 | 
					        self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
 | 
				
			||||||
 | 
					    ) -> torch.Tensor:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            x: tensor of shape [..., dim]
 | 
					            x: tensor of shape [..., dim]
 | 
				
			||||||
 | 
					            diag_cov: An optional tensor of shape `(..., dim)`
 | 
				
			||||||
 | 
					                representing the diagonal covariance matrices of our Gaussians, joined with x
 | 
				
			||||||
 | 
					                as means of the Gaussians.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            embedding: a harmonic embedding of `x`
 | 
					            embedding: a harmonic embedding of `x` of shape
 | 
				
			||||||
                of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
 | 
					            [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
 | 
					        # [..., dim, n_harmonic_functions]
 | 
				
			||||||
        embed = torch.cat(
 | 
					        embed = x[..., None] * self._frequencies
 | 
				
			||||||
            (embed.sin(), embed.cos(), x)
 | 
					        # [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
 | 
				
			||||||
            if self.append_input
 | 
					        embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
 | 
				
			||||||
            else (embed.sin(), embed.cos()),
 | 
					        # Use the trig identity cos(x) = sin(x + pi/2)
 | 
				
			||||||
            dim=-1,
 | 
					        # and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
 | 
				
			||||||
        )
 | 
					        embed = embed.sin()
 | 
				
			||||||
 | 
					        if diag_cov is not None:
 | 
				
			||||||
 | 
					            x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
 | 
				
			||||||
 | 
					            exp_var = torch.exp(-0.5 * x_var)
 | 
				
			||||||
 | 
					            # [..., 2, dim, n_harmonic_functions]
 | 
				
			||||||
 | 
					            embed = embed * exp_var[..., None, :, :]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed = embed.reshape(*x.shape[:-1], -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.append_input:
 | 
				
			||||||
 | 
					            return torch.cat([embed, x], dim=-1)
 | 
				
			||||||
        return embed
 | 
					        return embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,66 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.models.implicit_function.base import ImplicitronRayBundle
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
 | 
				
			||||||
 | 
					    NeuralRadianceFieldImplicitFunction,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestNeuralRadianceFieldImplicitFunction(unittest.TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        torch.manual_seed(42)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_forward_with_integrated_positionial_embedding(self):
 | 
				
			||||||
 | 
					        shape = [2, 4, 4]
 | 
				
			||||||
 | 
					        ray_bundle = ImplicitronRayBundle(
 | 
				
			||||||
 | 
					            origins=torch.randn(*shape, 3),
 | 
				
			||||||
 | 
					            directions=torch.randn(*shape, 3),
 | 
				
			||||||
 | 
					            bins=torch.randn(*shape, 6 + 1),
 | 
				
			||||||
 | 
					            lengths=torch.randn(*shape, 6),
 | 
				
			||||||
 | 
					            pixel_radii_2d=torch.randn(*shape, 1),
 | 
				
			||||||
 | 
					            xys=None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        model = NeuralRadianceFieldImplicitFunction(
 | 
				
			||||||
 | 
					            n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(raw_densities.shape, (*shape, ray_bundle.lengths.shape[-1], 1))
 | 
				
			||||||
 | 
					        self.assertEqual(ray_colors.shape, (*shape, ray_bundle.lengths.shape[-1], 3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_forward_with_integrated_positionial_embedding_raise_exception(self):
 | 
				
			||||||
 | 
					        shape = [2, 4, 4]
 | 
				
			||||||
 | 
					        ray_bundle = ImplicitronRayBundle(
 | 
				
			||||||
 | 
					            origins=torch.randn(*shape, 3),
 | 
				
			||||||
 | 
					            directions=torch.randn(*shape, 3),
 | 
				
			||||||
 | 
					            bins=None,
 | 
				
			||||||
 | 
					            lengths=torch.randn(*shape, 6),
 | 
				
			||||||
 | 
					            pixel_radii_2d=torch.randn(*shape, 1),
 | 
				
			||||||
 | 
					            xys=None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        model = NeuralRadianceFieldImplicitFunction(
 | 
				
			||||||
 | 
					            n_hidden_neurons_dir=32, use_integrated_positional_encoding=True
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					            _ = model(ray_bundle=ray_bundle)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_forward(self):
 | 
				
			||||||
 | 
					        shape = [2, 4, 4]
 | 
				
			||||||
 | 
					        ray_bundle = ImplicitronRayBundle(
 | 
				
			||||||
 | 
					            origins=torch.randn(*shape, 3),
 | 
				
			||||||
 | 
					            directions=torch.randn(*shape, 3),
 | 
				
			||||||
 | 
					            lengths=torch.randn(*shape, 6),
 | 
				
			||||||
 | 
					            pixel_radii_2d=torch.randn(*shape, 1),
 | 
				
			||||||
 | 
					            xys=None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        model = NeuralRadianceFieldImplicitFunction(n_hidden_neurons_dir=32)
 | 
				
			||||||
 | 
					        raw_densities, ray_colors, _ = model(ray_bundle=ray_bundle)
 | 
				
			||||||
 | 
					        self.assertEqual(raw_densities.shape, (*shape, 6, 1))
 | 
				
			||||||
 | 
					        self.assertEqual(ray_colors.shape, (*shape, 6, 3))
 | 
				
			||||||
@ -8,6 +8,7 @@ import unittest
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
 | 
					from pytorch3d.renderer.implicit import HarmonicEmbedding
 | 
				
			||||||
 | 
					from torch.distributions import MultivariateNormal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .common_testing import TestCaseMixin
 | 
					from .common_testing import TestCaseMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -36,16 +37,117 @@ class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase):
 | 
				
			|||||||
        self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))
 | 
					        self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_correct_embed_out(self):
 | 
					    def test_correct_embed_out(self):
 | 
				
			||||||
        embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False)
 | 
					        n_harmonic_functions = 2
 | 
				
			||||||
        x = torch.randn((1, 5))
 | 
					        x = torch.randn((1, 5))
 | 
				
			||||||
        D = 5 * 4
 | 
					        D = 5 * n_harmonic_functions * 2  # sin + cos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed_fun = HarmonicEmbedding(
 | 
				
			||||||
 | 
					            n_harmonic_functions=n_harmonic_functions, append_input=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        embed_out = embed_fun(x)
 | 
					        embed_out = embed_fun(x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.assertEqual(embed_out.shape, (1, D))
 | 
					        self.assertEqual(embed_out.shape, (1, D))
 | 
				
			||||||
        # Sum the squares of the respective frequencies
 | 
					        # Sum the squares of the respective frequencies
 | 
				
			||||||
 | 
					        # cos^2(x) + sin^2(x) = 1
 | 
				
			||||||
        sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
 | 
					        sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2
 | 
				
			||||||
        self.assertClose(sum_squares, torch.ones((D // 2)))
 | 
					        self.assertClose(sum_squares, torch.ones((D // 2)))
 | 
				
			||||||
        embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=True)
 | 
					
 | 
				
			||||||
        embed_out = embed_fun(x)
 | 
					        # Test append input
 | 
				
			||||||
        self.assertClose(embed_out.shape, torch.tensor((1, 5 * 5)))
 | 
					        embed_fun = HarmonicEmbedding(
 | 
				
			||||||
 | 
					            n_harmonic_functions=n_harmonic_functions, append_input=True
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        embed_out_appended_input = embed_fun(x)
 | 
				
			||||||
 | 
					        self.assertClose(
 | 
				
			||||||
 | 
					            embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1]))
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        # Last plane in output is the input
 | 
					        # Last plane in output is the input
 | 
				
			||||||
        self.assertClose(embed_out[..., -5:], x)
 | 
					        self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x)
 | 
				
			||||||
 | 
					        self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_correct_embed_out_with_diag_cov(self):
 | 
				
			||||||
 | 
					        n_harmonic_functions = 2
 | 
				
			||||||
 | 
					        x = torch.randn((1, 3))
 | 
				
			||||||
 | 
					        diag_cov = torch.randn((1, 3))
 | 
				
			||||||
 | 
					        D = 3 * n_harmonic_functions * 2  # sin + cos
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embed_fun = HarmonicEmbedding(
 | 
				
			||||||
 | 
					            n_harmonic_functions=n_harmonic_functions, append_input=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        embed_out = embed_fun(x, diag_cov=diag_cov)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(embed_out.shape, (1, D))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Compute the scaling factor introduce in MipNerf
 | 
				
			||||||
 | 
					        scale_factor = (
 | 
				
			||||||
 | 
					            -0.5 * diag_cov[..., None] * torch.pow(embed_fun._frequencies[None, :], 2)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        scale_factor = torch.exp(scale_factor).reshape(1, -1).tile((1, 2))
 | 
				
			||||||
 | 
					        # If we remove this scaling factor, we should go back to the
 | 
				
			||||||
 | 
					        # classical harmonic embedding:
 | 
				
			||||||
 | 
					        # Sum the squares of the respective frequencies
 | 
				
			||||||
 | 
					        # cos^2(x) + sin^2(x) = 1
 | 
				
			||||||
 | 
					        embed_out_without_cov = embed_out / scale_factor
 | 
				
			||||||
 | 
					        sum_squares = (
 | 
				
			||||||
 | 
					            embed_out_without_cov[0, : D // 2] ** 2
 | 
				
			||||||
 | 
					            + embed_out_without_cov[0, D // 2 :] ** 2
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertClose(sum_squares, torch.ones((D // 2)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Test append input
 | 
				
			||||||
 | 
					        embed_fun = HarmonicEmbedding(
 | 
				
			||||||
 | 
					            n_harmonic_functions=n_harmonic_functions, append_input=True
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        embed_out_appended_input = embed_fun(x, diag_cov=diag_cov)
 | 
				
			||||||
 | 
					        self.assertClose(
 | 
				
			||||||
 | 
					            embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1]))
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # Last plane in output is the input
 | 
				
			||||||
 | 
					        self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x)
 | 
				
			||||||
 | 
					        self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_correct_behavior_between_ipe_and_its_estimation_from_harmonic_embedding(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Check that the HarmonicEmbedding with integrated_position_encoding (IPE) set to
 | 
				
			||||||
 | 
					        True is coherent with the HarmonicEmbedding.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        What is the idea behind this test?
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We wish to produce an IPE that is the expectation
 | 
				
			||||||
 | 
					        of our lifted multivariate gaussian, modulated by the sine and cosine of
 | 
				
			||||||
 | 
					        the coordinates. These expectation has a closed-form
 | 
				
			||||||
 | 
					        (see equations 11, 12, 13, 14 of [1]).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        We sample N elements from the multivariate gaussian defined by its mean and covariance
 | 
				
			||||||
 | 
					        and compute the HarmonicEmbedding. The expected value of those embeddings should be
 | 
				
			||||||
 | 
					        equal to our IPE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Inspired from:
 | 
				
			||||||
 | 
					        https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/mip_test.py#L359
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        References:
 | 
				
			||||||
 | 
					            [1] `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        num_dims = 3
 | 
				
			||||||
 | 
					        n_harmonic_functions = 6
 | 
				
			||||||
 | 
					        mean = torch.randn(num_dims)
 | 
				
			||||||
 | 
					        diag_cov = torch.rand(num_dims)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        he_fun = HarmonicEmbedding(
 | 
				
			||||||
 | 
					            n_harmonic_functions=n_harmonic_functions, logspace=True, append_input=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        ipe_fun = HarmonicEmbedding(
 | 
				
			||||||
 | 
					            n_harmonic_functions=n_harmonic_functions,
 | 
				
			||||||
 | 
					            append_input=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        embedding_ipe = ipe_fun(mean, diag_cov=diag_cov)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        rand_mvn = MultivariateNormal(mean, torch.eye(num_dims) * diag_cov)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Providing a large enough number of samples
 | 
				
			||||||
 | 
					        # we should obtain an estimation close to our IPE
 | 
				
			||||||
 | 
					        num_samples = 100000
 | 
				
			||||||
 | 
					        embedding_he = he_fun(rand_mvn.sample_n(num_samples))
 | 
				
			||||||
 | 
					        self.assertClose(embedding_he.mean(0), embedding_ipe, rtol=1e-2, atol=1e-2)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user