mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Update Harmonic embedding in NeRF
Summary: Removed harmonic embedding function from projects/nerf and changed import to be from core pytorch3d. Reviewed By: patricklabatut Differential Revision: D33142358 fbshipit-source-id: 3004247d50392dbd04ea72e9cd4bace0dc03606b
This commit is contained in:
		
							parent
							
								
									f9a26a22fc
								
							
						
					
					
						commit
						52c71b8816
					
				@ -1,88 +0,0 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HarmonicEmbedding(torch.nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        n_harmonic_functions: int = 6,
 | 
			
		||||
        omega0: float = 1.0,
 | 
			
		||||
        logspace: bool = True,
 | 
			
		||||
        include_input: bool = True,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Given an input tensor `x` of shape [minibatch, ... , dim],
 | 
			
		||||
        the harmonic embedding layer converts each feature
 | 
			
		||||
        in `x` into a series of harmonic features `embedding`,
 | 
			
		||||
        where for each i in range(dim) the following are present
 | 
			
		||||
        in embedding[...]:
 | 
			
		||||
            ```
 | 
			
		||||
            [
 | 
			
		||||
                sin(x[..., i]),
 | 
			
		||||
                sin(f_1*x[..., i]),
 | 
			
		||||
                sin(f_2*x[..., i]),
 | 
			
		||||
                ...
 | 
			
		||||
                sin(f_N * x[..., i]),
 | 
			
		||||
                cos(x[..., i]),
 | 
			
		||||
                cos(f_1*x[..., i]),
 | 
			
		||||
                cos(f_2*x[..., i]),
 | 
			
		||||
                ...
 | 
			
		||||
                cos(f_N * x[..., i]),
 | 
			
		||||
                x[..., i]     # only present if include_input is True.
 | 
			
		||||
            ]
 | 
			
		||||
            ```
 | 
			
		||||
        where N corresponds to `n_harmonic_functions`, and f_i is a scalar
 | 
			
		||||
        denoting the i-th frequency of the harmonic embedding.
 | 
			
		||||
        The shape of the output is [minibatch, ... , dim * (2 * N + 1)] if
 | 
			
		||||
        include_input is True, otherwise [minibatch, ... , dim * (2 * N)].
 | 
			
		||||
 | 
			
		||||
        If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
 | 
			
		||||
        powers of 2:
 | 
			
		||||
            `f_1 = 1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
 | 
			
		||||
 | 
			
		||||
        If `logspace==False`, frequencies are linearly spaced between
 | 
			
		||||
        `1.0` and `2**(n_harmonic_functions-1)`:
 | 
			
		||||
            `f_1, ..., f_N = torch.linspace(
 | 
			
		||||
                1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
 | 
			
		||||
            )`
 | 
			
		||||
 | 
			
		||||
        Note that `x` is also premultiplied by the base frequency `omega0`
 | 
			
		||||
        before evaluating the harmonic functions.
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        if logspace:
 | 
			
		||||
            frequencies = 2.0 ** torch.arange(
 | 
			
		||||
                n_harmonic_functions,
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            frequencies = torch.linspace(
 | 
			
		||||
                1.0,
 | 
			
		||||
                2.0 ** (n_harmonic_functions - 1),
 | 
			
		||||
                n_harmonic_functions,
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.register_buffer("_frequencies", omega0 * frequencies, persistent=False)
 | 
			
		||||
        self.include_input = include_input
 | 
			
		||||
 | 
			
		||||
    def forward(self, x: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            x: tensor of shape [..., dim]
 | 
			
		||||
        Returns:
 | 
			
		||||
            embedding: a harmonic embedding of `x` of shape
 | 
			
		||||
                [..., dim * (n_harmonic_functions * 2 + T)] where
 | 
			
		||||
                T is 1 if include_input is True and 0 otherwise.
 | 
			
		||||
        """
 | 
			
		||||
        embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
 | 
			
		||||
        if self.include_input:
 | 
			
		||||
            return torch.cat((embed.sin(), embed.cos(), x), dim=-1)
 | 
			
		||||
        else:
 | 
			
		||||
            return torch.cat((embed.sin(), embed.cos()), dim=-1)
 | 
			
		||||
@ -7,9 +7,8 @@
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
 | 
			
		||||
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points, HarmonicEmbedding
 | 
			
		||||
 | 
			
		||||
from .harmonic_embedding import HarmonicEmbedding
 | 
			
		||||
from .linear_with_repeat import LinearWithRepeat
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user