arbitrary shape input to voxel_grids

Summary: Add the ability to process arbitrary point shapes `[n_grids, ..., 3]` instead of  only `[n_grids, n_points, 3]`.

Reviewed By: bottler

Differential Revision: D39574373

fbshipit-source-id: 0a9ecafe9ea58cd8f909644de43a1185ecf934f4
This commit is contained in:
Darijan Gudelj 2022-09-22 03:35:11 -07:00 committed by Facebook GitHub Bot
parent 6ae6ff9cf7
commit db3c12abfb
2 changed files with 93 additions and 55 deletions

View File

@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import ClassVar, Dict, Optional, Tuple, Type from typing import ClassVar, Dict, Optional, Tuple, Type
import torch import torch
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
Configurable, Configurable,
registry, registry,
@ -81,12 +82,12 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
Arguments: Arguments:
points (torch.Tensor): tensor of points that you want to query points (torch.Tensor): tensor of points that you want to query
of a form (n_grids, n_points, 3) of a form (n_grids, ..., 3)
grid_values: an object of type Class.values_type which has tensors as grid_values: an object of type Class.values_type which has tensors as
members which have shapes derived from the get_shapes() method members which have shapes derived from the get_shapes() method
locator: a VolumeLocator object locator: a VolumeLocator object
Returns: Returns:
torch.Tensor: shape (n_grids, n_points, n_features) torch.Tensor: shape (n_grids, ..., n_features)
""" """
points_local = locator.world_to_local_coords(points) points_local = locator.world_to_local_coords(points)
return self.evaluate_local(points_local, grid_values) return self.evaluate_local(points_local, grid_values)
@ -100,11 +101,11 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
Arguments: Arguments:
points (torch.Tensor): tensor of points that you want to query points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1]) of a form (n_grids, ..., 3), in a normalized form (coordinates are in [-1, 1])
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
as members which have shapes derived from the get_shapes() method as members which have shapes derived from the get_shapes() method
Returns: Returns:
torch.Tensor: shape (n_grids, n_points, n_features) torch.Tensor: shape (n_grids, ..., n_features)
""" """
raise NotImplementedError() raise NotImplementedError()
@ -117,11 +118,28 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
replace the shapes in the dictionary with tensors of those shapes and add the replace the shapes in the dictionary with tensors of those shapes and add the
first 'batch' dimension. If the required shape is (a, b) and you want to first 'batch' dimension. If the required shape is (a, b) and you want to
have g grids than the tensor that replaces the shape should have the have g grids then the tensor that replaces the shape should have the
shape (g, a, b). shape (g, a, b).
""" """
raise NotImplementedError() raise NotImplementedError()
@staticmethod
def get_output_dim(args: DictConfig) -> int:
"""
Given all the arguments of the grid's __init__, returns output's last dimension length.
In particular, if self.evaluate_world or self.evaluate_local
are called with `points` of shape (n_grids, n_points, 3),
their output will be of shape
(n_grids, n_points, grid.get_output_dim()).
Args:
args: DictConfig which would be used to initialize the object
Returns:
output's last dimension length
"""
return args["n_features"]
@dataclass @dataclass
class FullResolutionVoxelGridValues(VoxelGridValuesBase): class FullResolutionVoxelGridValues(VoxelGridValuesBase):
@ -149,19 +167,23 @@ class FullResolutionVoxelGrid(VoxelGridBase):
Arguments: Arguments:
points (torch.Tensor): tensor of points that you want to query points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1]) of a form (..., 3), in a normalized form (coordinates are in [-1, 1])
grid_values: an object of type values_type which has tensors as grid_values: an object of type values_type which has tensors as
members which have shapes derived from the get_shapes() method members which have shapes derived from the get_shapes() method
Returns: Returns:
torch.Tensor: shape (n_grids, n_points, n_features) torch.Tensor: shape (n_grids, ..., n_features)
""" """
return interpolate_volume( # (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
recorded_shape = points.shape
points = points.view(points.shape[0], -1, points.shape[-1])
interpolated = interpolate_volume(
points, points,
grid_values.voxel_grid, grid_values.voxel_grid,
align_corners=self.align_corners, align_corners=self.align_corners,
padding_mode=self.padding, padding_mode=self.padding,
mode=self.mode, mode=self.mode,
) )
return interpolated.view(*recorded_shape[:-1], -1)
def get_shapes(self) -> Dict[str, Tuple]: def get_shapes(self) -> Dict[str, Tuple]:
return {"voxel_grid": (self.n_features, *self.resolution)} return {"voxel_grid": (self.n_features, *self.resolution)}
@ -202,10 +224,11 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
Members: Members:
n_components: number of vector triplets, higher number gives better approximation. n_components: number of vector triplets, higher number gives better approximation.
matrix_reduction: how to transform components. If matrix_reduction=True result matrix_reduction: how to transform components. If matrix_reduction=True result
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
basis_matrix of shape (n_grids, n_components, n_features). If by the basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components) matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
is summed along the rows to get (n_grids, n_points, 1). is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
to return to starting shape (n_grids, ..., 1).
""" """
# the type of grid_values argument needed to run evaluate_local() # the type of grid_values argument needed to run evaluate_local()
@ -219,8 +242,8 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
def evaluate_local( def evaluate_local(
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
) -> torch.Tensor: ) -> torch.Tensor:
def factor(i): def factor(axis):
axis = ["x", "y", "z"][i] i = {"x": 0, "y": 1, "z": 2}[axis]
index = points[..., i, None] index = points[..., i, None]
vector = getattr(grid_values, "vector_components_" + axis) vector = getattr(grid_values, "vector_components_" + axis)
return interpolate_line( return interpolate_line(
@ -231,17 +254,25 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
mode=self.mode, mode=self.mode,
) )
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
recorded_shape = points.shape
points = points.view(points.shape[0], -1, points.shape[-1])
# collect points from all the vectors and multipy them out # collect points from all the vectors and multipy them out
mult = factor(0) * factor(1) * factor(2) mult = factor("x") * factor("y") * factor("z")
# reduce the result from # reduce the result from
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features) # (n_grids, n_points_total, n_components) to (n_grids, n_points_total, n_features)
if grid_values.basis_matrix is not None: if grid_values.basis_matrix is not None:
# (n_grids, n_points, n_features) = # (n_grids, n_points_total, n_features) =
# (n_grids, n_points, total_n_components) x (total_n_components, n_features) # (n_grids, n_points_total, total_n_components) @
return torch.bmm(mult, grid_values.basis_matrix) # (n_grids, total_n_components, n_features)
result = torch.bmm(mult, grid_values.basis_matrix)
return mult.sum(axis=-1, keepdim=True) else:
# (n_grids, n_points_total, 1) from (n_grids, n_points_total, n_features)
result = mult.sum(axis=-1, keepdim=True)
# (n_grids, ..., n_features)
return result.view(*recorded_shape[:-1], -1)
def get_shapes(self) -> Dict[str, Tuple[int, int]]: def get_shapes(self) -> Dict[str, Tuple[int, int]]:
if self.matrix_reduction is False and self.n_features != 1: if self.matrix_reduction is False and self.n_features != 1:
@ -308,10 +339,11 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
either n_components or distribution_of_components, you cannot specify both. either n_components or distribution_of_components, you cannot specify both.
matrix_reduction: how to transform components. If matrix_reduction=True result matrix_reduction: how to transform components. If matrix_reduction=True result
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
the basis_matrix of shape (n_grids, n_components, n_features). If by the basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components) matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
is summed along the rows to get (n_grids, n_points, 1). is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
to return to starting shape (n_grids, ..., 1).
""" """
# the type of grid_values argument needed to run evaluate_local() # the type of grid_values argument needed to run evaluate_local()
@ -326,6 +358,10 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
def evaluate_local( def evaluate_local(
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
) -> torch.Tensor: ) -> torch.Tensor:
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
recorded_shape = points.shape
points = points.view(points.shape[0], -1, points.shape[-1])
# collect points from matrices and vectors and multiply them # collect points from matrices and vectors and multiply them
a = interpolate_plane( a = interpolate_plane(
points[..., :2], points[..., :2],
@ -375,9 +411,13 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
# (n_grids, n_points, n_features) = # (n_grids, n_points, n_features) =
# (n_grids, n_points, total_n_components) x # (n_grids, n_points, total_n_components) x
# (n_grids, total_n_components, n_features) # (n_grids, total_n_components, n_features)
return torch.bmm(feats, grid_values.basis_matrix) result = torch.bmm(feats, grid_values.basis_matrix)
else:
# pyre-ignore[28] # pyre-ignore[28]
return feats.sum(axis=-1, keepdim=True) # (n_grids, n_points, 1) from (n_grids, n_points, n_features)
result = feats.sum(axis=-1, keepdim=True)
# (n_grids, ..., n_features)
return result.view(*recorded_shape[:-1], -1)
def get_shapes(self) -> Dict[str, Tuple]: def get_shapes(self) -> Dict[str, Tuple]:
if self.matrix_reduction is False and self.n_features != 1: if self.matrix_reduction is False and self.n_features != 1:
@ -494,9 +534,9 @@ class VoxelGridModule(Configurable, torch.nn.Module):
Args: Args:
points (torch.Tensor): tensor of points that you want to query points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3) of a form (..., 3)
Returns: Returns:
torch.Tensor of shape (n_points, n_features) torch.Tensor of shape (..., n_features)
""" """
locator = VolumeLocator( locator = VolumeLocator(
batch_size=1, batch_size=1,

View File

@ -38,10 +38,17 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
return return
def get_random_normalized_points( def get_random_normalized_points(
self, n_grids, n_points, dimension=3 self, n_grids, n_points=None, dimension=3
) -> torch.Tensor: ) -> torch.Tensor:
middle_shape = torch.randint(1, 4, tuple(torch.randint(1, 5, size=(1,))))
# create random query points # create random query points
return torch.rand(n_grids, n_points, dimension) * 2 - 1 return (
torch.rand(
n_grids, *(middle_shape if n_points is None else [n_points]), dimension
)
* 2
- 1
)
def _test_query_with_constant_init_cp( def _test_query_with_constant_init_cp(
self, self,
@ -50,7 +57,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_components: int, n_components: int,
resolution: Tuple[int], resolution: Tuple[int],
value: float = 1, value: float = 1,
n_points: int = 1,
) -> None: ) -> None:
# set everything to 'value' and do query for elementsthe result should # set everything to 'value' and do query for elementsthe result should
# be of shape (n_grids, n_points, n_features) and be filled with n_components # be of shape (n_grids, n_points, n_features) and be filled with n_components
@ -65,12 +71,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
params = grid.values_type( params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes} **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
) )
points = self.get_random_normalized_points(n_grids)
assert torch.allclose( assert torch.allclose(
grid.evaluate_local( grid.evaluate_local(points, params),
self.get_random_normalized_points(n_grids, n_points), params torch.ones(n_grids, *points.shape[1:-1], n_features) * n_components * value,
), rtol=0.0001,
torch.ones(n_grids, n_points, n_features) * n_components * value,
) )
def _test_query_with_constant_init_vm( def _test_query_with_constant_init_vm(
@ -98,11 +103,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
expected_element = ( expected_element = (
n_components * value if distribution is None else sum(distribution) * value n_components * value if distribution is None else sum(distribution) * value
) )
points = self.get_random_normalized_points(n_grids)
assert torch.allclose( assert torch.allclose(
grid.evaluate_local( grid.evaluate_local(points, params),
self.get_random_normalized_points(n_grids, n_points), params torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
),
torch.ones(n_grids, n_points, n_features) * expected_element,
) )
def _test_query_with_constant_init_full( def _test_query_with_constant_init_full(
@ -121,21 +125,20 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
) )
expected_element = value expected_element = value
points = self.get_random_normalized_points(n_grids)
assert torch.allclose( assert torch.allclose(
grid.evaluate_local( grid.evaluate_local(points, params),
self.get_random_normalized_points(n_grids, n_points), params torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
),
torch.ones(n_grids, n_points, n_features) * expected_element,
) )
def test_query_with_constant_init(self): def test_query_with_constant_init(self):
with self.subTest("Full"): with self.subTest("Full"):
self._test_query_with_constant_init_full( self._test_query_with_constant_init_full(
n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3 n_grids=5, n_features=6, resolution=(3, 4, 5)
) )
with self.subTest("Full with 1 in dimensions"): with self.subTest("Full with 1 in dimensions"):
self._test_query_with_constant_init_full( self._test_query_with_constant_init_full(
n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4 n_grids=5, n_features=1, resolution=(33, 41, 1)
) )
with self.subTest("CP"): with self.subTest("CP"):
self._test_query_with_constant_init_cp( self._test_query_with_constant_init_cp(
@ -143,7 +146,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=6, n_features=6,
n_components=7, n_components=7,
resolution=(3, 4, 5), resolution=(3, 4, 5),
n_points=2,
) )
with self.subTest("CP with 1 in dimensions"): with self.subTest("CP with 1 in dimensions"):
self._test_query_with_constant_init_cp( self._test_query_with_constant_init_cp(
@ -151,7 +153,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=1, n_features=1,
n_components=3, n_components=3,
resolution=(3, 1, 1), resolution=(3, 1, 1),
n_points=4,
) )
with self.subTest("VM with symetric distribution"): with self.subTest("VM with symetric distribution"):
self._test_query_with_constant_init_vm( self._test_query_with_constant_init_vm(
@ -159,7 +160,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=9, n_features=9,
resolution=(2, 12, 2), resolution=(2, 12, 2),
n_components=12, n_components=12,
n_points=3,
) )
with self.subTest("VM with distribution"): with self.subTest("VM with distribution"):
self._test_query_with_constant_init_vm( self._test_query_with_constant_init_vm(
@ -167,7 +167,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=1, n_features=1,
resolution=(5, 9, 7), resolution=(5, 9, 7),
distribution=(33, 41, 1), distribution=(33, 41, 1),
n_points=7,
) )
def test_query_with_zero_init(self): def test_query_with_zero_init(self):
@ -177,7 +176,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=6, n_features=6,
n_components=7, n_components=7,
resolution=(3, 2, 5), resolution=(3, 2, 5),
n_points=3,
value=0, value=0,
) )
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"): with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
@ -186,12 +184,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=9, n_features=9,
resolution=(2, 11, 3), resolution=(2, 11, 3),
n_components=3, n_components=3,
n_points=3,
value=0, value=0,
) )
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"): with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
self._test_query_with_constant_init_full( self._test_query_with_constant_init_full(
n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0 n_grids=4, n_features=2, resolution=(3, 3, 5), value=0
) )
def setUp(self): def setUp(self):
@ -324,6 +321,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
padding_mode="zeros", padding_mode="zeros",
mode="bilinear", mode="bilinear",
), ),
rtol=0.0001,
) )
def test_floating_point_query(self): def test_floating_point_query(self):