arbitrary shape input to voxel_grids

Summary: Add the ability to process arbitrary point shapes `[n_grids, ..., 3]` instead of  only `[n_grids, n_points, 3]`.

Reviewed By: bottler

Differential Revision: D39574373

fbshipit-source-id: 0a9ecafe9ea58cd8f909644de43a1185ecf934f4
This commit is contained in:
Darijan Gudelj 2022-09-22 03:35:11 -07:00 committed by Facebook GitHub Bot
parent 6ae6ff9cf7
commit db3c12abfb
2 changed files with 93 additions and 55 deletions

View File

@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import ClassVar, Dict, Optional, Tuple, Type
import torch
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import (
Configurable,
registry,
@ -81,12 +82,12 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
Arguments:
points (torch.Tensor): tensor of points that you want to query
of a form (n_grids, n_points, 3)
of a form (n_grids, ..., 3)
grid_values: an object of type Class.values_type which has tensors as
members which have shapes derived from the get_shapes() method
locator: a VolumeLocator object
Returns:
torch.Tensor: shape (n_grids, n_points, n_features)
torch.Tensor: shape (n_grids, ..., n_features)
"""
points_local = locator.world_to_local_coords(points)
return self.evaluate_local(points_local, grid_values)
@ -100,11 +101,11 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
Arguments:
points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
of a form (n_grids, ..., 3), in a normalized form (coordinates are in [-1, 1])
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
as members which have shapes derived from the get_shapes() method
Returns:
torch.Tensor: shape (n_grids, n_points, n_features)
torch.Tensor: shape (n_grids, ..., n_features)
"""
raise NotImplementedError()
@ -117,11 +118,28 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
replace the shapes in the dictionary with tensors of those shapes and add the
first 'batch' dimension. If the required shape is (a, b) and you want to
have g grids than the tensor that replaces the shape should have the
have g grids then the tensor that replaces the shape should have the
shape (g, a, b).
"""
raise NotImplementedError()
@staticmethod
def get_output_dim(args: DictConfig) -> int:
"""
Given all the arguments of the grid's __init__, returns output's last dimension length.
In particular, if self.evaluate_world or self.evaluate_local
are called with `points` of shape (n_grids, n_points, 3),
their output will be of shape
(n_grids, n_points, grid.get_output_dim()).
Args:
args: DictConfig which would be used to initialize the object
Returns:
output's last dimension length
"""
return args["n_features"]
@dataclass
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
@ -149,19 +167,23 @@ class FullResolutionVoxelGrid(VoxelGridBase):
Arguments:
points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
of a form (..., 3), in a normalized form (coordinates are in [-1, 1])
grid_values: an object of type values_type which has tensors as
members which have shapes derived from the get_shapes() method
Returns:
torch.Tensor: shape (n_grids, n_points, n_features)
torch.Tensor: shape (n_grids, ..., n_features)
"""
return interpolate_volume(
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
recorded_shape = points.shape
points = points.view(points.shape[0], -1, points.shape[-1])
interpolated = interpolate_volume(
points,
grid_values.voxel_grid,
align_corners=self.align_corners,
padding_mode=self.padding,
mode=self.mode,
)
return interpolated.view(*recorded_shape[:-1], -1)
def get_shapes(self) -> Dict[str, Tuple]:
return {"voxel_grid": (self.n_features, *self.resolution)}
@ -202,10 +224,11 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
Members:
n_components: number of vector triplets, higher number gives better approximation.
matrix_reduction: how to transform components. If matrix_reduction=True result
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
is summed along the rows to get (n_grids, n_points, 1).
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
by the basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
to return to starting shape (n_grids, ..., 1).
"""
# the type of grid_values argument needed to run evaluate_local()
@ -219,8 +242,8 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
def evaluate_local(
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
) -> torch.Tensor:
def factor(i):
axis = ["x", "y", "z"][i]
def factor(axis):
i = {"x": 0, "y": 1, "z": 2}[axis]
index = points[..., i, None]
vector = getattr(grid_values, "vector_components_" + axis)
return interpolate_line(
@ -231,17 +254,25 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
mode=self.mode,
)
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
recorded_shape = points.shape
points = points.view(points.shape[0], -1, points.shape[-1])
# collect points from all the vectors and multipy them out
mult = factor(0) * factor(1) * factor(2)
mult = factor("x") * factor("y") * factor("z")
# reduce the result from
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
# (n_grids, n_points_total, n_components) to (n_grids, n_points_total, n_features)
if grid_values.basis_matrix is not None:
# (n_grids, n_points, n_features) =
# (n_grids, n_points, total_n_components) x (total_n_components, n_features)
return torch.bmm(mult, grid_values.basis_matrix)
return mult.sum(axis=-1, keepdim=True)
# (n_grids, n_points_total, n_features) =
# (n_grids, n_points_total, total_n_components) @
# (n_grids, total_n_components, n_features)
result = torch.bmm(mult, grid_values.basis_matrix)
else:
# (n_grids, n_points_total, 1) from (n_grids, n_points_total, n_features)
result = mult.sum(axis=-1, keepdim=True)
# (n_grids, ..., n_features)
return result.view(*recorded_shape[:-1], -1)
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
if self.matrix_reduction is False and self.n_features != 1:
@ -308,10 +339,11 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
either n_components or distribution_of_components, you cannot specify both.
matrix_reduction: how to transform components. If matrix_reduction=True result
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
the basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
is summed along the rows to get (n_grids, n_points, 1).
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
by the basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
to return to starting shape (n_grids, ..., 1).
"""
# the type of grid_values argument needed to run evaluate_local()
@ -326,6 +358,10 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
def evaluate_local(
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
) -> torch.Tensor:
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
recorded_shape = points.shape
points = points.view(points.shape[0], -1, points.shape[-1])
# collect points from matrices and vectors and multiply them
a = interpolate_plane(
points[..., :2],
@ -375,9 +411,13 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
# (n_grids, n_points, n_features) =
# (n_grids, n_points, total_n_components) x
# (n_grids, total_n_components, n_features)
return torch.bmm(feats, grid_values.basis_matrix)
# pyre-ignore[28]
return feats.sum(axis=-1, keepdim=True)
result = torch.bmm(feats, grid_values.basis_matrix)
else:
# pyre-ignore[28]
# (n_grids, n_points, 1) from (n_grids, n_points, n_features)
result = feats.sum(axis=-1, keepdim=True)
# (n_grids, ..., n_features)
return result.view(*recorded_shape[:-1], -1)
def get_shapes(self) -> Dict[str, Tuple]:
if self.matrix_reduction is False and self.n_features != 1:
@ -494,9 +534,9 @@ class VoxelGridModule(Configurable, torch.nn.Module):
Args:
points (torch.Tensor): tensor of points that you want to query
of a form (n_points, 3)
of a form (..., 3)
Returns:
torch.Tensor of shape (n_points, n_features)
torch.Tensor of shape (..., n_features)
"""
locator = VolumeLocator(
batch_size=1,

View File

@ -38,10 +38,17 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
return
def get_random_normalized_points(
self, n_grids, n_points, dimension=3
self, n_grids, n_points=None, dimension=3
) -> torch.Tensor:
middle_shape = torch.randint(1, 4, tuple(torch.randint(1, 5, size=(1,))))
# create random query points
return torch.rand(n_grids, n_points, dimension) * 2 - 1
return (
torch.rand(
n_grids, *(middle_shape if n_points is None else [n_points]), dimension
)
* 2
- 1
)
def _test_query_with_constant_init_cp(
self,
@ -50,7 +57,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_components: int,
resolution: Tuple[int],
value: float = 1,
n_points: int = 1,
) -> None:
# set everything to 'value' and do query for elementsthe result should
# be of shape (n_grids, n_points, n_features) and be filled with n_components
@ -65,12 +71,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
)
points = self.get_random_normalized_points(n_grids)
assert torch.allclose(
grid.evaluate_local(
self.get_random_normalized_points(n_grids, n_points), params
),
torch.ones(n_grids, n_points, n_features) * n_components * value,
grid.evaluate_local(points, params),
torch.ones(n_grids, *points.shape[1:-1], n_features) * n_components * value,
rtol=0.0001,
)
def _test_query_with_constant_init_vm(
@ -98,11 +103,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
expected_element = (
n_components * value if distribution is None else sum(distribution) * value
)
points = self.get_random_normalized_points(n_grids)
assert torch.allclose(
grid.evaluate_local(
self.get_random_normalized_points(n_grids, n_points), params
),
torch.ones(n_grids, n_points, n_features) * expected_element,
grid.evaluate_local(points, params),
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
)
def _test_query_with_constant_init_full(
@ -121,21 +125,20 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
)
expected_element = value
points = self.get_random_normalized_points(n_grids)
assert torch.allclose(
grid.evaluate_local(
self.get_random_normalized_points(n_grids, n_points), params
),
torch.ones(n_grids, n_points, n_features) * expected_element,
grid.evaluate_local(points, params),
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
)
def test_query_with_constant_init(self):
with self.subTest("Full"):
self._test_query_with_constant_init_full(
n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3
n_grids=5, n_features=6, resolution=(3, 4, 5)
)
with self.subTest("Full with 1 in dimensions"):
self._test_query_with_constant_init_full(
n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4
n_grids=5, n_features=1, resolution=(33, 41, 1)
)
with self.subTest("CP"):
self._test_query_with_constant_init_cp(
@ -143,7 +146,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=6,
n_components=7,
resolution=(3, 4, 5),
n_points=2,
)
with self.subTest("CP with 1 in dimensions"):
self._test_query_with_constant_init_cp(
@ -151,7 +153,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=1,
n_components=3,
resolution=(3, 1, 1),
n_points=4,
)
with self.subTest("VM with symetric distribution"):
self._test_query_with_constant_init_vm(
@ -159,7 +160,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=9,
resolution=(2, 12, 2),
n_components=12,
n_points=3,
)
with self.subTest("VM with distribution"):
self._test_query_with_constant_init_vm(
@ -167,7 +167,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=1,
resolution=(5, 9, 7),
distribution=(33, 41, 1),
n_points=7,
)
def test_query_with_zero_init(self):
@ -177,7 +176,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=6,
n_components=7,
resolution=(3, 2, 5),
n_points=3,
value=0,
)
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
@ -186,12 +184,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_features=9,
resolution=(2, 11, 3),
n_components=3,
n_points=3,
value=0,
)
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
self._test_query_with_constant_init_full(
n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0
n_grids=4, n_features=2, resolution=(3, 3, 5), value=0
)
def setUp(self):
@ -324,6 +321,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
padding_mode="zeros",
mode="bilinear",
),
rtol=0.0001,
)
def test_floating_point_query(self):