diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index 8ea21859..5c560f0c 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -15,8 +15,8 @@ these classes. """ -from dataclasses import dataclass -from typing import ClassVar, Dict, Optional, Tuple, Type +from dataclasses import dataclass, field +from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Type import torch from omegaconf import DictConfig @@ -58,17 +58,22 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): will actually be trilinear. n_features: number of dimensions of base feature vector. Determines how many features the grid returns. - resolution: 3-tuple containing x, y, z grid sizes corresponding to each axis. + resolution_changes: a dictionary, where keys are change epochs and values are + 3-tuples containing x, y, z grid sizes corresponding to each axis to each epoch """ align_corners: bool = True padding: str = "zeros" mode: str = "bilinear" n_features: int = 1 - resolution: Tuple[int, int, int] = (128, 128, 128) + resolution_changes: Dict[int, List[int]] = field( + default_factory=lambda: {0: [128, 128, 128]} + ) def __post_init__(self): super().__init__() + if 0 not in self.resolution_changes: + raise ValueError("There has to be key `0` in `resolution_changes`.") def evaluate_world( self, @@ -109,11 +114,13 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): """ raise NotImplementedError() - def get_shapes(self) -> Dict[str, Tuple]: + def get_shapes(self, epoch: int) -> Dict[str, Tuple]: """ Using parameters from the __init__ method, this method returns the shapes of individual tensors needed to run the evaluate method. + Args: + epoch: If the shape varies during training, which training epoch's shape to return. Returns: a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods replace the shapes in the dictionary with tensors of those shapes and add the @@ -123,6 +130,21 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): """ raise NotImplementedError() + def get_resolution(self, epoch: int) -> List[int]: + """ + Returns the resolution which the grid should have at specific epoch + + Args: + epoch which to use in the resolution calculation + Returns: + resolution at specific epoch + """ + last_change = 0 + for change_epoch in self.resolution_changes: + if change_epoch <= epoch: + last_change = max(last_change, change_epoch) + return self.resolution_changes[last_change] + @staticmethod def get_output_dim(args: DictConfig) -> int: """ @@ -140,6 +162,75 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): """ return args["n_features"] + def change_resolution( + self, + epoch: int, + grid_values: VoxelGridValuesBase, + mode: str = "linear", + align_corners: bool = True, + antialias: bool = False, + ) -> Tuple[VoxelGridValuesBase, bool]: + """ + Changes resolution of tensors in `grid_values` to match the `wanted_resolution`. + + Args: + epoch: current training epoch, used to see if the grid needs regridding + grid_values: instance of self.values_type which contains + the voxel grid which will be interpolated to create the new grid + wanted_resolution: tuple of (x, y, z) resolutions which determine + new grid's resolution + align_corners: as for torch.nn.functional.interpolate + mode: as for torch.nn.functional.interpolate + 'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'. + Default: 'linear' + antialias: as for torch.nn.functional.interpolate. + Using anti-alias option + together with align_corners=False and mode='bicubic', interpolation + result would match Pillow result for downsampling operation. + Supported mode: 'bicubic' + Returns: + tuple of + - new voxel grid_values of desired resolution, of type self.values_type + - True if regridding has happened. + """ + if epoch not in self.resolution_changes: + return grid_values, False + + if mode not in ("nearest", "bicubic", "linear", "area", "nearest-exact"): + raise ValueError( + "`mode` should be one of the following 'nearest'" + + "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'" + ) + + def change_individual_resolution(tensor, wanted_resolution): + if mode == "linear": + n_dim = len(wanted_resolution) + new_mode = ("linear", "bilinear", "trilinear")[n_dim - 1] + else: + new_mode = mode + return torch.nn.functional.interpolate( + input=tensor, + size=wanted_resolution, + mode=new_mode, + align_corners=align_corners, + antialias=antialias, + recompute_scale_factor=False, + ) + + wanted_shapes = self.get_shapes(epoch=epoch) + params = { + name: change_individual_resolution(getattr(grid_values, name), shape[1:]) + for name, shape in wanted_shapes.items() + } + # pyre-ignore[29] + return self.values_type(**params), True + + def get_resolution_change_epochs(self) -> List[int]: + """ + Returns epochs at which this grid should change epochs. + """ + return list(self.resolution_changes.keys()) + @dataclass class FullResolutionVoxelGridValues(VoxelGridValuesBase): @@ -185,8 +276,9 @@ class FullResolutionVoxelGrid(VoxelGridBase): ) return interpolated.view(*recorded_shape[:-1], -1) - def get_shapes(self) -> Dict[str, Tuple]: - return {"voxel_grid": (self.n_features, *self.resolution)} + def get_shapes(self, epoch: int) -> Dict[str, Tuple]: + width, height, depth = self.get_resolution(epoch) + return {"voxel_grid": (self.n_features, width, height, depth)} @dataclass @@ -212,7 +304,7 @@ class CPFactorizedVoxelGrid(VoxelGridBase): Each element of this sum has an extra dimension, which gets matrix-multiplied by an appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication - brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements + brings us to the desired "n_features" dimensionality. If basis_matrix=False the elements of different components are summed together to create (n_grids, n_components, 1) tensor. With some notation abuse, ignoring the interpolation operation, simplifying and denoting n_features as F, n_components as C and n_grids as G: @@ -223,7 +315,7 @@ class CPFactorizedVoxelGrid(VoxelGridBase): Members: n_components: number of vector triplets, higher number gives better approximation. - matrix_reduction: how to transform components. If matrix_reduction=True result + basis_matrix: how to transform components. If matrix_reduction=True result matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied by the basis_matrix of shape (n_grids, n_components, n_features). If matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components) @@ -235,7 +327,7 @@ class CPFactorizedVoxelGrid(VoxelGridBase): values_type: ClassVar[Type[VoxelGridValuesBase]] = CPFactorizedVoxelGridValues n_components: int = 24 - matrix_reduction: bool = True + basis_matrix: bool = True # pyre-fixme[14]: `evaluate_local` overrides method defined in `VoxelGridBase` # inconsistently. @@ -274,16 +366,17 @@ class CPFactorizedVoxelGrid(VoxelGridBase): # (n_grids, ..., n_features) return result.view(*recorded_shape[:-1], -1) - def get_shapes(self) -> Dict[str, Tuple[int, int]]: - if self.matrix_reduction is False and self.n_features != 1: - raise ValueError("Cannot set matrix_reduction=False and n_features to != 1") + def get_shapes(self, epoch: int) -> Dict[str, Tuple[int, int]]: + if self.basis_matrix is False and self.n_features != 1: + raise ValueError("Cannot set basis_matrix=False and n_features to != 1") + width, height, depth = self.get_resolution(epoch=epoch) shape_dict = { - "vector_components_x": (self.n_components, self.resolution[0]), - "vector_components_y": (self.n_components, self.resolution[1]), - "vector_components_z": (self.n_components, self.resolution[2]), + "vector_components_x": (self.n_components, width), + "vector_components_y": (self.n_components, height), + "vector_components_z": (self.n_components, depth), } - if self.matrix_reduction: + if self.basis_matrix: shape_dict["basis_matrix"] = (self.n_components, self.n_features) return shape_dict @@ -321,7 +414,7 @@ class VMFactorizedVoxelGrid(VoxelGridBase): Each element of this sum has an extra dimension, which gets matrix-multiplied by an appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication - brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements + brings us to the desired "n_features" dimensionality. If basis_matrix=False the elements of different components are summed together to create (n_grids, n_components, 1) tensor. With some notation abuse, ignoring the interpolation operation, simplifying and denoting n_features as F, n_components as C (which can differ for each dimension) and n_grids as G: @@ -338,7 +431,7 @@ class VMFactorizedVoxelGrid(VoxelGridBase): all 3 directions specify a tuple of numbers of matrix_vector pairs for each coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify either n_components or distribution_of_components, you cannot specify both. - matrix_reduction: how to transform components. If matrix_reduction=True result + basis_matrix: how to transform components. If matrix_reduction=True result matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied by the basis_matrix of shape (n_grids, n_components, n_features). If matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components) @@ -351,7 +444,7 @@ class VMFactorizedVoxelGrid(VoxelGridBase): n_components: Optional[int] = None distribution_of_components: Optional[Tuple[int, int, int]] = None - matrix_reduction: bool = True + basis_matrix: bool = True # pyre-fixme[14]: `evaluate_local` overrides method defined in `VoxelGridBase` # inconsistently. @@ -419,9 +512,9 @@ class VMFactorizedVoxelGrid(VoxelGridBase): # (n_grids, ..., n_features) return result.view(*recorded_shape[:-1], -1) - def get_shapes(self) -> Dict[str, Tuple]: - if self.matrix_reduction is False and self.n_features != 1: - raise ValueError("Cannot set matrix_reduction=False and n_features to != 1") + def get_shapes(self, epoch: int) -> Dict[str, Tuple]: + if self.basis_matrix is False and self.n_features != 1: + raise ValueError("Cannot set basis_matrix=False and n_features to != 1") if self.distribution_of_components is None and self.n_components is None: raise ValueError( "You need to provide n_components or distribution_of_components" @@ -446,36 +539,37 @@ class VMFactorizedVoxelGrid(VoxelGridBase): else: calculated_distribution_of_components = self.distribution_of_components + width, height, depth = self.get_resolution(epoch=epoch) shape_dict = { "vector_components_x": ( calculated_distribution_of_components[1], - self.resolution[0], + width, ), "vector_components_y": ( calculated_distribution_of_components[2], - self.resolution[1], + height, ), "vector_components_z": ( calculated_distribution_of_components[0], - self.resolution[2], + depth, ), "matrix_components_xy": ( calculated_distribution_of_components[0], - self.resolution[0], - self.resolution[1], + width, + height, ), "matrix_components_yz": ( calculated_distribution_of_components[1], - self.resolution[1], - self.resolution[2], + height, + depth, ), "matrix_components_xz": ( calculated_distribution_of_components[2], - self.resolution[0], - self.resolution[2], + width, + depth, ), } - if self.matrix_reduction: + if self.basis_matrix: shape_dict["basis_matrix"] = ( sum(calculated_distribution_of_components), self.n_features, @@ -517,7 +611,7 @@ class VoxelGridModule(Configurable, torch.nn.Module): super().__init__() run_auto_creation(self) n_grids = 1 # Voxel grid objects are batched. We need only a single grid. - shapes = self.voxel_grid.get_shapes() + shapes = self.voxel_grid.get_shapes(epoch=0) params = { name: torch.normal( mean=torch.zeros((n_grids, *shape)) + self.init_mean, @@ -526,6 +620,7 @@ class VoxelGridModule(Configurable, torch.nn.Module): for name, shape in shapes.items() } self.params = torch.nn.ParameterDict(params) + self._register_load_state_dict_pre_hook(self._create_parameters_with_new_size) def forward(self, points: torch.Tensor) -> torch.Tensor: """ @@ -554,7 +649,7 @@ class VoxelGridModule(Configurable, torch.nn.Module): voxel_size=tuple(self.extents), volume_translation=tuple(self.translation), # pyre-ignore[29] - device=next(self.params.values()).device, + device=next(val for val in self.params.values() if val is not None).device, ) # pyre-fixme[29]: `Union[torch._tensor.Tensor, # torch.nn.modules.module.Module]` is not a function. @@ -576,3 +671,83 @@ class VoxelGridModule(Configurable, torch.nn.Module): return grid.get_output_dim( args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"] ) + + def subscribe_to_epochs(self) -> Tuple[List[int], Callable[[int], bool]]: + """ + Method which expresses interest in subscribing to optimization epoch updates. + + Returns: + list of epochs on which to call a callable and callable to be called on + particular epoch. The callable returns True if parameter change has + happened else False and it must be supplied with one argument, epoch. + """ + return self.voxel_grid.get_resolution_change_epochs(), self._apply_epochs + + def _apply_epochs(self, epoch: int) -> bool: + """ + Asks voxel_grid to change the resolution. + This method is returned with subscribe_to_epochs and is the method that collects + updates on training epochs, it is run on the training epochs that are requested. + + Args: + epoch: current training epoch used for voxel grids to know to which + resolution to change + Returns: + True if parameter change has happened else False. + """ + # pyre-ignore[29] + grid_values = self.voxel_grid.values_type(**self.params) + grid_values, change = self.voxel_grid.change_resolution(epoch, grid_values) + if change: + # pyre-ignore[16] + self.params = torch.nn.ParameterDict( + {name: tensor for name, tensor in vars(grid_values).items()} + ) + return change + + def _create_parameters_with_new_size( + self, + state_dict: dict, + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + ''' + Automatically ran before loading the parameters with `load_state_dict()`. + Creates new parameters with the sizes of the ones in the loaded state dict. + This is necessary because the parameters are changing throughout training and + at the time of construction `VoxelGridModule` does not know the size of + parameters which will be loaded. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + Returns: + nothing + """ + ''' + new_params = {} + # pyre-ignore[29] + for name in self.params: + key = prefix + "params." + name + if key in state_dict: + new_params[name] = torch.zeros_like(state_dict[key]) + # pyre-ignore[16] + self.params = torch.nn.ParameterDict(new_params) diff --git a/tests/implicitron/test_voxel_grids.py b/tests/implicitron/test_voxel_grids.py index 110521cc..6ab93a21 100644 --- a/tests/implicitron/test_voxel_grids.py +++ b/tests/implicitron/test_voxel_grids.py @@ -9,6 +9,7 @@ import unittest from typing import Optional, Tuple import torch +from omegaconf import DictConfig from pytorch3d.implicitron.models.implicit_function.utils import ( interpolate_line, @@ -62,11 +63,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): # be of shape (n_grids, n_points, n_features) and be filled with n_components # * value grid = CPFactorizedVoxelGrid( - resolution=resolution, + resolution_changes={0: resolution}, n_components=n_components, n_features=n_features, ) - shapes = grid.get_shapes() + shapes = grid.get_shapes(epoch=0) params = grid.values_type( **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes} @@ -91,11 +92,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): # set everything to 'value' and do query for elements grid = VMFactorizedVoxelGrid( n_features=n_features, - resolution=resolution, + resolution_changes={0: resolution}, n_components=n_components, distribution_of_components=distribution, ) - shapes = grid.get_shapes() + shapes = grid.get_shapes(epoch=0) params = grid.values_type( **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes} ) @@ -118,8 +119,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): n_points: int = 1, ) -> None: # set everything to 'value' and do query for elements - grid = FullResolutionVoxelGrid(n_features=n_features, resolution=resolution) - shapes = grid.get_shapes() + grid = FullResolutionVoxelGrid( + n_features=n_features, resolution_changes={0: resolution} + ) + shapes = grid.get_shapes(epoch=0) params = grid.values_type( **{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes} ) @@ -329,8 +332,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): test querying the voxel grids on some float positions """ with self.subTest("FullResolution"): - grid = FullResolutionVoxelGrid(n_features=1, resolution=(1, 1, 1)) - params = grid.values_type(**grid.get_shapes()) + grid = FullResolutionVoxelGrid( + n_features=1, resolution_changes={0: (1, 1, 1)} + ) + params = grid.values_type(**grid.get_shapes(epoch=0)) params.voxel_grid = torch.tensor( [ [ @@ -377,9 +382,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): ), grid.evaluate_local(points, params) with self.subTest("CP"): grid = CPFactorizedVoxelGrid( - n_features=1, resolution=(1, 1, 1), n_components=3 + n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3 ) - params = grid.values_type(**grid.get_shapes()) + params = grid.values_type(**grid.get_shapes(epoch=0)) params.vector_components_x = torch.tensor( [ [[1, 2], [10.5, 20.5]], @@ -453,9 +458,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): with self.subTest("VM"): grid = VMFactorizedVoxelGrid( - n_features=1, resolution=(1, 1, 1), n_components=3 + n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3 ) - params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes()) + params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes(epoch=0)) params.matrix_components_xy = torch.tensor( [ [[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]], @@ -555,7 +560,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): with self.subTest(cls.__name__): n_grids = 3 grid = cls(**kwargs) - shapes = grid.get_shapes() + shapes = grid.get_shapes(epoch=0) params = cls.values_type( **{ k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001) @@ -570,18 +575,18 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): test( FullResolutionVoxelGrid, - resolution=(4, 6, 9), + resolution_changes={0: (4, 6, 9)}, n_features=10, ) test( CPFactorizedVoxelGrid, - resolution=(4, 6, 9), + resolution_changes={0: (4, 6, 9)}, n_features=10, n_components=3, ) test( VMFactorizedVoxelGrid, - resolution=(4, 6, 9), + resolution_changes={0: (4, 6, 9)}, n_features=10, n_components=3, ) @@ -609,3 +614,105 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase): grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0], rtol=0.0001, ) + + def test_resolution_change(self, n_times=10): + for _ in range(n_times): + n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist() + resolution = torch.randint(3, 10, (3,)).tolist() + resolution2 = torch.randint(3, 10, (3,)).tolist() + resolution_changes = {0: resolution, 1: resolution2} + n_components *= 3 + for cls, kwargs in ( + ( + FullResolutionVoxelGrid, + { + "n_features": n_features, + "resolution_changes": resolution_changes, + }, + ), + ( + CPFactorizedVoxelGrid, + { + "n_features": n_features, + "resolution_changes": resolution_changes, + "n_components": n_components, + }, + ), + ( + VMFactorizedVoxelGrid, + { + "n_features": n_features, + "resolution_changes": resolution_changes, + "n_components": n_components, + }, + ), + ): + with self.subTest(cls.__name__): + grid = cls(**kwargs) + self.assertEqual(grid.get_resolution(epoch=0), resolution) + shapes = grid.get_shapes(epoch=0) + params = { + name: torch.randn((n_grids, *shape)) + for name, shape in shapes.items() + } + grid_values = grid.values_type(**params) + grid_values_changed_resolution, change = grid.change_resolution( + epoch=1, + grid_values=grid_values, + mode="linear", + ) + assert change + self.assertEqual(grid.get_resolution(epoch=1), resolution2) + shapes_changed_resolution = grid.get_shapes(epoch=1) + for name, expected_shape in shapes_changed_resolution.items(): + shape = getattr(grid_values_changed_resolution, name).shape + self.assertEqual(expected_shape, shape[1:]) + + with self.subTest("VoxelGridModule"): + n_changes = 10 + grid = VoxelGridModule() + resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)} + grid.voxel_grid = FullResolutionVoxelGrid( + resolution_changes=resolution_changes + ) + epochs, apply_func = grid.subscribe_to_epochs() + self.assertEqual(list(range(n_changes)), list(epochs)) + for epoch in epochs: + change = apply_func(epoch) + assert change + self.assertEqual( + resolution_changes[epoch], + grid.voxel_grid.get_resolution(epoch=epoch), + ) + + def test_loading_state_dict(self): + """ + Test loading state dict after rescaling. + + Create a voxel grid, rescale it and get the state_dict. + Create a new voxel grid with the same args as the first one and load + the state_dict and check if everything is ok. + """ + n_changes = 10 + + resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)} + cfg = DictConfig( + { + "voxel_grid_class_type": "VMFactorizedVoxelGrid", + "voxel_grid_VMFactorizedVoxelGrid_args": { + "resolution_changes": resolution_changes, + "n_components": 48, + }, + } + ) + grid = VoxelGridModule(**cfg) + epochs, apply_func = grid.subscribe_to_epochs() + for epoch in epochs: + apply_func(epoch) + + loaded_grid = VoxelGridModule(**cfg) + loaded_grid.load_state_dict(grid.state_dict()) + for name_loaded, param_loaded in loaded_grid.named_parameters(): + for name, param in grid.named_parameters(): + if name_loaded == name: + torch.allclose(param_loaded, param)