mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
scaffold
Summary: Forward method is sped up using the scaffold, a low resolution voxel grid which is used to filter out the points in empty space. These points will be predicted as having 0 density and (0, 0, 0) color. The points which were not evaluated as empty space will be passed through the steps outlined above. Reviewed By: kjchalup Differential Revision: D39579671 fbshipit-source-id: 8eab8bb43ef77c2a73557efdb725e99a6c60d415
This commit is contained in:
parent
95a2acf763
commit
56d3465b09
@ -15,8 +15,9 @@ these classes.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Type
|
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
@ -164,8 +165,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
|
|
||||||
def change_resolution(
|
def change_resolution(
|
||||||
self,
|
self,
|
||||||
epoch: int,
|
|
||||||
grid_values: VoxelGridValuesBase,
|
grid_values: VoxelGridValuesBase,
|
||||||
|
epoch: int,
|
||||||
|
*,
|
||||||
mode: str = "linear",
|
mode: str = "linear",
|
||||||
align_corners: bool = True,
|
align_corners: bool = True,
|
||||||
antialias: bool = False,
|
antialias: bool = False,
|
||||||
@ -177,8 +179,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
epoch: current training epoch, used to see if the grid needs regridding
|
epoch: current training epoch, used to see if the grid needs regridding
|
||||||
grid_values: instance of self.values_type which contains
|
grid_values: instance of self.values_type which contains
|
||||||
the voxel grid which will be interpolated to create the new grid
|
the voxel grid which will be interpolated to create the new grid
|
||||||
wanted_resolution: tuple of (x, y, z) resolutions which determine
|
epoch: epoch which is used to get the resolution of the new
|
||||||
new grid's resolution
|
`grid_values` using `self.resolution_changes`.
|
||||||
align_corners: as for torch.nn.functional.interpolate
|
align_corners: as for torch.nn.functional.interpolate
|
||||||
mode: as for torch.nn.functional.interpolate
|
mode: as for torch.nn.functional.interpolate
|
||||||
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
|
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
|
||||||
@ -225,11 +227,17 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
# pyre-ignore[29]
|
# pyre-ignore[29]
|
||||||
return self.values_type(**params), True
|
return self.values_type(**params), True
|
||||||
|
|
||||||
def get_resolution_change_epochs(self) -> List[int]:
|
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
|
||||||
"""
|
"""
|
||||||
Returns epochs at which this grid should change epochs.
|
Returns epochs at which this grid should change epochs.
|
||||||
"""
|
"""
|
||||||
return list(self.resolution_changes.keys())
|
return tuple(self.resolution_changes.keys())
|
||||||
|
|
||||||
|
def get_align_corners(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if voxel grid uses align_corners=True
|
||||||
|
"""
|
||||||
|
return self.align_corners
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -583,6 +591,8 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
A wrapper torch.nn.Module for the VoxelGrid classes, which
|
A wrapper torch.nn.Module for the VoxelGrid classes, which
|
||||||
contains parameters that are needed to train the VoxelGrid classes.
|
contains parameters that are needed to train the VoxelGrid classes.
|
||||||
|
Can contain the parameters for the voxel grid as pytorch parameters
|
||||||
|
or as registered buffers.
|
||||||
|
|
||||||
Members:
|
Members:
|
||||||
voxel_grid_class_type: The name of the class to use for voxel_grid,
|
voxel_grid_class_type: The name of the class to use for voxel_grid,
|
||||||
@ -596,17 +606,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
with mean=init_mean and std=init_std. Default 0.1
|
with mean=init_mean and std=init_std. Default 0.1
|
||||||
init_mean: Parameters are initialized using the gaussian distribution
|
init_mean: Parameters are initialized using the gaussian distribution
|
||||||
with mean=init_mean and std=init_std. Default 0.
|
with mean=init_mean and std=init_std. Default 0.
|
||||||
|
hold_voxel_grid_as_parameters: if True components of the underlying voxel grids
|
||||||
|
will be saved as parameters and therefore be trainable. Default True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
|
||||||
voxel_grid: VoxelGridBase
|
voxel_grid: VoxelGridBase
|
||||||
|
|
||||||
extents: Tuple[float, float, float] = (1.0, 1.0, 1.0)
|
extents: Tuple[float, float, float] = (2.0, 2.0, 2.0)
|
||||||
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
|
|
||||||
init_std: float = 0.1
|
init_std: float = 0.1
|
||||||
init_mean: float = 0
|
init_mean: float = 0
|
||||||
|
|
||||||
|
hold_voxel_grid_as_parameters: bool = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
@ -619,7 +633,8 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
for name, shape in shapes.items()
|
for name, shape in shapes.items()
|
||||||
}
|
}
|
||||||
self.params = torch.nn.ParameterDict(params)
|
|
||||||
|
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**params))
|
||||||
self._register_load_state_dict_pre_hook(self._create_parameters_with_new_size)
|
self._register_load_state_dict_pre_hook(self._create_parameters_with_new_size)
|
||||||
|
|
||||||
def forward(self, points: torch.Tensor) -> torch.Tensor:
|
def forward(self, points: torch.Tensor) -> torch.Tensor:
|
||||||
@ -632,31 +647,29 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor of shape (..., n_features)
|
torch.Tensor of shape (..., n_features)
|
||||||
"""
|
"""
|
||||||
locator = VolumeLocator(
|
locator = self._get_volume_locator()
|
||||||
batch_size=1,
|
|
||||||
# The resolution of the voxel grid does not need to be known
|
|
||||||
# to the locator object. It is easiest to fix the resolution of the locator.
|
|
||||||
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
|
|
||||||
# desired size. The locator object uses (z, y, x) convention for the grid_size,
|
|
||||||
# and this module uses (x, y, z) convention so the order has to be reversed
|
|
||||||
# (irrelevant in this case since they are all equal).
|
|
||||||
# It is (2, 2, 2) because the VolumeLocator object behaves like
|
|
||||||
# align_corners=True, which means that the points are in the corners of
|
|
||||||
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
|
|
||||||
grid_sizes=(2, 2, 2),
|
|
||||||
# The locator object uses (x, y, z) convention for the
|
|
||||||
# voxel size and translation.
|
|
||||||
voxel_size=tuple(self.extents),
|
|
||||||
volume_translation=tuple(self.translation),
|
|
||||||
# pyre-ignore[29]
|
|
||||||
device=next(val for val in self.params.values() if val is not None).device,
|
|
||||||
)
|
|
||||||
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
|
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
|
||||||
# torch.nn.modules.module.Module]` is not a function.
|
# torch.nn.modules.module.Module]` is not a function.
|
||||||
grid_values = self.voxel_grid.values_type(**self.params)
|
grid_values = self.voxel_grid.values_type(**self.params)
|
||||||
# voxel grids operate with extra n_grids dimension, which we fix to one
|
# voxel grids operate with extra n_grids dimension, which we fix to one
|
||||||
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
|
||||||
|
|
||||||
|
def set_voxel_grid_parameters(self, params: VoxelGridValuesBase) -> None:
|
||||||
|
"""
|
||||||
|
Sets the parameters of the underlying voxel grid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: parameters of type `self.voxel_grid.values_type` which will
|
||||||
|
replace current parameters
|
||||||
|
"""
|
||||||
|
if self.hold_voxel_grid_as_parameters:
|
||||||
|
# pyre-ignore [16]
|
||||||
|
self.params = torch.nn.ParameterDict(vars(params))
|
||||||
|
else:
|
||||||
|
# Torch Module to hold parameters since they can only be registered
|
||||||
|
# at object level.
|
||||||
|
self.params = _RegistratedBufferDict(vars(params))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_output_dim(args: DictConfig) -> int:
|
def get_output_dim(args: DictConfig) -> int:
|
||||||
"""
|
"""
|
||||||
@ -672,12 +685,12 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
|
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def subscribe_to_epochs(self) -> Tuple[List[int], Callable[[int], bool]]:
|
def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]:
|
||||||
"""
|
"""
|
||||||
Method which expresses interest in subscribing to optimization epoch updates.
|
Method which expresses interest in subscribing to optimization epoch updates.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of epochs on which to call a callable and callable to be called on
|
tuple of epochs on which to call a callable and callable to be called on
|
||||||
particular epoch. The callable returns True if parameter change has
|
particular epoch. The callable returns True if parameter change has
|
||||||
happened else False and it must be supplied with one argument, epoch.
|
happened else False and it must be supplied with one argument, epoch.
|
||||||
"""
|
"""
|
||||||
@ -697,13 +710,12 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
# pyre-ignore[29]
|
# pyre-ignore[29]
|
||||||
grid_values = self.voxel_grid.values_type(**self.params)
|
grid_values = self.voxel_grid.values_type(**self.params)
|
||||||
grid_values, change = self.voxel_grid.change_resolution(epoch, grid_values)
|
grid_values, change = self.voxel_grid.change_resolution(
|
||||||
|
grid_values, epoch=epoch
|
||||||
|
)
|
||||||
if change:
|
if change:
|
||||||
# pyre-ignore[16]
|
self.set_voxel_grid_parameters(grid_values)
|
||||||
self.params = torch.nn.ParameterDict(
|
return change and self.hold_voxel_grid_as_parameters
|
||||||
{name: tensor for name, tensor in vars(grid_values).items()}
|
|
||||||
)
|
|
||||||
return change
|
|
||||||
|
|
||||||
def _create_parameters_with_new_size(
|
def _create_parameters_with_new_size(
|
||||||
self,
|
self,
|
||||||
@ -749,5 +761,113 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
key = prefix + "params." + name
|
key = prefix + "params." + name
|
||||||
if key in state_dict:
|
if key in state_dict:
|
||||||
new_params[name] = torch.zeros_like(state_dict[key])
|
new_params[name] = torch.zeros_like(state_dict[key])
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[29]
|
||||||
self.params = torch.nn.ParameterDict(new_params)
|
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
|
||||||
|
|
||||||
|
def get_device(self) -> torch.device:
|
||||||
|
"""
|
||||||
|
Returns torch.device on which module parameters are located
|
||||||
|
"""
|
||||||
|
# pyre-ignore[29]
|
||||||
|
return next(val for val in self.params.values() if val is not None).device
|
||||||
|
|
||||||
|
def _get_volume_locator(self) -> VolumeLocator:
|
||||||
|
"""
|
||||||
|
Returns VolumeLocator calculated from `extents` and `translation` members.
|
||||||
|
"""
|
||||||
|
return VolumeLocator(
|
||||||
|
batch_size=1,
|
||||||
|
# The resolution of the voxel grid does not need to be known
|
||||||
|
# to the locator object. It is easiest to fix the resolution of the locator.
|
||||||
|
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
|
||||||
|
# desired size. The locator object uses (z, y, x) convention for the grid_size,
|
||||||
|
# and this module uses (x, y, z) convention so the order has to be reversed
|
||||||
|
# (irrelevant in this case since they are all equal).
|
||||||
|
# It is (2, 2, 2) because the VolumeLocator object behaves like
|
||||||
|
# align_corners=True, which means that the points are in the corners of
|
||||||
|
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
|
||||||
|
grid_sizes=(2, 2, 2),
|
||||||
|
# The locator object uses (x, y, z) convention for the
|
||||||
|
# voxel size and translation.
|
||||||
|
voxel_size=tuple(self.extents),
|
||||||
|
# volume_translation is defined in `VolumeLocator` as a vector from the origin
|
||||||
|
# of local coordinate frame to origin of world coordinate frame, that is:
|
||||||
|
# x_world = x_local * extents/2 - translation.
|
||||||
|
# To get the reverse we need to negate it.
|
||||||
|
volume_translation=tuple(-t for t in self.translation),
|
||||||
|
device=self.get_device(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_grid_points(self, epoch: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Returns a grid of points that represent centers of voxels of the
|
||||||
|
underlying voxel grid in world coordinates at specific epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
epoch: underlying voxel grids change resolution depending on the
|
||||||
|
epoch, this argument is used to determine the resolution
|
||||||
|
of the voxel grid at that epoch.
|
||||||
|
Returns:
|
||||||
|
tensor of shape [xresolution, yresolution, zresolution, 3] where
|
||||||
|
xresolution, yresolution, zresolution are resolutions of the
|
||||||
|
underlying voxel grid
|
||||||
|
"""
|
||||||
|
xresolution, yresolution, zresolution = self.voxel_grid.get_resolution(epoch)
|
||||||
|
width, height, depth = self.extents
|
||||||
|
if not self.voxel_grid.get_align_corners():
|
||||||
|
width = (
|
||||||
|
width * (xresolution - 1) / xresolution if xresolution > 1 else width
|
||||||
|
)
|
||||||
|
height = (
|
||||||
|
height * (xresolution - 1) / xresolution if xresolution > 1 else height
|
||||||
|
)
|
||||||
|
depth = (
|
||||||
|
depth * (xresolution - 1) / xresolution if xresolution > 1 else depth
|
||||||
|
)
|
||||||
|
xs = torch.linspace(
|
||||||
|
-width / 2, width / 2, xresolution, device=self.get_device()
|
||||||
|
)
|
||||||
|
ys = torch.linspace(
|
||||||
|
-height / 2, height / 2, yresolution, device=self.get_device()
|
||||||
|
)
|
||||||
|
zs = torch.linspace(
|
||||||
|
-depth / 2, depth / 2, zresolution, device=self.get_device()
|
||||||
|
)
|
||||||
|
xmesh, ymesh, zmesh = torch.meshgrid(xs, ys, zs, indexing="ij")
|
||||||
|
return torch.stack((xmesh, ymesh, zmesh), dim=3)
|
||||||
|
|
||||||
|
|
||||||
|
class _RegistratedBufferDict(torch.nn.Module, Mapping):
|
||||||
|
"""
|
||||||
|
Mapping class and a torch.nn.Module that registeres its values
|
||||||
|
with `self.register_buffer`. Can be indexed like a regular Python
|
||||||
|
dictionary, but torch.Tensors it contains are properly registered, and will be visible
|
||||||
|
by all Module methods. Supports only `torch.Tensor` as value and str as key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, init_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
init_dict: dictionary which will be used to populate the object
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._keys = set()
|
||||||
|
if init_dict is not None:
|
||||||
|
for k, v in init_dict.items():
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
|
||||||
|
return iter({k: self[k] for k in self._keys})
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._keys)
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> torch.Tensor:
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value) -> None:
|
||||||
|
self._keys.add(key)
|
||||||
|
self.register_buffer(key, value)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(repr(self))
|
||||||
|
@ -653,7 +653,7 @@ class VolumeLocator:
|
|||||||
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
**batch_size** : Batch size of the underlaying grids
|
**batch_size** : Batch size of the underlying grids
|
||||||
**grid_sizes** : Represents the resolutions of different grids in the batch. Can be
|
**grid_sizes** : Represents the resolutions of different grids in the batch. Can be
|
||||||
a) tuple of form (H, W, D)
|
a) tuple of form (H, W, D)
|
||||||
b) list/tuple of length batch_size of lists/tuples of form (H, W, D)
|
b) list/tuple of length batch_size of lists/tuples of form (H, W, D)
|
||||||
|
@ -35,9 +35,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
one by one sample and comparing with the batched implementation.
|
one by one sample and comparing with the batched implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_my_code(self):
|
|
||||||
return
|
|
||||||
|
|
||||||
def get_random_normalized_points(
|
def get_random_normalized_points(
|
||||||
self, n_grids, n_points=None, dimension=3
|
self, n_grids, n_points=None, dimension=3
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -293,6 +290,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
padding_mode="zeros",
|
padding_mode="zeros",
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
),
|
),
|
||||||
|
rtol=0.0001,
|
||||||
|
atol=0.0001,
|
||||||
)
|
)
|
||||||
with self.subTest("2D interpolation"):
|
with self.subTest("2D interpolation"):
|
||||||
points = self.get_random_normalized_points(
|
points = self.get_random_normalized_points(
|
||||||
@ -308,6 +307,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
padding_mode="zeros",
|
padding_mode="zeros",
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
),
|
),
|
||||||
|
rtol=0.0001,
|
||||||
|
atol=0.0001,
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.subTest("3D interpolation"):
|
with self.subTest("3D interpolation"):
|
||||||
@ -325,6 +326,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
),
|
),
|
||||||
rtol=0.0001,
|
rtol=0.0001,
|
||||||
|
atol=0.0001,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_floating_point_query(self):
|
def test_floating_point_query(self):
|
||||||
@ -378,7 +380,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
grid.evaluate_local(points, params),
|
grid.evaluate_local(points, params),
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=0.00001,
|
rtol=0.0001,
|
||||||
|
atol=0.0001,
|
||||||
), grid.evaluate_local(points, params)
|
), grid.evaluate_local(points, params)
|
||||||
with self.subTest("CP"):
|
with self.subTest("CP"):
|
||||||
grid = CPFactorizedVoxelGrid(
|
grid = CPFactorizedVoxelGrid(
|
||||||
@ -446,14 +449,16 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
grid.evaluate_local(points, params),
|
grid.evaluate_local(points, params),
|
||||||
expected_result_matrix,
|
expected_result_matrix,
|
||||||
rtol=0.00001,
|
rtol=0.0001,
|
||||||
|
atol=0.0001,
|
||||||
)
|
)
|
||||||
del params.basis_matrix
|
del params.basis_matrix
|
||||||
with self.subTest("CP with sum reduction"):
|
with self.subTest("CP with sum reduction"):
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
grid.evaluate_local(points, params),
|
grid.evaluate_local(points, params),
|
||||||
expected_result_sum,
|
expected_result_sum,
|
||||||
rtol=0.00001,
|
rtol=0.0001,
|
||||||
|
atol=0.0001,
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.subTest("VM"):
|
with self.subTest("VM"):
|
||||||
@ -540,7 +545,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
grid.evaluate_local(points, params),
|
grid.evaluate_local(points, params),
|
||||||
expected_result_matrix,
|
expected_result_matrix,
|
||||||
rtol=0.00001,
|
rtol=0.0001,
|
||||||
|
atol=0.0001,
|
||||||
)
|
)
|
||||||
del params.basis_matrix
|
del params.basis_matrix
|
||||||
with self.subTest("VM with sum reduction"):
|
with self.subTest("VM with sum reduction"):
|
||||||
@ -548,6 +554,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
grid.evaluate_local(points, params),
|
grid.evaluate_local(points, params),
|
||||||
expected_result_sum,
|
expected_result_sum,
|
||||||
rtol=0.0001,
|
rtol=0.0001,
|
||||||
|
atol=0.0001,
|
||||||
), grid.evaluate_local(points, params)
|
), grid.evaluate_local(points, params)
|
||||||
|
|
||||||
def test_forward_with_small_init_std(self):
|
def test_forward_with_small_init_std(self):
|
||||||
@ -613,6 +620,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|||||||
grid(world_point)[0, 0],
|
grid(world_point)[0, 0],
|
||||||
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
|
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
|
||||||
rtol=0.0001,
|
rtol=0.0001,
|
||||||
|
atol=0.0001,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_resolution_change(self, n_times=10):
|
def test_resolution_change(self, n_times=10):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user