voxel grid implicit function

Summary: The implicit function and its members and internal working

Reviewed By: kjchalup

Differential Revision: D38829764

fbshipit-source-id: 28394fe7819e311ed52c9defc9a1b29f37fbc495
This commit is contained in:
Darijan Gudelj 2022-09-22 10:56:00 -07:00 committed by Facebook GitHub Bot
parent d6a197be36
commit c2d876c9e8
3 changed files with 40 additions and 7 deletions

View File

@ -18,6 +18,8 @@ from typing import Optional, Tuple
import torch
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import (
Configurable,
registry,
@ -179,8 +181,11 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
class MLPDecoder(DecoderFunctionBase):
"""
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
If using Implicitron config system `input_dim` of the `network` is changed to the
value of `input_dim` member and `input_skips` is removed.
"""
input_dim: int = 3
network: MLPWithInputSkips
def __post_init__(self):
@ -192,6 +197,20 @@ class MLPDecoder(DecoderFunctionBase):
) -> torch.Tensor:
return self.network(features, z)
@classmethod
def network_tweak_args(cls, type, args: DictConfig) -> None:
"""
Special method to stop get_default_args exposing member's `input_dim`.
"""
args.pop("input_dim", None)
def create_network_impl(self, type, args: DictConfig) -> None:
"""
Set the input dimension of the `network` to the input dimension of the
decoding function.
"""
self.network = MLPWithInputSkips(input_dim=self.input_dim, **args)
class TransformerWithInputSkips(torch.nn.Module):
def __init__(

View File

@ -65,7 +65,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
padding: str = "zeros"
mode: str = "bilinear"
n_features: int = 1
resolution: Tuple[int, int, int] = (64, 64, 64)
resolution: Tuple[int, int, int] = (128, 128, 128)
def __post_init__(self):
super().__init__()
@ -507,8 +507,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
voxel_grid: VoxelGridBase
# pyre-fixme[8]: Attribute has type `Tuple[float, float, float]`; used as `float`.
extents: Tuple[float, float, float] = 1.0
extents: Tuple[float, float, float] = (1.0, 1.0, 1.0)
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
init_std: float = 0.1
@ -552,9 +551,9 @@ class VoxelGridModule(Configurable, torch.nn.Module):
grid_sizes=(2, 2, 2),
# The locator object uses (x, y, z) convention for the
# voxel size and translation.
voxel_size=self.extents,
volume_translation=self.translation,
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
voxel_size=tuple(self.extents),
volume_translation=tuple(self.translation),
# pyre-ignore[29]
device=next(self.params.values()).device,
)
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
@ -562,3 +561,18 @@ class VoxelGridModule(Configurable, torch.nn.Module):
grid_values = self.voxel_grid.values_type(**self.params)
# voxel grids operate with extra n_grids dimension, which we fix to one
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
@staticmethod
def get_output_dim(args: DictConfig) -> int:
"""
Utility to help predict the shape of the output of `forward`.
Args:
args: DictConfig which would be used to initialize the object
Returns:
int: the length of the last dimension of the output tensor
"""
grid = registry.get(VoxelGridBase, args["voxel_grid_class_type"])
return grid.get_output_dim(
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
)

View File

@ -88,7 +88,7 @@ class HarmonicEmbedding(torch.nn.Module):
embedding: a harmonic embedding of `x`
of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim]
"""
embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1)
embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1)
embed = torch.cat(
(embed.sin(), embed.cos(), x)
if self.append_input