From 24f5f4a3e7faec460e25ab0c1690f7d8329f92a6 Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Thu, 25 Aug 2022 09:40:43 -0700 Subject: [PATCH] VoxelGridModule Summary: Simple wrapper around voxel grids to make them a module Reviewed By: bottler Differential Revision: D38829762 fbshipit-source-id: dfee85088fa3c65e396cc7d3bf7ebaaffaadb646 --- .../models/implicit_function/voxel_grid.py | 88 ++++++++++++++++++- tests/implicitron/test_voxel_grids.py | 26 ++++++ 2 files changed, 112 insertions(+), 2 deletions(-) diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index 834dcb2c..76dc3ac2 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -8,14 +8,23 @@ This file contains classes that implement Voxel grids, both in their full resolution as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the -https://arxiv.org/abs/2203.09517. +TensoRF (https://arxiv.org/abs/2203.09517) paper. + +In addition, the module VoxelGridModule implements a trainable instance of one of +these classes. + """ from dataclasses import dataclass from typing import ClassVar, Dict, Optional, Tuple, Type import torch -from pytorch3d.implicitron.tools.config import registry, ReplaceableBase +from pytorch3d.implicitron.tools.config import ( + Configurable, + registry, + ReplaceableBase, + run_auto_creation, +) from pytorch3d.structures.volumes import VolumeLocator from .utils import interpolate_line, interpolate_plane, interpolate_volume @@ -426,3 +435,78 @@ class VMFactorizedVoxelGrid(VoxelGridBase): ) return shape_dict + + +class VoxelGridModule(Configurable, torch.nn.Module): + """ + A wrapper torch.nn.Module for the VoxelGrid classes, which + contains parameters that are needed to train the VoxelGrid classes. + + Members: + voxel_grid_class_type: The name of the class to use for voxel_grid, + which must be available in the registry. Default FullResolutionVoxelGrid. + voxel_grid: An instance of `VoxelGridBase`. This is the object which + this class wraps. + extents: 3-tuple of a form (width, height, depth), denotes the size of the grid + in world units. + translation: 3-tuple of float. The center of the volume in world units as (x, y, z). + init_std: Parameters are initialized using the gaussian distribution + with mean=init_mean and std=init_std. Default 0.1 + init_mean: Parameters are initialized using the gaussian distribution + with mean=init_mean and std=init_std. Default 0. + """ + + voxel_grid_class_type: str = "FullResolutionVoxelGrid" + voxel_grid: VoxelGridBase + + extents: Tuple[float, float, float] = 1.0 + translation: Tuple[float, float, float] = (0.0, 0.0, 0.0) + + init_std: float = 0.1 + init_mean: float = 0 + + def __post_init__(self): + super().__init__() + run_auto_creation(self) + n_grids = 1 # Voxel grid objects are batched. We need only a single grid. + shapes = self.voxel_grid.get_shapes() + params = { + name: torch.normal( + mean=torch.zeros((n_grids, *shape)) + self.init_mean, + std=self.init_std, + ) + for name, shape in shapes.items() + } + self.params = torch.nn.ParameterDict(params) + + def forward(self, points: torch.Tensor) -> torch.Tensor: + """ + Evaluates points in the world coordinate frame on the voxel_grid. + + Args: + points (torch.Tensor): tensor of points that you want to query + of a form (n_points, 3) + Returns: + torch.Tensor of shape (n_points, n_features) + """ + locator = VolumeLocator( + batch_size=1, + # The resolution of the voxel grid does not need to be known + # to the locator object. It is easiest to fix the resolution of the locator. + # In particular we fix it to (2,2,2) so that there is exactly one voxel of the + # desired size. The locator object uses (z, y, x) convention for the grid_size, + # and this module uses (x, y, z) convention so the order has to be reversed + # (irrelevant in this case since they are all equal). + # It is (2, 2, 2) because the VolumeLocator object behaves like + # align_corners=True, which means that the points are in the corners of + # the volume. So in the grid of (2, 2, 2) there is only one voxel. + grid_sizes=(2, 2, 2), + # The locator object uses (x, y, z) convention for the + # voxel size and translation. + voxel_size=self.extents, + volume_translation=self.translation, + device=next(self.params.values()).device, + ) + grid_values = self.voxel_grid.values_type(**self.params) + # voxel grids operate with extra n_grids dimension, which we fix to one + return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0] diff --git a/tests/implicitron/test_voxel_grids.py b/tests/implicitron/test_voxel_grids.py index 4e5ce3be..75433936 100644 --- a/tests/implicitron/test_voxel_grids.py +++ b/tests/implicitron/test_voxel_grids.py @@ -19,6 +19,7 @@ from pytorch3d.implicitron.models.implicit_function.voxel_grid import ( CPFactorizedVoxelGrid, FullResolutionVoxelGrid, VMFactorizedVoxelGrid, + VoxelGridModule, ) from pytorch3d.implicitron.tools.config import expand_args_fields @@ -198,6 +199,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): expand_args_fields(FullResolutionVoxelGrid) expand_args_fields(CPFactorizedVoxelGrid) expand_args_fields(VMFactorizedVoxelGrid) + expand_args_fields(VoxelGridModule) def _interpolate_1D( self, points: torch.Tensor, vectors: torch.Tensor @@ -585,3 +587,27 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): n_features=10, n_components=3, ) + + def test_voxel_grid_module_location(self, n_times=10): + """ + This checks the module uses locator correctly etc.. + + If we know that voxel grids work for (x, y, z) in local coordinates + to test if the VoxelGridModule does not have permuted dimensions we + create local coordinates, pass them through verified voxelgrids and + compare the result with the result that we get when we convert + coordinates to world and pass them through the VoxelGridModule + """ + for _ in range(n_times): + extents = tuple(torch.randint(1, 50, size=(3,)).tolist()) + + grid = VoxelGridModule(extents=extents) + local_point = torch.rand(1, 3) * 2 - 1 + world_point = local_point * torch.tensor(extents) / 2 + grid_values = grid.voxel_grid.values_type(**grid.params) + + assert torch.allclose( + grid(world_point)[0, 0], + grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0], + rtol=0.0001, + )