diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index 5c560f0c..63272569 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -15,8 +15,9 @@ these classes. """ +from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Type +from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type import torch from omegaconf import DictConfig @@ -164,8 +165,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): def change_resolution( self, - epoch: int, grid_values: VoxelGridValuesBase, + epoch: int, + *, mode: str = "linear", align_corners: bool = True, antialias: bool = False, @@ -177,8 +179,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): epoch: current training epoch, used to see if the grid needs regridding grid_values: instance of self.values_type which contains the voxel grid which will be interpolated to create the new grid - wanted_resolution: tuple of (x, y, z) resolutions which determine - new grid's resolution + epoch: epoch which is used to get the resolution of the new + `grid_values` using `self.resolution_changes`. align_corners: as for torch.nn.functional.interpolate mode: as for torch.nn.functional.interpolate 'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'. @@ -225,11 +227,17 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): # pyre-ignore[29] return self.values_type(**params), True - def get_resolution_change_epochs(self) -> List[int]: + def get_resolution_change_epochs(self) -> Tuple[int, ...]: """ Returns epochs at which this grid should change epochs. """ - return list(self.resolution_changes.keys()) + return tuple(self.resolution_changes.keys()) + + def get_align_corners(self) -> bool: + """ + Returns True if voxel grid uses align_corners=True + """ + return self.align_corners @dataclass @@ -583,6 +591,8 @@ class VoxelGridModule(Configurable, torch.nn.Module): """ A wrapper torch.nn.Module for the VoxelGrid classes, which contains parameters that are needed to train the VoxelGrid classes. + Can contain the parameters for the voxel grid as pytorch parameters + or as registered buffers. Members: voxel_grid_class_type: The name of the class to use for voxel_grid, @@ -596,17 +606,21 @@ class VoxelGridModule(Configurable, torch.nn.Module): with mean=init_mean and std=init_std. Default 0.1 init_mean: Parameters are initialized using the gaussian distribution with mean=init_mean and std=init_std. Default 0. + hold_voxel_grid_as_parameters: if True components of the underlying voxel grids + will be saved as parameters and therefore be trainable. Default True. """ voxel_grid_class_type: str = "FullResolutionVoxelGrid" voxel_grid: VoxelGridBase - extents: Tuple[float, float, float] = (1.0, 1.0, 1.0) + extents: Tuple[float, float, float] = (2.0, 2.0, 2.0) translation: Tuple[float, float, float] = (0.0, 0.0, 0.0) init_std: float = 0.1 init_mean: float = 0 + hold_voxel_grid_as_parameters: bool = True + def __post_init__(self): super().__init__() run_auto_creation(self) @@ -619,7 +633,8 @@ class VoxelGridModule(Configurable, torch.nn.Module): ) for name, shape in shapes.items() } - self.params = torch.nn.ParameterDict(params) + + self.set_voxel_grid_parameters(self.voxel_grid.values_type(**params)) self._register_load_state_dict_pre_hook(self._create_parameters_with_new_size) def forward(self, points: torch.Tensor) -> torch.Tensor: @@ -632,31 +647,29 @@ class VoxelGridModule(Configurable, torch.nn.Module): Returns: torch.Tensor of shape (..., n_features) """ - locator = VolumeLocator( - batch_size=1, - # The resolution of the voxel grid does not need to be known - # to the locator object. It is easiest to fix the resolution of the locator. - # In particular we fix it to (2,2,2) so that there is exactly one voxel of the - # desired size. The locator object uses (z, y, x) convention for the grid_size, - # and this module uses (x, y, z) convention so the order has to be reversed - # (irrelevant in this case since they are all equal). - # It is (2, 2, 2) because the VolumeLocator object behaves like - # align_corners=True, which means that the points are in the corners of - # the volume. So in the grid of (2, 2, 2) there is only one voxel. - grid_sizes=(2, 2, 2), - # The locator object uses (x, y, z) convention for the - # voxel size and translation. - voxel_size=tuple(self.extents), - volume_translation=tuple(self.translation), - # pyre-ignore[29] - device=next(val for val in self.params.values() if val is not None).device, - ) + locator = self._get_volume_locator() # pyre-fixme[29]: `Union[torch._tensor.Tensor, # torch.nn.modules.module.Module]` is not a function. grid_values = self.voxel_grid.values_type(**self.params) # voxel grids operate with extra n_grids dimension, which we fix to one return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0] + def set_voxel_grid_parameters(self, params: VoxelGridValuesBase) -> None: + """ + Sets the parameters of the underlying voxel grid. + + Args: + params: parameters of type `self.voxel_grid.values_type` which will + replace current parameters + """ + if self.hold_voxel_grid_as_parameters: + # pyre-ignore [16] + self.params = torch.nn.ParameterDict(vars(params)) + else: + # Torch Module to hold parameters since they can only be registered + # at object level. + self.params = _RegistratedBufferDict(vars(params)) + @staticmethod def get_output_dim(args: DictConfig) -> int: """ @@ -672,12 +685,12 @@ class VoxelGridModule(Configurable, torch.nn.Module): args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"] ) - def subscribe_to_epochs(self) -> Tuple[List[int], Callable[[int], bool]]: + def subscribe_to_epochs(self) -> Tuple[Tuple[int, ...], Callable[[int], bool]]: """ Method which expresses interest in subscribing to optimization epoch updates. Returns: - list of epochs on which to call a callable and callable to be called on + tuple of epochs on which to call a callable and callable to be called on particular epoch. The callable returns True if parameter change has happened else False and it must be supplied with one argument, epoch. """ @@ -697,13 +710,12 @@ class VoxelGridModule(Configurable, torch.nn.Module): """ # pyre-ignore[29] grid_values = self.voxel_grid.values_type(**self.params) - grid_values, change = self.voxel_grid.change_resolution(epoch, grid_values) + grid_values, change = self.voxel_grid.change_resolution( + grid_values, epoch=epoch + ) if change: - # pyre-ignore[16] - self.params = torch.nn.ParameterDict( - {name: tensor for name, tensor in vars(grid_values).items()} - ) - return change + self.set_voxel_grid_parameters(grid_values) + return change and self.hold_voxel_grid_as_parameters def _create_parameters_with_new_size( self, @@ -749,5 +761,113 @@ class VoxelGridModule(Configurable, torch.nn.Module): key = prefix + "params." + name if key in state_dict: new_params[name] = torch.zeros_like(state_dict[key]) - # pyre-ignore[16] - self.params = torch.nn.ParameterDict(new_params) + # pyre-ignore[29] + self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params)) + + def get_device(self) -> torch.device: + """ + Returns torch.device on which module parameters are located + """ + # pyre-ignore[29] + return next(val for val in self.params.values() if val is not None).device + + def _get_volume_locator(self) -> VolumeLocator: + """ + Returns VolumeLocator calculated from `extents` and `translation` members. + """ + return VolumeLocator( + batch_size=1, + # The resolution of the voxel grid does not need to be known + # to the locator object. It is easiest to fix the resolution of the locator. + # In particular we fix it to (2,2,2) so that there is exactly one voxel of the + # desired size. The locator object uses (z, y, x) convention for the grid_size, + # and this module uses (x, y, z) convention so the order has to be reversed + # (irrelevant in this case since they are all equal). + # It is (2, 2, 2) because the VolumeLocator object behaves like + # align_corners=True, which means that the points are in the corners of + # the volume. So in the grid of (2, 2, 2) there is only one voxel. + grid_sizes=(2, 2, 2), + # The locator object uses (x, y, z) convention for the + # voxel size and translation. + voxel_size=tuple(self.extents), + # volume_translation is defined in `VolumeLocator` as a vector from the origin + # of local coordinate frame to origin of world coordinate frame, that is: + # x_world = x_local * extents/2 - translation. + # To get the reverse we need to negate it. + volume_translation=tuple(-t for t in self.translation), + device=self.get_device(), + ) + + def get_grid_points(self, epoch: int) -> torch.Tensor: + """ + Returns a grid of points that represent centers of voxels of the + underlying voxel grid in world coordinates at specific epoch. + + Args: + epoch: underlying voxel grids change resolution depending on the + epoch, this argument is used to determine the resolution + of the voxel grid at that epoch. + Returns: + tensor of shape [xresolution, yresolution, zresolution, 3] where + xresolution, yresolution, zresolution are resolutions of the + underlying voxel grid + """ + xresolution, yresolution, zresolution = self.voxel_grid.get_resolution(epoch) + width, height, depth = self.extents + if not self.voxel_grid.get_align_corners(): + width = ( + width * (xresolution - 1) / xresolution if xresolution > 1 else width + ) + height = ( + height * (xresolution - 1) / xresolution if xresolution > 1 else height + ) + depth = ( + depth * (xresolution - 1) / xresolution if xresolution > 1 else depth + ) + xs = torch.linspace( + -width / 2, width / 2, xresolution, device=self.get_device() + ) + ys = torch.linspace( + -height / 2, height / 2, yresolution, device=self.get_device() + ) + zs = torch.linspace( + -depth / 2, depth / 2, zresolution, device=self.get_device() + ) + xmesh, ymesh, zmesh = torch.meshgrid(xs, ys, zs, indexing="ij") + return torch.stack((xmesh, ymesh, zmesh), dim=3) + + +class _RegistratedBufferDict(torch.nn.Module, Mapping): + """ + Mapping class and a torch.nn.Module that registeres its values + with `self.register_buffer`. Can be indexed like a regular Python + dictionary, but torch.Tensors it contains are properly registered, and will be visible + by all Module methods. Supports only `torch.Tensor` as value and str as key. + """ + + def __init__(self, init_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: + """ + Args: + init_dict: dictionary which will be used to populate the object + """ + super().__init__() + self._keys = set() + if init_dict is not None: + for k, v in init_dict.items(): + self[k] = v + + def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: + return iter({k: self[k] for k in self._keys}) + + def __len__(self) -> int: + return len(self._keys) + + def __getitem__(self, key: str) -> torch.Tensor: + return getattr(self, key) + + def __setitem__(self, key, value) -> None: + self._keys.add(key) + self.register_buffer(key, value) + + def __hash__(self) -> int: + return hash(repr(self)) diff --git a/pytorch3d/structures/volumes.py b/pytorch3d/structures/volumes.py index 166c383a..f63bf0da 100644 --- a/pytorch3d/structures/volumes.py +++ b/pytorch3d/structures/volumes.py @@ -653,7 +653,7 @@ class VolumeLocator: volume_translation: _Translation = (0.0, 0.0, 0.0), ): """ - **batch_size** : Batch size of the underlaying grids + **batch_size** : Batch size of the underlying grids **grid_sizes** : Represents the resolutions of different grids in the batch. Can be a) tuple of form (H, W, D) b) list/tuple of length batch_size of lists/tuples of form (H, W, D) diff --git a/tests/implicitron/test_voxel_grids.py b/tests/implicitron/test_voxel_grids.py index 6ab93a21..8e138411 100644 --- a/tests/implicitron/test_voxel_grids.py +++ b/tests/implicitron/test_voxel_grids.py @@ -35,9 +35,6 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): one by one sample and comparing with the batched implementation. """ - def test_my_code(self): - return - def get_random_normalized_points( self, n_grids, n_points=None, dimension=3 ) -> torch.Tensor: @@ -293,6 +290,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): padding_mode="zeros", mode="bilinear", ), + rtol=0.0001, + atol=0.0001, ) with self.subTest("2D interpolation"): points = self.get_random_normalized_points( @@ -308,6 +307,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): padding_mode="zeros", mode="bilinear", ), + rtol=0.0001, + atol=0.0001, ) with self.subTest("3D interpolation"): @@ -325,6 +326,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): mode="bilinear", ), rtol=0.0001, + atol=0.0001, ) def test_floating_point_query(self): @@ -378,7 +380,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): assert torch.allclose( grid.evaluate_local(points, params), expected_result, - rtol=0.00001, + rtol=0.0001, + atol=0.0001, ), grid.evaluate_local(points, params) with self.subTest("CP"): grid = CPFactorizedVoxelGrid( @@ -446,14 +449,16 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): assert torch.allclose( grid.evaluate_local(points, params), expected_result_matrix, - rtol=0.00001, + rtol=0.0001, + atol=0.0001, ) del params.basis_matrix with self.subTest("CP with sum reduction"): assert torch.allclose( grid.evaluate_local(points, params), expected_result_sum, - rtol=0.00001, + rtol=0.0001, + atol=0.0001, ) with self.subTest("VM"): @@ -540,7 +545,8 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): assert torch.allclose( grid.evaluate_local(points, params), expected_result_matrix, - rtol=0.00001, + rtol=0.0001, + atol=0.0001, ) del params.basis_matrix with self.subTest("VM with sum reduction"): @@ -548,6 +554,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): grid.evaluate_local(points, params), expected_result_sum, rtol=0.0001, + atol=0.0001, ), grid.evaluate_local(points, params) def test_forward_with_small_init_std(self): @@ -613,6 +620,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): grid(world_point)[0, 0], grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0], rtol=0.0001, + atol=0.0001, ) def test_resolution_change(self, n_times=10):