From db3c12abfb32a5065d5163e13e2753274ad72a31 Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Thu, 22 Sep 2022 03:35:11 -0700 Subject: [PATCH] arbitrary shape input to voxel_grids Summary: Add the ability to process arbitrary point shapes `[n_grids, ..., 3]` instead of only `[n_grids, n_points, 3]`. Reviewed By: bottler Differential Revision: D39574373 fbshipit-source-id: 0a9ecafe9ea58cd8f909644de43a1185ecf934f4 --- .../models/implicit_function/voxel_grid.py | 100 ++++++++++++------ tests/implicitron/test_voxel_grids.py | 48 ++++----- 2 files changed, 93 insertions(+), 55 deletions(-) diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index 8b7cd932..fc97c69b 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from typing import ClassVar, Dict, Optional, Tuple, Type import torch +from omegaconf import DictConfig from pytorch3d.implicitron.tools.config import ( Configurable, registry, @@ -81,12 +82,12 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): Arguments: points (torch.Tensor): tensor of points that you want to query - of a form (n_grids, n_points, 3) + of a form (n_grids, ..., 3) grid_values: an object of type Class.values_type which has tensors as members which have shapes derived from the get_shapes() method locator: a VolumeLocator object Returns: - torch.Tensor: shape (n_grids, n_points, n_features) + torch.Tensor: shape (n_grids, ..., n_features) """ points_local = locator.world_to_local_coords(points) return self.evaluate_local(points_local, grid_values) @@ -100,11 +101,11 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): Arguments: points (torch.Tensor): tensor of points that you want to query - of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1]) + of a form (n_grids, ..., 3), in a normalized form (coordinates are in [-1, 1]) grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors as members which have shapes derived from the get_shapes() method Returns: - torch.Tensor: shape (n_grids, n_points, n_features) + torch.Tensor: shape (n_grids, ..., n_features) """ raise NotImplementedError() @@ -117,11 +118,28 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods replace the shapes in the dictionary with tensors of those shapes and add the first 'batch' dimension. If the required shape is (a, b) and you want to - have g grids than the tensor that replaces the shape should have the + have g grids then the tensor that replaces the shape should have the shape (g, a, b). """ raise NotImplementedError() + @staticmethod + def get_output_dim(args: DictConfig) -> int: + """ + Given all the arguments of the grid's __init__, returns output's last dimension length. + + In particular, if self.evaluate_world or self.evaluate_local + are called with `points` of shape (n_grids, n_points, 3), + their output will be of shape + (n_grids, n_points, grid.get_output_dim()). + + Args: + args: DictConfig which would be used to initialize the object + Returns: + output's last dimension length + """ + return args["n_features"] + @dataclass class FullResolutionVoxelGridValues(VoxelGridValuesBase): @@ -149,19 +167,23 @@ class FullResolutionVoxelGrid(VoxelGridBase): Arguments: points (torch.Tensor): tensor of points that you want to query - of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1]) + of a form (..., 3), in a normalized form (coordinates are in [-1, 1]) grid_values: an object of type values_type which has tensors as members which have shapes derived from the get_shapes() method Returns: - torch.Tensor: shape (n_grids, n_points, n_features) + torch.Tensor: shape (n_grids, ..., n_features) """ - return interpolate_volume( + # (n_grids, n_points_total, n_features) from (n_grids, ..., n_features) + recorded_shape = points.shape + points = points.view(points.shape[0], -1, points.shape[-1]) + interpolated = interpolate_volume( points, grid_values.voxel_grid, align_corners=self.align_corners, padding_mode=self.padding, mode=self.mode, ) + return interpolated.view(*recorded_shape[:-1], -1) def get_shapes(self) -> Dict[str, Tuple]: return {"voxel_grid": (self.n_features, *self.resolution)} @@ -202,10 +224,11 @@ class CPFactorizedVoxelGrid(VoxelGridBase): Members: n_components: number of vector triplets, higher number gives better approximation. matrix_reduction: how to transform components. If matrix_reduction=True result - matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the - basis_matrix of shape (n_grids, n_components, n_features). If - matrix_reduction=False, the result tensor of (n_grids, n_points, n_components) - is summed along the rows to get (n_grids, n_points, 1). + matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied + by the basis_matrix of shape (n_grids, n_components, n_features). If + matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components) + is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed + to return to starting shape (n_grids, ..., 1). """ # the type of grid_values argument needed to run evaluate_local() @@ -219,8 +242,8 @@ class CPFactorizedVoxelGrid(VoxelGridBase): def evaluate_local( self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues ) -> torch.Tensor: - def factor(i): - axis = ["x", "y", "z"][i] + def factor(axis): + i = {"x": 0, "y": 1, "z": 2}[axis] index = points[..., i, None] vector = getattr(grid_values, "vector_components_" + axis) return interpolate_line( @@ -231,17 +254,25 @@ class CPFactorizedVoxelGrid(VoxelGridBase): mode=self.mode, ) + # (n_grids, n_points_total, n_features) from (n_grids, ..., n_features) + recorded_shape = points.shape + points = points.view(points.shape[0], -1, points.shape[-1]) + # collect points from all the vectors and multipy them out - mult = factor(0) * factor(1) * factor(2) + mult = factor("x") * factor("y") * factor("z") # reduce the result from - # (n_grids, n_points, n_components) to (n_grids, n_points, n_features) + # (n_grids, n_points_total, n_components) to (n_grids, n_points_total, n_features) if grid_values.basis_matrix is not None: - # (n_grids, n_points, n_features) = - # (n_grids, n_points, total_n_components) x (total_n_components, n_features) - return torch.bmm(mult, grid_values.basis_matrix) - - return mult.sum(axis=-1, keepdim=True) + # (n_grids, n_points_total, n_features) = + # (n_grids, n_points_total, total_n_components) @ + # (n_grids, total_n_components, n_features) + result = torch.bmm(mult, grid_values.basis_matrix) + else: + # (n_grids, n_points_total, 1) from (n_grids, n_points_total, n_features) + result = mult.sum(axis=-1, keepdim=True) + # (n_grids, ..., n_features) + return result.view(*recorded_shape[:-1], -1) def get_shapes(self) -> Dict[str, Tuple[int, int]]: if self.matrix_reduction is False and self.n_features != 1: @@ -308,10 +339,11 @@ class VMFactorizedVoxelGrid(VoxelGridBase): coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify either n_components or distribution_of_components, you cannot specify both. matrix_reduction: how to transform components. If matrix_reduction=True result - matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by - the basis_matrix of shape (n_grids, n_components, n_features). If - matrix_reduction=False, the result tensor of (n_grids, n_points, n_components) - is summed along the rows to get (n_grids, n_points, 1). + matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied + by the basis_matrix of shape (n_grids, n_components, n_features). If + matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components) + is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed + to return to starting shape (n_grids, ..., 1). """ # the type of grid_values argument needed to run evaluate_local() @@ -326,6 +358,10 @@ class VMFactorizedVoxelGrid(VoxelGridBase): def evaluate_local( self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues ) -> torch.Tensor: + # (n_grids, n_points_total, n_features) from (n_grids, ..., n_features) + recorded_shape = points.shape + points = points.view(points.shape[0], -1, points.shape[-1]) + # collect points from matrices and vectors and multiply them a = interpolate_plane( points[..., :2], @@ -375,9 +411,13 @@ class VMFactorizedVoxelGrid(VoxelGridBase): # (n_grids, n_points, n_features) = # (n_grids, n_points, total_n_components) x # (n_grids, total_n_components, n_features) - return torch.bmm(feats, grid_values.basis_matrix) - # pyre-ignore[28] - return feats.sum(axis=-1, keepdim=True) + result = torch.bmm(feats, grid_values.basis_matrix) + else: + # pyre-ignore[28] + # (n_grids, n_points, 1) from (n_grids, n_points, n_features) + result = feats.sum(axis=-1, keepdim=True) + # (n_grids, ..., n_features) + return result.view(*recorded_shape[:-1], -1) def get_shapes(self) -> Dict[str, Tuple]: if self.matrix_reduction is False and self.n_features != 1: @@ -494,9 +534,9 @@ class VoxelGridModule(Configurable, torch.nn.Module): Args: points (torch.Tensor): tensor of points that you want to query - of a form (n_points, 3) + of a form (..., 3) Returns: - torch.Tensor of shape (n_points, n_features) + torch.Tensor of shape (..., n_features) """ locator = VolumeLocator( batch_size=1, diff --git a/tests/implicitron/test_voxel_grids.py b/tests/implicitron/test_voxel_grids.py index 75433936..110521cc 100644 --- a/tests/implicitron/test_voxel_grids.py +++ b/tests/implicitron/test_voxel_grids.py @@ -38,10 +38,17 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): return def get_random_normalized_points( - self, n_grids, n_points, dimension=3 + self, n_grids, n_points=None, dimension=3 ) -> torch.Tensor: + middle_shape = torch.randint(1, 4, tuple(torch.randint(1, 5, size=(1,)))) # create random query points - return torch.rand(n_grids, n_points, dimension) * 2 - 1 + return ( + torch.rand( + n_grids, *(middle_shape if n_points is None else [n_points]), dimension + ) + * 2 + - 1 + ) def _test_query_with_constant_init_cp( self, @@ -50,7 +57,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): n_components: int, resolution: Tuple[int], value: float = 1, - n_points: int = 1, ) -> None: # set everything to 'value' and do query for elementsthe result should # be of shape (n_grids, n_points, n_features) and be filled with n_components @@ -65,12 +71,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): params = grid.values_type( **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes} ) - + points = self.get_random_normalized_points(n_grids) assert torch.allclose( - grid.evaluate_local( - self.get_random_normalized_points(n_grids, n_points), params - ), - torch.ones(n_grids, n_points, n_features) * n_components * value, + grid.evaluate_local(points, params), + torch.ones(n_grids, *points.shape[1:-1], n_features) * n_components * value, + rtol=0.0001, ) def _test_query_with_constant_init_vm( @@ -98,11 +103,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): expected_element = ( n_components * value if distribution is None else sum(distribution) * value ) + points = self.get_random_normalized_points(n_grids) assert torch.allclose( - grid.evaluate_local( - self.get_random_normalized_points(n_grids, n_points), params - ), - torch.ones(n_grids, n_points, n_features) * expected_element, + grid.evaluate_local(points, params), + torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element, ) def _test_query_with_constant_init_full( @@ -121,21 +125,20 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): ) expected_element = value + points = self.get_random_normalized_points(n_grids) assert torch.allclose( - grid.evaluate_local( - self.get_random_normalized_points(n_grids, n_points), params - ), - torch.ones(n_grids, n_points, n_features) * expected_element, + grid.evaluate_local(points, params), + torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element, ) def test_query_with_constant_init(self): with self.subTest("Full"): self._test_query_with_constant_init_full( - n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3 + n_grids=5, n_features=6, resolution=(3, 4, 5) ) with self.subTest("Full with 1 in dimensions"): self._test_query_with_constant_init_full( - n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4 + n_grids=5, n_features=1, resolution=(33, 41, 1) ) with self.subTest("CP"): self._test_query_with_constant_init_cp( @@ -143,7 +146,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): n_features=6, n_components=7, resolution=(3, 4, 5), - n_points=2, ) with self.subTest("CP with 1 in dimensions"): self._test_query_with_constant_init_cp( @@ -151,7 +153,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): n_features=1, n_components=3, resolution=(3, 1, 1), - n_points=4, ) with self.subTest("VM with symetric distribution"): self._test_query_with_constant_init_vm( @@ -159,7 +160,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): n_features=9, resolution=(2, 12, 2), n_components=12, - n_points=3, ) with self.subTest("VM with distribution"): self._test_query_with_constant_init_vm( @@ -167,7 +167,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): n_features=1, resolution=(5, 9, 7), distribution=(33, 41, 1), - n_points=7, ) def test_query_with_zero_init(self): @@ -177,7 +176,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): n_features=6, n_components=7, resolution=(3, 2, 5), - n_points=3, value=0, ) with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"): @@ -186,12 +184,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): n_features=9, resolution=(2, 11, 3), n_components=3, - n_points=3, value=0, ) with self.subTest("Query testing with zero init FullResolutionVoxelGrid"): self._test_query_with_constant_init_full( - n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0 + n_grids=4, n_features=2, resolution=(3, 3, 5), value=0 ) def setUp(self): @@ -324,6 +321,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): padding_mode="zeros", mode="bilinear", ), + rtol=0.0001, ) def test_floating_point_query(self):