mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	voxel grid implicit function
Summary: The implicit function and its members and internal working Reviewed By: kjchalup Differential Revision: D38829764 fbshipit-source-id: 28394fe7819e311ed52c9defc9a1b29f37fbc495
This commit is contained in:
		
							parent
							
								
									d6a197be36
								
							
						
					
					
						commit
						c2d876c9e8
					
				@ -18,6 +18,8 @@ from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from omegaconf import DictConfig
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    Configurable,
 | 
			
		||||
    registry,
 | 
			
		||||
@ -179,8 +181,11 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
 | 
			
		||||
class MLPDecoder(DecoderFunctionBase):
 | 
			
		||||
    """
 | 
			
		||||
    Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
 | 
			
		||||
    If using Implicitron config system `input_dim` of the `network` is changed to the
 | 
			
		||||
    value of `input_dim` member and `input_skips` is removed.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    input_dim: int = 3
 | 
			
		||||
    network: MLPWithInputSkips
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
@ -192,6 +197,20 @@ class MLPDecoder(DecoderFunctionBase):
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        return self.network(features, z)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def network_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Special method to stop get_default_args exposing member's `input_dim`.
 | 
			
		||||
        """
 | 
			
		||||
        args.pop("input_dim", None)
 | 
			
		||||
 | 
			
		||||
    def create_network_impl(self, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Set the input dimension of the `network` to the input dimension of the
 | 
			
		||||
        decoding function.
 | 
			
		||||
        """
 | 
			
		||||
        self.network = MLPWithInputSkips(input_dim=self.input_dim, **args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerWithInputSkips(torch.nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
 | 
			
		||||
@ -65,7 +65,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    padding: str = "zeros"
 | 
			
		||||
    mode: str = "bilinear"
 | 
			
		||||
    n_features: int = 1
 | 
			
		||||
    resolution: Tuple[int, int, int] = (64, 64, 64)
 | 
			
		||||
    resolution: Tuple[int, int, int] = (128, 128, 128)
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
@ -507,8 +507,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
    voxel_grid_class_type: str = "FullResolutionVoxelGrid"
 | 
			
		||||
    voxel_grid: VoxelGridBase
 | 
			
		||||
 | 
			
		||||
    # pyre-fixme[8]: Attribute has type `Tuple[float, float, float]`; used as `float`.
 | 
			
		||||
    extents: Tuple[float, float, float] = 1.0
 | 
			
		||||
    extents: Tuple[float, float, float] = (1.0, 1.0, 1.0)
 | 
			
		||||
    translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
 | 
			
		||||
 | 
			
		||||
    init_std: float = 0.1
 | 
			
		||||
@ -552,9 +551,9 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
            grid_sizes=(2, 2, 2),
 | 
			
		||||
            # The locator object uses (x, y, z) convention for the
 | 
			
		||||
            # voxel size and translation.
 | 
			
		||||
            voxel_size=self.extents,
 | 
			
		||||
            volume_translation=self.translation,
 | 
			
		||||
            # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
 | 
			
		||||
            voxel_size=tuple(self.extents),
 | 
			
		||||
            volume_translation=tuple(self.translation),
 | 
			
		||||
            # pyre-ignore[29]
 | 
			
		||||
            device=next(self.params.values()).device,
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
 | 
			
		||||
@ -562,3 +561,18 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
        grid_values = self.voxel_grid.values_type(**self.params)
 | 
			
		||||
        # voxel grids operate with extra n_grids dimension, which we fix to one
 | 
			
		||||
        return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_output_dim(args: DictConfig) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        Utility to help predict the shape of the output of `forward`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            args: DictConfig which would be used to initialize the object
 | 
			
		||||
        Returns:
 | 
			
		||||
            int: the length of the last dimension of the output tensor
 | 
			
		||||
        """
 | 
			
		||||
        grid = registry.get(VoxelGridBase, args["voxel_grid_class_type"])
 | 
			
		||||
        return grid.get_output_dim(
 | 
			
		||||
            args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -88,7 +88,7 @@ class HarmonicEmbedding(torch.nn.Module):
 | 
			
		||||
            embedding: a harmonic embedding of `x`
 | 
			
		||||
                of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
 | 
			
		||||
        """
 | 
			
		||||
        embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
 | 
			
		||||
        embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
 | 
			
		||||
        embed = torch.cat(
 | 
			
		||||
            (embed.sin(), embed.cos(), x)
 | 
			
		||||
            if self.append_input
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user