From c2d876c9e88af35e43dc7acd14dec83467020c9f Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Thu, 22 Sep 2022 10:56:00 -0700 Subject: [PATCH] voxel grid implicit function Summary: The implicit function and its members and internal working Reviewed By: kjchalup Differential Revision: D38829764 fbshipit-source-id: 28394fe7819e311ed52c9defc9a1b29f37fbc495 --- .../implicit_function/decoding_functions.py | 19 ++++++++++++++ .../models/implicit_function/voxel_grid.py | 26 ++++++++++++++----- .../renderer/implicit/harmonic_embedding.py | 2 +- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py index 2b6fe969..6e99b5ab 100644 --- a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py +++ b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py @@ -18,6 +18,8 @@ from typing import Optional, Tuple import torch +from omegaconf import DictConfig + from pytorch3d.implicitron.tools.config import ( Configurable, registry, @@ -179,8 +181,11 @@ class MLPWithInputSkips(Configurable, torch.nn.Module): class MLPDecoder(DecoderFunctionBase): """ Decoding function which uses `MLPWithIputSkips` to convert the embedding to output. + If using Implicitron config system `input_dim` of the `network` is changed to the + value of `input_dim` member and `input_skips` is removed. """ + input_dim: int = 3 network: MLPWithInputSkips def __post_init__(self): @@ -192,6 +197,20 @@ class MLPDecoder(DecoderFunctionBase): ) -> torch.Tensor: return self.network(features, z) + @classmethod + def network_tweak_args(cls, type, args: DictConfig) -> None: + """ + Special method to stop get_default_args exposing member's `input_dim`. + """ + args.pop("input_dim", None) + + def create_network_impl(self, type, args: DictConfig) -> None: + """ + Set the input dimension of the `network` to the input dimension of the + decoding function. + """ + self.network = MLPWithInputSkips(input_dim=self.input_dim, **args) + class TransformerWithInputSkips(torch.nn.Module): def __init__( diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index fc97c69b..8ea21859 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -65,7 +65,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): padding: str = "zeros" mode: str = "bilinear" n_features: int = 1 - resolution: Tuple[int, int, int] = (64, 64, 64) + resolution: Tuple[int, int, int] = (128, 128, 128) def __post_init__(self): super().__init__() @@ -507,8 +507,7 @@ class VoxelGridModule(Configurable, torch.nn.Module): voxel_grid_class_type: str = "FullResolutionVoxelGrid" voxel_grid: VoxelGridBase - # pyre-fixme[8]: Attribute has type `Tuple[float, float, float]`; used as `float`. - extents: Tuple[float, float, float] = 1.0 + extents: Tuple[float, float, float] = (1.0, 1.0, 1.0) translation: Tuple[float, float, float] = (0.0, 0.0, 0.0) init_std: float = 0.1 @@ -552,9 +551,9 @@ class VoxelGridModule(Configurable, torch.nn.Module): grid_sizes=(2, 2, 2), # The locator object uses (x, y, z) convention for the # voxel size and translation. - voxel_size=self.extents, - volume_translation=self.translation, - # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase... + voxel_size=tuple(self.extents), + volume_translation=tuple(self.translation), + # pyre-ignore[29] device=next(self.params.values()).device, ) # pyre-fixme[29]: `Union[torch._tensor.Tensor, @@ -562,3 +561,18 @@ class VoxelGridModule(Configurable, torch.nn.Module): grid_values = self.voxel_grid.values_type(**self.params) # voxel grids operate with extra n_grids dimension, which we fix to one return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0] + + @staticmethod + def get_output_dim(args: DictConfig) -> int: + """ + Utility to help predict the shape of the output of `forward`. + + Args: + args: DictConfig which would be used to initialize the object + Returns: + int: the length of the last dimension of the output tensor + """ + grid = registry.get(VoxelGridBase, args["voxel_grid_class_type"]) + return grid.get_output_dim( + args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"] + ) diff --git a/pytorch3d/renderer/implicit/harmonic_embedding.py b/pytorch3d/renderer/implicit/harmonic_embedding.py index 4dadd548..63d28e5f 100644 --- a/pytorch3d/renderer/implicit/harmonic_embedding.py +++ b/pytorch3d/renderer/implicit/harmonic_embedding.py @@ -88,7 +88,7 @@ class HarmonicEmbedding(torch.nn.Module): embedding: a harmonic embedding of `x` of shape [..., (n_harmonic_functions * 2 + int(append_input)) * dim] """ - embed = (x[..., None] * self._frequencies).view(*x.shape[:-1], -1) + embed = (x[..., None] * self._frequencies).reshape(*x.shape[:-1], -1) embed = torch.cat( (embed.sin(), embed.cos(), x) if self.append_input