mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
arbitrary shape input to voxel_grids
Summary: Add the ability to process arbitrary point shapes `[n_grids, ..., 3]` instead of only `[n_grids, n_points, 3]`. Reviewed By: bottler Differential Revision: D39574373 fbshipit-source-id: 0a9ecafe9ea58cd8f909644de43a1185ecf934f4
This commit is contained in:
parent
6ae6ff9cf7
commit
db3c12abfb
@ -19,6 +19,7 @@ from dataclasses import dataclass
|
|||||||
from typing import ClassVar, Dict, Optional, Tuple, Type
|
from typing import ClassVar, Dict, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
Configurable,
|
Configurable,
|
||||||
registry,
|
registry,
|
||||||
@ -81,12 +82,12 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
points (torch.Tensor): tensor of points that you want to query
|
points (torch.Tensor): tensor of points that you want to query
|
||||||
of a form (n_grids, n_points, 3)
|
of a form (n_grids, ..., 3)
|
||||||
grid_values: an object of type Class.values_type which has tensors as
|
grid_values: an object of type Class.values_type which has tensors as
|
||||||
members which have shapes derived from the get_shapes() method
|
members which have shapes derived from the get_shapes() method
|
||||||
locator: a VolumeLocator object
|
locator: a VolumeLocator object
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
torch.Tensor: shape (n_grids, ..., n_features)
|
||||||
"""
|
"""
|
||||||
points_local = locator.world_to_local_coords(points)
|
points_local = locator.world_to_local_coords(points)
|
||||||
return self.evaluate_local(points_local, grid_values)
|
return self.evaluate_local(points_local, grid_values)
|
||||||
@ -100,11 +101,11 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
points (torch.Tensor): tensor of points that you want to query
|
points (torch.Tensor): tensor of points that you want to query
|
||||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
of a form (n_grids, ..., 3), in a normalized form (coordinates are in [-1, 1])
|
||||||
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
|
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
|
||||||
as members which have shapes derived from the get_shapes() method
|
as members which have shapes derived from the get_shapes() method
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
torch.Tensor: shape (n_grids, ..., n_features)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -117,11 +118,28 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
|
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
|
||||||
replace the shapes in the dictionary with tensors of those shapes and add the
|
replace the shapes in the dictionary with tensors of those shapes and add the
|
||||||
first 'batch' dimension. If the required shape is (a, b) and you want to
|
first 'batch' dimension. If the required shape is (a, b) and you want to
|
||||||
have g grids than the tensor that replaces the shape should have the
|
have g grids then the tensor that replaces the shape should have the
|
||||||
shape (g, a, b).
|
shape (g, a, b).
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_output_dim(args: DictConfig) -> int:
|
||||||
|
"""
|
||||||
|
Given all the arguments of the grid's __init__, returns output's last dimension length.
|
||||||
|
|
||||||
|
In particular, if self.evaluate_world or self.evaluate_local
|
||||||
|
are called with `points` of shape (n_grids, n_points, 3),
|
||||||
|
their output will be of shape
|
||||||
|
(n_grids, n_points, grid.get_output_dim()).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: DictConfig which would be used to initialize the object
|
||||||
|
Returns:
|
||||||
|
output's last dimension length
|
||||||
|
"""
|
||||||
|
return args["n_features"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
|
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
|
||||||
@ -149,19 +167,23 @@ class FullResolutionVoxelGrid(VoxelGridBase):
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
points (torch.Tensor): tensor of points that you want to query
|
points (torch.Tensor): tensor of points that you want to query
|
||||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
of a form (..., 3), in a normalized form (coordinates are in [-1, 1])
|
||||||
grid_values: an object of type values_type which has tensors as
|
grid_values: an object of type values_type which has tensors as
|
||||||
members which have shapes derived from the get_shapes() method
|
members which have shapes derived from the get_shapes() method
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
torch.Tensor: shape (n_grids, ..., n_features)
|
||||||
"""
|
"""
|
||||||
return interpolate_volume(
|
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
|
||||||
|
recorded_shape = points.shape
|
||||||
|
points = points.view(points.shape[0], -1, points.shape[-1])
|
||||||
|
interpolated = interpolate_volume(
|
||||||
points,
|
points,
|
||||||
grid_values.voxel_grid,
|
grid_values.voxel_grid,
|
||||||
align_corners=self.align_corners,
|
align_corners=self.align_corners,
|
||||||
padding_mode=self.padding,
|
padding_mode=self.padding,
|
||||||
mode=self.mode,
|
mode=self.mode,
|
||||||
)
|
)
|
||||||
|
return interpolated.view(*recorded_shape[:-1], -1)
|
||||||
|
|
||||||
def get_shapes(self) -> Dict[str, Tuple]:
|
def get_shapes(self) -> Dict[str, Tuple]:
|
||||||
return {"voxel_grid": (self.n_features, *self.resolution)}
|
return {"voxel_grid": (self.n_features, *self.resolution)}
|
||||||
@ -202,10 +224,11 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
|||||||
Members:
|
Members:
|
||||||
n_components: number of vector triplets, higher number gives better approximation.
|
n_components: number of vector triplets, higher number gives better approximation.
|
||||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
|
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
|
||||||
basis_matrix of shape (n_grids, n_components, n_features). If
|
by the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
|
||||||
is summed along the rows to get (n_grids, n_points, 1).
|
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
|
||||||
|
to return to starting shape (n_grids, ..., 1).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# the type of grid_values argument needed to run evaluate_local()
|
# the type of grid_values argument needed to run evaluate_local()
|
||||||
@ -219,8 +242,8 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
|||||||
def evaluate_local(
|
def evaluate_local(
|
||||||
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
|
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
def factor(i):
|
def factor(axis):
|
||||||
axis = ["x", "y", "z"][i]
|
i = {"x": 0, "y": 1, "z": 2}[axis]
|
||||||
index = points[..., i, None]
|
index = points[..., i, None]
|
||||||
vector = getattr(grid_values, "vector_components_" + axis)
|
vector = getattr(grid_values, "vector_components_" + axis)
|
||||||
return interpolate_line(
|
return interpolate_line(
|
||||||
@ -231,17 +254,25 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
|||||||
mode=self.mode,
|
mode=self.mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
|
||||||
|
recorded_shape = points.shape
|
||||||
|
points = points.view(points.shape[0], -1, points.shape[-1])
|
||||||
|
|
||||||
# collect points from all the vectors and multipy them out
|
# collect points from all the vectors and multipy them out
|
||||||
mult = factor(0) * factor(1) * factor(2)
|
mult = factor("x") * factor("y") * factor("z")
|
||||||
|
|
||||||
# reduce the result from
|
# reduce the result from
|
||||||
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
|
# (n_grids, n_points_total, n_components) to (n_grids, n_points_total, n_features)
|
||||||
if grid_values.basis_matrix is not None:
|
if grid_values.basis_matrix is not None:
|
||||||
# (n_grids, n_points, n_features) =
|
# (n_grids, n_points_total, n_features) =
|
||||||
# (n_grids, n_points, total_n_components) x (total_n_components, n_features)
|
# (n_grids, n_points_total, total_n_components) @
|
||||||
return torch.bmm(mult, grid_values.basis_matrix)
|
# (n_grids, total_n_components, n_features)
|
||||||
|
result = torch.bmm(mult, grid_values.basis_matrix)
|
||||||
return mult.sum(axis=-1, keepdim=True)
|
else:
|
||||||
|
# (n_grids, n_points_total, 1) from (n_grids, n_points_total, n_features)
|
||||||
|
result = mult.sum(axis=-1, keepdim=True)
|
||||||
|
# (n_grids, ..., n_features)
|
||||||
|
return result.view(*recorded_shape[:-1], -1)
|
||||||
|
|
||||||
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
|
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
|
||||||
if self.matrix_reduction is False and self.n_features != 1:
|
if self.matrix_reduction is False and self.n_features != 1:
|
||||||
@ -308,10 +339,11 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
|||||||
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
|
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
|
||||||
either n_components or distribution_of_components, you cannot specify both.
|
either n_components or distribution_of_components, you cannot specify both.
|
||||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
|
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
|
||||||
the basis_matrix of shape (n_grids, n_components, n_features). If
|
by the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
|
||||||
is summed along the rows to get (n_grids, n_points, 1).
|
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
|
||||||
|
to return to starting shape (n_grids, ..., 1).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# the type of grid_values argument needed to run evaluate_local()
|
# the type of grid_values argument needed to run evaluate_local()
|
||||||
@ -326,6 +358,10 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
|||||||
def evaluate_local(
|
def evaluate_local(
|
||||||
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
|
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
|
||||||
|
recorded_shape = points.shape
|
||||||
|
points = points.view(points.shape[0], -1, points.shape[-1])
|
||||||
|
|
||||||
# collect points from matrices and vectors and multiply them
|
# collect points from matrices and vectors and multiply them
|
||||||
a = interpolate_plane(
|
a = interpolate_plane(
|
||||||
points[..., :2],
|
points[..., :2],
|
||||||
@ -375,9 +411,13 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
|||||||
# (n_grids, n_points, n_features) =
|
# (n_grids, n_points, n_features) =
|
||||||
# (n_grids, n_points, total_n_components) x
|
# (n_grids, n_points, total_n_components) x
|
||||||
# (n_grids, total_n_components, n_features)
|
# (n_grids, total_n_components, n_features)
|
||||||
return torch.bmm(feats, grid_values.basis_matrix)
|
result = torch.bmm(feats, grid_values.basis_matrix)
|
||||||
# pyre-ignore[28]
|
else:
|
||||||
return feats.sum(axis=-1, keepdim=True)
|
# pyre-ignore[28]
|
||||||
|
# (n_grids, n_points, 1) from (n_grids, n_points, n_features)
|
||||||
|
result = feats.sum(axis=-1, keepdim=True)
|
||||||
|
# (n_grids, ..., n_features)
|
||||||
|
return result.view(*recorded_shape[:-1], -1)
|
||||||
|
|
||||||
def get_shapes(self) -> Dict[str, Tuple]:
|
def get_shapes(self) -> Dict[str, Tuple]:
|
||||||
if self.matrix_reduction is False and self.n_features != 1:
|
if self.matrix_reduction is False and self.n_features != 1:
|
||||||
@ -494,9 +534,9 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
points (torch.Tensor): tensor of points that you want to query
|
points (torch.Tensor): tensor of points that you want to query
|
||||||
of a form (n_points, 3)
|
of a form (..., 3)
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor of shape (n_points, n_features)
|
torch.Tensor of shape (..., n_features)
|
||||||
"""
|
"""
|
||||||
locator = VolumeLocator(
|
locator = VolumeLocator(
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
|
@ -38,10 +38,17 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def get_random_normalized_points(
|
def get_random_normalized_points(
|
||||||
self, n_grids, n_points, dimension=3
|
self, n_grids, n_points=None, dimension=3
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
middle_shape = torch.randint(1, 4, tuple(torch.randint(1, 5, size=(1,))))
|
||||||
# create random query points
|
# create random query points
|
||||||
return torch.rand(n_grids, n_points, dimension) * 2 - 1
|
return (
|
||||||
|
torch.rand(
|
||||||
|
n_grids, *(middle_shape if n_points is None else [n_points]), dimension
|
||||||
|
)
|
||||||
|
* 2
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
|
||||||
def _test_query_with_constant_init_cp(
|
def _test_query_with_constant_init_cp(
|
||||||
self,
|
self,
|
||||||
@ -50,7 +57,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
n_components: int,
|
n_components: int,
|
||||||
resolution: Tuple[int],
|
resolution: Tuple[int],
|
||||||
value: float = 1,
|
value: float = 1,
|
||||||
n_points: int = 1,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
# set everything to 'value' and do query for elementsthe result should
|
# set everything to 'value' and do query for elementsthe result should
|
||||||
# be of shape (n_grids, n_points, n_features) and be filled with n_components
|
# be of shape (n_grids, n_points, n_features) and be filled with n_components
|
||||||
@ -65,12 +71,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
params = grid.values_type(
|
params = grid.values_type(
|
||||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||||
)
|
)
|
||||||
|
points = self.get_random_normalized_points(n_grids)
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
grid.evaluate_local(
|
grid.evaluate_local(points, params),
|
||||||
self.get_random_normalized_points(n_grids, n_points), params
|
torch.ones(n_grids, *points.shape[1:-1], n_features) * n_components * value,
|
||||||
),
|
rtol=0.0001,
|
||||||
torch.ones(n_grids, n_points, n_features) * n_components * value,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_query_with_constant_init_vm(
|
def _test_query_with_constant_init_vm(
|
||||||
@ -98,11 +103,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
expected_element = (
|
expected_element = (
|
||||||
n_components * value if distribution is None else sum(distribution) * value
|
n_components * value if distribution is None else sum(distribution) * value
|
||||||
)
|
)
|
||||||
|
points = self.get_random_normalized_points(n_grids)
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
grid.evaluate_local(
|
grid.evaluate_local(points, params),
|
||||||
self.get_random_normalized_points(n_grids, n_points), params
|
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
|
||||||
),
|
|
||||||
torch.ones(n_grids, n_points, n_features) * expected_element,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _test_query_with_constant_init_full(
|
def _test_query_with_constant_init_full(
|
||||||
@ -121,21 +125,20 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
expected_element = value
|
expected_element = value
|
||||||
|
points = self.get_random_normalized_points(n_grids)
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
grid.evaluate_local(
|
grid.evaluate_local(points, params),
|
||||||
self.get_random_normalized_points(n_grids, n_points), params
|
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
|
||||||
),
|
|
||||||
torch.ones(n_grids, n_points, n_features) * expected_element,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_query_with_constant_init(self):
|
def test_query_with_constant_init(self):
|
||||||
with self.subTest("Full"):
|
with self.subTest("Full"):
|
||||||
self._test_query_with_constant_init_full(
|
self._test_query_with_constant_init_full(
|
||||||
n_grids=5, n_features=6, resolution=(3, 4, 5), n_points=3
|
n_grids=5, n_features=6, resolution=(3, 4, 5)
|
||||||
)
|
)
|
||||||
with self.subTest("Full with 1 in dimensions"):
|
with self.subTest("Full with 1 in dimensions"):
|
||||||
self._test_query_with_constant_init_full(
|
self._test_query_with_constant_init_full(
|
||||||
n_grids=5, n_features=1, resolution=(33, 41, 1), n_points=4
|
n_grids=5, n_features=1, resolution=(33, 41, 1)
|
||||||
)
|
)
|
||||||
with self.subTest("CP"):
|
with self.subTest("CP"):
|
||||||
self._test_query_with_constant_init_cp(
|
self._test_query_with_constant_init_cp(
|
||||||
@ -143,7 +146,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
n_features=6,
|
n_features=6,
|
||||||
n_components=7,
|
n_components=7,
|
||||||
resolution=(3, 4, 5),
|
resolution=(3, 4, 5),
|
||||||
n_points=2,
|
|
||||||
)
|
)
|
||||||
with self.subTest("CP with 1 in dimensions"):
|
with self.subTest("CP with 1 in dimensions"):
|
||||||
self._test_query_with_constant_init_cp(
|
self._test_query_with_constant_init_cp(
|
||||||
@ -151,7 +153,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
n_features=1,
|
n_features=1,
|
||||||
n_components=3,
|
n_components=3,
|
||||||
resolution=(3, 1, 1),
|
resolution=(3, 1, 1),
|
||||||
n_points=4,
|
|
||||||
)
|
)
|
||||||
with self.subTest("VM with symetric distribution"):
|
with self.subTest("VM with symetric distribution"):
|
||||||
self._test_query_with_constant_init_vm(
|
self._test_query_with_constant_init_vm(
|
||||||
@ -159,7 +160,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
n_features=9,
|
n_features=9,
|
||||||
resolution=(2, 12, 2),
|
resolution=(2, 12, 2),
|
||||||
n_components=12,
|
n_components=12,
|
||||||
n_points=3,
|
|
||||||
)
|
)
|
||||||
with self.subTest("VM with distribution"):
|
with self.subTest("VM with distribution"):
|
||||||
self._test_query_with_constant_init_vm(
|
self._test_query_with_constant_init_vm(
|
||||||
@ -167,7 +167,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
n_features=1,
|
n_features=1,
|
||||||
resolution=(5, 9, 7),
|
resolution=(5, 9, 7),
|
||||||
distribution=(33, 41, 1),
|
distribution=(33, 41, 1),
|
||||||
n_points=7,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_query_with_zero_init(self):
|
def test_query_with_zero_init(self):
|
||||||
@ -177,7 +176,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
n_features=6,
|
n_features=6,
|
||||||
n_components=7,
|
n_components=7,
|
||||||
resolution=(3, 2, 5),
|
resolution=(3, 2, 5),
|
||||||
n_points=3,
|
|
||||||
value=0,
|
value=0,
|
||||||
)
|
)
|
||||||
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
|
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
|
||||||
@ -186,12 +184,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
n_features=9,
|
n_features=9,
|
||||||
resolution=(2, 11, 3),
|
resolution=(2, 11, 3),
|
||||||
n_components=3,
|
n_components=3,
|
||||||
n_points=3,
|
|
||||||
value=0,
|
value=0,
|
||||||
)
|
)
|
||||||
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
|
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
|
||||||
self._test_query_with_constant_init_full(
|
self._test_query_with_constant_init_full(
|
||||||
n_grids=4, n_features=2, resolution=(3, 3, 5), n_points=3, value=0
|
n_grids=4, n_features=2, resolution=(3, 3, 5), value=0
|
||||||
)
|
)
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -324,6 +321,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
padding_mode="zeros",
|
padding_mode="zeros",
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
),
|
),
|
||||||
|
rtol=0.0001,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_floating_point_query(self):
|
def test_floating_point_query(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user