mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
arbitrary shape input to voxel_grids
Summary: Add the ability to process arbitrary point shapes `[n_grids, ..., 3]` instead of only `[n_grids, n_points, 3]`. Reviewed By: bottler Differential Revision: D39574373 fbshipit-source-id: 0a9ecafe9ea58cd8f909644de43a1185ecf934f4
This commit is contained in:
parent
6ae6ff9cf7
commit
db3c12abfb
@ -19,6 +19,7 @@ from dataclasses import dataclass
|
||||
from typing import ClassVar, Dict, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
registry,
|
||||
@ -81,12 +82,12 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_grids, n_points, 3)
|
||||
of a form (n_grids, ..., 3)
|
||||
grid_values: an object of type Class.values_type which has tensors as
|
||||
members which have shapes derived from the get_shapes() method
|
||||
locator: a VolumeLocator object
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
torch.Tensor: shape (n_grids, ..., n_features)
|
||||
"""
|
||||
points_local = locator.world_to_local_coords(points)
|
||||
return self.evaluate_local(points_local, grid_values)
|
||||
@ -100,11 +101,11 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
||||
of a form (n_grids, ..., 3), in a normalized form (coordinates are in [-1, 1])
|
||||
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
|
||||
as members which have shapes derived from the get_shapes() method
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
torch.Tensor: shape (n_grids, ..., n_features)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -117,11 +118,28 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
|
||||
replace the shapes in the dictionary with tensors of those shapes and add the
|
||||
first 'batch' dimension. If the required shape is (a, b) and you want to
|
||||
have g grids than the tensor that replaces the shape should have the
|
||||
have g grids then the tensor that replaces the shape should have the
|
||||
shape (g, a, b).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def get_output_dim(args: DictConfig) -> int:
|
||||
"""
|
||||
Given all the arguments of the grid's __init__, returns output's last dimension length.
|
||||
|
||||
In particular, if self.evaluate_world or self.evaluate_local
|
||||
are called with `points` of shape (n_grids, n_points, 3),
|
||||
their output will be of shape
|
||||
(n_grids, n_points, grid.get_output_dim()).
|
||||
|
||||
Args:
|
||||
args: DictConfig which would be used to initialize the object
|
||||
Returns:
|
||||
output's last dimension length
|
||||
"""
|
||||
return args["n_features"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
|
||||
@ -149,19 +167,23 @@ class FullResolutionVoxelGrid(VoxelGridBase):
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
||||
of a form (..., 3), in a normalized form (coordinates are in [-1, 1])
|
||||
grid_values: an object of type values_type which has tensors as
|
||||
members which have shapes derived from the get_shapes() method
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
torch.Tensor: shape (n_grids, ..., n_features)
|
||||
"""
|
||||
return interpolate_volume(
|
||||
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
|
||||
recorded_shape = points.shape
|
||||
points = points.view(points.shape[0], -1, points.shape[-1])
|
||||
interpolated = interpolate_volume(
|
||||
points,
|
||||
grid_values.voxel_grid,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
return interpolated.view(*recorded_shape[:-1], -1)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
return {"voxel_grid": (self.n_features, *self.resolution)}
|
||||
@ -202,10 +224,11 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
Members:
|
||||
n_components: number of vector triplets, higher number gives better approximation.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
|
||||
basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
||||
is summed along the rows to get (n_grids, n_points, 1).
|
||||
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
|
||||
by the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
|
||||
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
|
||||
to return to starting shape (n_grids, ..., 1).
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
@ -219,8 +242,8 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
def factor(i):
|
||||
axis = ["x", "y", "z"][i]
|
||||
def factor(axis):
|
||||
i = {"x": 0, "y": 1, "z": 2}[axis]
|
||||
index = points[..., i, None]
|
||||
vector = getattr(grid_values, "vector_components_" + axis)
|
||||
return interpolate_line(
|
||||
@ -231,17 +254,25 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
|
||||
recorded_shape = points.shape
|
||||
points = points.view(points.shape[0], -1, points.shape[-1])
|
||||
|
||||
# collect points from all the vectors and multipy them out
|
||||
mult = factor(0) * factor(1) * factor(2)
|
||||
mult = factor("x") * factor("y") * factor("z")
|
||||
|
||||
# reduce the result from
|
||||
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
|
||||
# (n_grids, n_points_total, n_components) to (n_grids, n_points_total, n_features)
|
||||
if grid_values.basis_matrix is not None:
|
||||
# (n_grids, n_points, n_features) =
|
||||
# (n_grids, n_points, total_n_components) x (total_n_components, n_features)
|
||||
return torch.bmm(mult, grid_values.basis_matrix)
|
||||
|
||||
return mult.sum(axis=-1, keepdim=True)
|
||||
# (n_grids, n_points_total, n_features) =
|
||||
# (n_grids, n_points_total, total_n_components) @
|
||||
# (n_grids, total_n_components, n_features)
|
||||
result = torch.bmm(mult, grid_values.basis_matrix)
|
||||
else:
|
||||
# (n_grids, n_points_total, 1) from (n_grids, n_points_total, n_features)
|
||||
result = mult.sum(axis=-1, keepdim=True)
|
||||
# (n_grids, ..., n_features)
|
||||
return result.view(*recorded_shape[:-1], -1)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
@ -308,10 +339,11 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
|
||||
either n_components or distribution_of_components, you cannot specify both.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
|
||||
the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
||||
is summed along the rows to get (n_grids, n_points, 1).
|
||||
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
|
||||
by the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
|
||||
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
|
||||
to return to starting shape (n_grids, ..., 1).
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
@ -326,6 +358,10 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
|
||||
recorded_shape = points.shape
|
||||
points = points.view(points.shape[0], -1, points.shape[-1])
|
||||
|
||||
# collect points from matrices and vectors and multiply them
|
||||
a = interpolate_plane(
|
||||
points[..., :2],
|
||||
@ -375,9 +411,13 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
# (n_grids, n_points, n_features) =
|
||||
# (n_grids, n_points, total_n_components) x
|
||||
# (n_grids, total_n_components, n_features)
|
||||
return torch.bmm(feats, grid_values.basis_matrix)
|
||||
# pyre-ignore[28]
|
||||
return feats.sum(axis=-1, keepdim=True)
|
||||
result = torch.bmm(feats, grid_values.basis_matrix)
|
||||
else:
|
||||
# pyre-ignore[28]
|
||||
# (n_grids, n_points, 1) from (n_grids, n_points, n_features)
|
||||
result = feats.sum(axis=-1, keepdim=True)
|
||||
# (n_grids, ..., n_features)
|
||||
return result.view(*recorded_shape[:-1], -1)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
@ -494,9 +534,9 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
|
||||
Args:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3)
|
||||
of a form (..., 3)
|
||||
Returns:
|
||||
torch.Tensor of shape (n_points, n_features)
|
||||
torch.Tensor of shape (..., n_features)
|
||||
"""
|
||||
locator = VolumeLocator(
|
||||
batch_size=1,
|
||||
|
@ -38,10 +38,17 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
return
|
||||
|
||||
def get_random_normalized_points(
|
||||
self, n_grids, n_points, dimension=3
|
||||
self, n_grids, n_points=None, dimension=3
|
||||
) -> torch.Tensor:
|
||||
middle_shape = torch.randint(1, 4, tuple(torch.randint(1, 5, size=(1,))))
|
||||
# create random query points
|
||||
return torch.rand(n_grids, n_points, dimension) * 2 - 1
|
||||
return (
|
||||
torch.rand(
|
||||
n_grids, *(middle_shape if n_points is None else [n_points]), dimension
|
||||
)
|
||||
* 2
|
||||
- 1
|
||||
)
|
||||
|
||||
def _test_query_with_constant_init_cp(
|
||||
self,
|
||||
@ -50,7 +57,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_components: int,
|
||||
resolution: Tuple[int],
|
||||
value: float = 1,
|
||||
n_points: int = 1,
|
||||
) -> None:
|
||||
# set everything to 'value' and do query for elementsthe result should
|
||||
# be of shape (n_grids, n_points, n_features) and be filled with n_components
|
||||
@ -65,12 +71,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
)
|
||||
|
||||
points = self.get_random_normalized_points(n_grids)
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(
|
||||
self.get_random_normalized_points(n_grids, n_points), params
|
||||
),
|
||||
torch.ones(n_grids, n_points, n_features) * n_components * value,
|
||||
grid.evaluate_local(points, params),
|
||||
torch.ones(n_grids, *points.shape[1:-1], n_features) * n_components * value,
|
||||
rtol=0.0001,
|
||||
)
|
||||
|
||||
def _test_query_with_constant_init_vm(
|
||||
@ -98,11 +103,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
expected_element = (
|
||||
n_components * value if distribution is None else sum(distribution) * value
|
||||
)
|
||||
points = self.get_random_normalized_points(n_grids)
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(
|
||||
self.get_random_normalized_points(n_grids, n_points), params
|
||||
),
|
||||
torch.ones(n_grids, n_points, n_features) * expected_element,
|
||||
grid.evaluate_local(points, params),
|
||||
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
|
||||
)
|
||||
|
||||
def _test_query_with_constant_init_full(
|
||||
@ -121,21 +125,20 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
expected_element = value
|
||||
points = self.get_random_normalized_points(n_grids)
|
||||
assert torch.allclose(
|
||||
grid.evaluate_local(
|
||||
self.get_random_normalized_points(n_grids, n_points), params
|
||||
),
|
||||
torch.ones(n_grids, n_points, n_features) * expected_element,
|
||||
grid.evaluate_local(points, params),
|
||||
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
|
||||
)
|
||||
|
||||
def test_query_with_constant_init(self):
|
||||
with self.subTest("Full"):
|
||||
self._test_query_with_constant_init_full(
|
||||
n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3
|
||||
n_grids=5, n_features=6, resolution=(3, 4, 5)
|
||||
)
|
||||
with self.subTest("Full with 1 in dimensions"):
|
||||
self._test_query_with_constant_init_full(
|
||||
n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4
|
||||
n_grids=5, n_features=1, resolution=(33, 41, 1)
|
||||
)
|
||||
with self.subTest("CP"):
|
||||
self._test_query_with_constant_init_cp(
|
||||
@ -143,7 +146,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=6,
|
||||
n_components=7,
|
||||
resolution=(3, 4, 5),
|
||||
n_points=2,
|
||||
)
|
||||
with self.subTest("CP with 1 in dimensions"):
|
||||
self._test_query_with_constant_init_cp(
|
||||
@ -151,7 +153,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=1,
|
||||
n_components=3,
|
||||
resolution=(3, 1, 1),
|
||||
n_points=4,
|
||||
)
|
||||
with self.subTest("VM with symetric distribution"):
|
||||
self._test_query_with_constant_init_vm(
|
||||
@ -159,7 +160,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=9,
|
||||
resolution=(2, 12, 2),
|
||||
n_components=12,
|
||||
n_points=3,
|
||||
)
|
||||
with self.subTest("VM with distribution"):
|
||||
self._test_query_with_constant_init_vm(
|
||||
@ -167,7 +167,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=1,
|
||||
resolution=(5, 9, 7),
|
||||
distribution=(33, 41, 1),
|
||||
n_points=7,
|
||||
)
|
||||
|
||||
def test_query_with_zero_init(self):
|
||||
@ -177,7 +176,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=6,
|
||||
n_components=7,
|
||||
resolution=(3, 2, 5),
|
||||
n_points=3,
|
||||
value=0,
|
||||
)
|
||||
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
|
||||
@ -186,12 +184,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_features=9,
|
||||
resolution=(2, 11, 3),
|
||||
n_components=3,
|
||||
n_points=3,
|
||||
value=0,
|
||||
)
|
||||
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
|
||||
self._test_query_with_constant_init_full(
|
||||
n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0
|
||||
n_grids=4, n_features=2, resolution=(3, 3, 5), value=0
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
@ -324,6 +321,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
padding_mode="zeros",
|
||||
mode="bilinear",
|
||||
),
|
||||
rtol=0.0001,
|
||||
)
|
||||
|
||||
def test_floating_point_query(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user