Summary: Forward method is sped up using the scaffold, a low resolution voxel grid which is used to filter out the points in empty space. These points will be predicted as having 0 density and (0, 0, 0) color. The points which were not evaluated as empty space will be passed through the steps outlined above.

Reviewed By: kjchalup

Differential Revision: D39579671

fbshipit-source-id: 8eab8bb43ef77c2a73557efdb725e99a6c60d415
This commit is contained in:
Darijan Gudelj 2022-10-10 11:01:00 -07:00 committed by Facebook GitHub Bot
parent 95a2acf763
commit 56d3465b09
3 changed files with 173 additions and 45 deletions

View File

@ -15,8 +15,9 @@ these classes.
"""
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Type
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
import torch
from omegaconf import DictConfig
@ -164,8 +165,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
def change_resolution(
self,
epoch: int,
grid_values: VoxelGridValuesBase,
epoch: int,
*,
mode: str = "linear",
align_corners: bool = True,
antialias: bool = False,
@ -177,8 +179,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
epoch: current training epoch, used to see if the grid needs regridding
grid_values: instance of self.values_type which contains
the voxel grid which will be interpolated to create the new grid
wanted_resolution: tuple of (x, y, z) resolutions which determine
new grid's resolution
epoch: epoch which is used to get the resolution of the new
`grid_values` using `self.resolution_changes`.
align_corners: as for torch.nn.functional.interpolate
mode: as for torch.nn.functional.interpolate
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
@ -225,11 +227,17 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
# pyre-ignore[29]
return self.values_type(**params), True
def get_resolution_change_epochs(self) -> List[int]:
def get_resolution_change_epochs(self) -> Tuple[int, ...]:
"""
Returns epochs at which this grid should change epochs.
"""
return list(self.resolution_changes.keys())
return tuple(self.resolution_changes.keys())
def get_align_corners(self) -> bool:
"""
Returns True if voxel grid uses align_corners=True
"""
return self.align_corners
@dataclass
@ -583,6 +591,8 @@ class VoxelGridModule(Configurable, torch.nn.Module):
"""
A wrapper torch.nn.Module for the VoxelGrid classes, which
contains parameters that are needed to train the VoxelGrid classes.
Can contain the parameters for the voxel grid as pytorch parameters
or as registered buffers.
Members:
voxel_grid_class_type: The name of the class to use for voxel_grid,
@ -596,17 +606,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
with mean=init_mean and std=init_std. Default 0.1
init_mean: Parameters are initialized using the gaussian distribution
with mean=init_mean and std=init_std. Default 0.
hold_voxel_grid_as_parameters: if True components of the underlying voxel grids
will be saved as parameters and therefore be trainable. Default True.
"""
voxel_grid_class_type: str = "FullResolutionVoxelGrid"
voxel_grid: VoxelGridBase
extents: Tuple[float, float, float] = (1.0, 1.0, 1.0)
extents: Tuple[float, float, float] = (2.0, 2.0, 2.0)
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
init_std: float = 0.1
init_mean: float = 0
hold_voxel_grid_as_parameters: bool = True
def __post_init__(self):
super().__init__()
run_auto_creation(self)
@ -619,7 +633,8 @@ class VoxelGridModule(Configurable, torch.nn.Module):
)
for name, shape in shapes.items()
}
self.params = torch.nn.ParameterDict(params)
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**params))
self._register_load_state_dict_pre_hook(self._create_parameters_with_new_size)
def forward(self, points: torch.Tensor) -> torch.Tensor:
@ -632,31 +647,29 @@ class VoxelGridModule(Configurable, torch.nn.Module):
Returns:
torch.Tensor of shape (..., n_features)
"""
locator = VolumeLocator(
batch_size=1,
# The resolution of the voxel grid does not need to be known
# to the locator object. It is easiest to fix the resolution of the locator.
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
# desired size. The locator object uses (z, y, x) convention for the grid_size,
# and this module uses (x, y, z) convention so the order has to be reversed
# (irrelevant in this case since they are all equal).
# It is (2, 2, 2) because the VolumeLocator object behaves like
# align_corners=True, which means that the points are in the corners of
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
grid_sizes=(2, 2, 2),
# The locator object uses (x, y, z) convention for the
# voxel size and translation.
voxel_size=tuple(self.extents),
volume_translation=tuple(self.translation),
# pyre-ignore[29]
device=next(val for val in self.params.values() if val is not None).device,
)
locator = self._get_volume_locator()
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
grid_values = self.voxel_grid.values_type(**self.params)
# voxel grids operate with extra n_grids dimension, which we fix to one
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
def set_voxel_grid_parameters(self, params: VoxelGridValuesBase) -> None:
"""
Sets the parameters of the underlying voxel grid.
Args:
params: parameters of type `self.voxel_grid.values_type` which will
replace current parameters
"""
if self.hold_voxel_grid_as_parameters:
# pyre-ignore [16]
self.params = torch.nn.ParameterDict(vars(params))
else:
# Torch Module to hold parameters since they can only be registered
# at object level.
self.params = _RegistratedBufferDict(vars(params))
@staticmethod
def get_output_dim(args: DictConfig) -> int:
"""
@ -672,12 +685,12 @@ class VoxelGridModule(Configurable, torch.nn.Module):
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
)
def subscribe_to_epochs(self) -> Tuple[List[int], Callable[[int], bool]]:
def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]:
"""
Method which expresses interest in subscribing to optimization epoch updates.
Returns:
list of epochs on which to call a callable and callable to be called on
tuple of epochs on which to call a callable and callable to be called on
particular epoch. The callable returns True if parameter change has
happened else False and it must be supplied with one argument, epoch.
"""
@ -697,13 +710,12 @@ class VoxelGridModule(Configurable, torch.nn.Module):
"""
# pyre-ignore[29]
grid_values = self.voxel_grid.values_type(**self.params)
grid_values, change = self.voxel_grid.change_resolution(epoch, grid_values)
grid_values, change = self.voxel_grid.change_resolution(
grid_values, epoch=epoch
)
if change:
# pyre-ignore[16]
self.params = torch.nn.ParameterDict(
{name: tensor for name, tensor in vars(grid_values).items()}
)
return change
self.set_voxel_grid_parameters(grid_values)
return change and self.hold_voxel_grid_as_parameters
def _create_parameters_with_new_size(
self,
@ -749,5 +761,113 @@ class VoxelGridModule(Configurable, torch.nn.Module):
key = prefix + "params." + name
if key in state_dict:
new_params[name] = torch.zeros_like(state_dict[key])
# pyre-ignore[16]
self.params = torch.nn.ParameterDict(new_params)
# pyre-ignore[29]
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
def get_device(self) -> torch.device:
"""
Returns torch.device on which module parameters are located
"""
# pyre-ignore[29]
return next(val for val in self.params.values() if val is not None).device
def _get_volume_locator(self) -> VolumeLocator:
"""
Returns VolumeLocator calculated from `extents` and `translation` members.
"""
return VolumeLocator(
batch_size=1,
# The resolution of the voxel grid does not need to be known
# to the locator object. It is easiest to fix the resolution of the locator.
# In particular we fix it to (2,2,2) so that there is exactly one voxel of the
# desired size. The locator object uses (z, y, x) convention for the grid_size,
# and this module uses (x, y, z) convention so the order has to be reversed
# (irrelevant in this case since they are all equal).
# It is (2, 2, 2) because the VolumeLocator object behaves like
# align_corners=True, which means that the points are in the corners of
# the volume. So in the grid of (2, 2, 2) there is only one voxel.
grid_sizes=(2, 2, 2),
# The locator object uses (x, y, z) convention for the
# voxel size and translation.
voxel_size=tuple(self.extents),
# volume_translation is defined in `VolumeLocator` as a vector from the origin
# of local coordinate frame to origin of world coordinate frame, that is:
# x_world = x_local * extents/2 - translation.
# To get the reverse we need to negate it.
volume_translation=tuple(-t for t in self.translation),
device=self.get_device(),
)
def get_grid_points(self, epoch: int) -> torch.Tensor:
"""
Returns a grid of points that represent centers of voxels of the
underlying voxel grid in world coordinates at specific epoch.
Args:
epoch: underlying voxel grids change resolution depending on the
epoch, this argument is used to determine the resolution
of the voxel grid at that epoch.
Returns:
tensor of shape [xresolution, yresolution, zresolution, 3] where
xresolution, yresolution, zresolution are resolutions of the
underlying voxel grid
"""
xresolution, yresolution, zresolution = self.voxel_grid.get_resolution(epoch)
width, height, depth = self.extents
if not self.voxel_grid.get_align_corners():
width = (
width * (xresolution - 1) / xresolution if xresolution > 1 else width
)
height = (
height * (xresolution - 1) / xresolution if xresolution > 1 else height
)
depth = (
depth * (xresolution - 1) / xresolution if xresolution > 1 else depth
)
xs = torch.linspace(
-width / 2, width / 2, xresolution, device=self.get_device()
)
ys = torch.linspace(
-height / 2, height / 2, yresolution, device=self.get_device()
)
zs = torch.linspace(
-depth / 2, depth / 2, zresolution, device=self.get_device()
)
xmesh, ymesh, zmesh = torch.meshgrid(xs, ys, zs, indexing="ij")
return torch.stack((xmesh, ymesh, zmesh), dim=3)
class _RegistratedBufferDict(torch.nn.Module, Mapping):
"""
Mapping class and a torch.nn.Module that registeres its values
with `self.register_buffer`. Can be indexed like a regular Python
dictionary, but torch.Tensors it contains are properly registered, and will be visible
by all Module methods. Supports only `torch.Tensor` as value and str as key.
"""
def __init__(self, init_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
"""
Args:
init_dict: dictionary which will be used to populate the object
"""
super().__init__()
self._keys = set()
if init_dict is not None:
for k, v in init_dict.items():
self[k] = v
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
return iter({k: self[k] for k in self._keys})
def __len__(self) -> int:
return len(self._keys)
def __getitem__(self, key: str) -> torch.Tensor:
return getattr(self, key)
def __setitem__(self, key, value) -> None:
self._keys.add(key)
self.register_buffer(key, value)
def __hash__(self) -> int:
return hash(repr(self))

View File

@ -653,7 +653,7 @@ class VolumeLocator:
volume_translation: _Translation = (0.0, 0.0, 0.0),
):
"""
**batch_size** : Batch size of the underlaying grids
**batch_size** : Batch size of the underlying grids
**grid_sizes** : Represents the resolutions of different grids in the batch. Can be
a) tuple of form (H, W, D)
b) list/tuple of length batch_size of lists/tuples of form (H, W, D)

View File

@ -35,9 +35,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
one by one sample and comparing with the batched implementation.
"""
def test_my_code(self):
return
def get_random_normalized_points(
self, n_grids, n_points=None, dimension=3
) -> torch.Tensor:
@ -293,6 +290,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
padding_mode="zeros",
mode="bilinear",
),
rtol=0.0001,
atol=0.0001,
)
with self.subTest("2D interpolation"):
points = self.get_random_normalized_points(
@ -308,6 +307,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
padding_mode="zeros",
mode="bilinear",
),
rtol=0.0001,
atol=0.0001,
)
with self.subTest("3D interpolation"):
@ -325,6 +326,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
mode="bilinear",
),
rtol=0.0001,
atol=0.0001,
)
def test_floating_point_query(self):
@ -378,7 +380,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
assert torch.allclose(
grid.evaluate_local(points, params),
expected_result,
rtol=0.00001,
rtol=0.0001,
atol=0.0001,
), grid.evaluate_local(points, params)
with self.subTest("CP"):
grid = CPFactorizedVoxelGrid(
@ -446,14 +449,16 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
assert torch.allclose(
grid.evaluate_local(points, params),
expected_result_matrix,
rtol=0.00001,
rtol=0.0001,
atol=0.0001,
)
del params.basis_matrix
with self.subTest("CP with sum reduction"):
assert torch.allclose(
grid.evaluate_local(points, params),
expected_result_sum,
rtol=0.00001,
rtol=0.0001,
atol=0.0001,
)
with self.subTest("VM"):
@ -540,7 +545,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
assert torch.allclose(
grid.evaluate_local(points, params),
expected_result_matrix,
rtol=0.00001,
rtol=0.0001,
atol=0.0001,
)
del params.basis_matrix
with self.subTest("VM with sum reduction"):
@ -548,6 +554,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
grid.evaluate_local(points, params),
expected_result_sum,
rtol=0.0001,
atol=0.0001,
), grid.evaluate_local(points, params)
def test_forward_with_small_init_std(self):
@ -613,6 +620,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
grid(world_point)[0, 0],
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
rtol=0.0001,
atol=0.0001,
)
def test_resolution_change(self, n_times=10):