Add rescaling to voxel grids

Summary: Any module can be subscribed to step updates from the training loop. Once the training loop publishes a step the voxel grid changes its dimensions. During the construction of VoxelGridModule and its parameters it does not know which is the resolution that will be loaded from checkpoint, so before the checkpoint loading a hook runs which changes the VoxelGridModule's parameters to match shapes of the loaded checkpoint.

Reviewed By: bottler

Differential Revision: D39026775

fbshipit-source-id: 0d359ea5c8d2eda11d773d79c7513c83585d5f17
This commit is contained in:
Darijan Gudelj 2022-09-28 05:23:22 -07:00 committed by Facebook GitHub Bot
parent efea540bbc
commit 5005f09118
2 changed files with 333 additions and 51 deletions

View File

@ -15,8 +15,8 @@ these classes.
"""
from dataclasses import dataclass
from typing import ClassVar, Dict, Optional, Tuple, Type
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Type
import torch
from omegaconf import DictConfig
@ -58,17 +58,22 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
will actually be trilinear.
n_features: number of dimensions of base feature vector. Determines how many features
the grid returns.
resolution: 3-tuple containing x, y, z grid sizes corresponding to each axis.
resolution_changes: a dictionary, where keys are change epochs and values are
3-tuples containing x, y, z grid sizes corresponding to each axis to each epoch
"""
align_corners: bool = True
padding: str = "zeros"
mode: str = "bilinear"
n_features: int = 1
resolution: Tuple[int, int, int] = (128, 128, 128)
resolution_changes: Dict[int, List[int]] = field(
default_factory=lambda: {0: [128, 128, 128]}
)
def __post_init__(self):
super().__init__()
if 0 not in self.resolution_changes:
raise ValueError("There has to be key `0` in `resolution_changes`.")
def evaluate_world(
self,
@ -109,11 +114,13 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
"""
raise NotImplementedError()
def get_shapes(self) -> Dict[str, Tuple]:
def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
"""
Using parameters from the __init__ method, this method returns the
shapes of individual tensors needed to run the evaluate method.
Args:
epoch: If the shape varies during training, which training epoch's shape to return.
Returns:
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
replace the shapes in the dictionary with tensors of those shapes and add the
@ -123,6 +130,21 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
"""
raise NotImplementedError()
def get_resolution(self, epoch: int) -> List[int]:
"""
Returns the resolution which the grid should have at specific epoch
Args:
epoch which to use in the resolution calculation
Returns:
resolution at specific epoch
"""
last_change = 0
for change_epoch in self.resolution_changes:
if change_epoch <= epoch:
last_change = max(last_change, change_epoch)
return self.resolution_changes[last_change]
@staticmethod
def get_output_dim(args: DictConfig) -> int:
"""
@ -140,6 +162,75 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
"""
return args["n_features"]
def change_resolution(
self,
epoch: int,
grid_values: VoxelGridValuesBase,
mode: str = "linear",
align_corners: bool = True,
antialias: bool = False,
) -> Tuple[VoxelGridValuesBase, bool]:
"""
Changes resolution of tensors in `grid_values` to match the `wanted_resolution`.
Args:
epoch: current training epoch, used to see if the grid needs regridding
grid_values: instance of self.values_type which contains
the voxel grid which will be interpolated to create the new grid
wanted_resolution: tuple of (x, y, z) resolutions which determine
new grid's resolution
align_corners: as for torch.nn.functional.interpolate
mode: as for torch.nn.functional.interpolate
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
Default: 'linear'
antialias: as for torch.nn.functional.interpolate.
Using anti-alias option
together with align_corners=False and mode='bicubic', interpolation
result would match Pillow result for downsampling operation.
Supported mode: 'bicubic'
Returns:
tuple of
- new voxel grid_values of desired resolution, of type self.values_type
- True if regridding has happened.
"""
if epoch not in self.resolution_changes:
return grid_values, False
if mode not in ("nearest", "bicubic", "linear", "area", "nearest-exact"):
raise ValueError(
"`mode` should be one of the following 'nearest'"
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
)
def change_individual_resolution(tensor, wanted_resolution):
if mode == "linear":
n_dim = len(wanted_resolution)
new_mode = ("linear", "bilinear", "trilinear")[n_dim - 1]
else:
new_mode = mode
return torch.nn.functional.interpolate(
input=tensor,
size=wanted_resolution,
mode=new_mode,
align_corners=align_corners,
antialias=antialias,
recompute_scale_factor=False,
)
wanted_shapes = self.get_shapes(epoch=epoch)
params = {
name: change_individual_resolution(getattr(grid_values, name), shape[1:])
for name, shape in wanted_shapes.items()
}
# pyre-ignore[29]
return self.values_type(**params), True
def get_resolution_change_epochs(self) -> List[int]:
"""
Returns epochs at which this grid should change epochs.
"""
return list(self.resolution_changes.keys())
@dataclass
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
@ -185,8 +276,9 @@ class FullResolutionVoxelGrid(VoxelGridBase):
)
return interpolated.view(*recorded_shape[:-1], -1)
def get_shapes(self) -> Dict[str, Tuple]:
return {"voxel_grid": (self.n_features, *self.resolution)}
def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
width, height, depth = self.get_resolution(epoch)
return {"voxel_grid": (self.n_features, width, height, depth)}
@dataclass
@ -212,7 +304,7 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
brings us to the desired "n_features" dimensionality. If basis_matrix=False the elements
of different components are summed together to create (n_grids, n_components, 1) tensor.
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
n_features as F, n_components as C and n_grids as G:
@ -223,7 +315,7 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
Members:
n_components: number of vector triplets, higher number gives better approximation.
matrix_reduction: how to transform components. If matrix_reduction=True result
basis_matrix: how to transform components. If matrix_reduction=True result
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
by the basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
@ -235,7 +327,7 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
values_type: ClassVar[Type[VoxelGridValuesBase]] = CPFactorizedVoxelGridValues
n_components: int = 24
matrix_reduction: bool = True
basis_matrix: bool = True
# pyre-fixme[14]: `evaluate_local` overrides method defined in `VoxelGridBase`
# inconsistently.
@ -274,16 +366,17 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
# (n_grids, ..., n_features)
return result.view(*recorded_shape[:-1], -1)
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
if self.matrix_reduction is False and self.n_features != 1:
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
def get_shapes(self, epoch: int) -> Dict[str, Tuple[int, int]]:
if self.basis_matrix is False and self.n_features != 1:
raise ValueError("Cannot set basis_matrix=False and n_features to != 1")
width, height, depth = self.get_resolution(epoch=epoch)
shape_dict = {
"vector_components_x": (self.n_components, self.resolution[0]),
"vector_components_y": (self.n_components, self.resolution[1]),
"vector_components_z": (self.n_components, self.resolution[2]),
"vector_components_x": (self.n_components, width),
"vector_components_y": (self.n_components, height),
"vector_components_z": (self.n_components, depth),
}
if self.matrix_reduction:
if self.basis_matrix:
shape_dict["basis_matrix"] = (self.n_components, self.n_features)
return shape_dict
@ -321,7 +414,7 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
brings us to the desired "n_features" dimensionality. If basis_matrix=False the elements
of different components are summed together to create (n_grids, n_components, 1) tensor.
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
n_features as F, n_components as C (which can differ for each dimension) and n_grids as G:
@ -338,7 +431,7 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
all 3 directions specify a tuple of numbers of matrix_vector pairs for each
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
either n_components or distribution_of_components, you cannot specify both.
matrix_reduction: how to transform components. If matrix_reduction=True result
basis_matrix: how to transform components. If matrix_reduction=True result
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
by the basis_matrix of shape (n_grids, n_components, n_features). If
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
@ -351,7 +444,7 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
n_components: Optional[int] = None
distribution_of_components: Optional[Tuple[int, int, int]] = None
matrix_reduction: bool = True
basis_matrix: bool = True
# pyre-fixme[14]: `evaluate_local` overrides method defined in `VoxelGridBase`
# inconsistently.
@ -419,9 +512,9 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
# (n_grids, ..., n_features)
return result.view(*recorded_shape[:-1], -1)
def get_shapes(self) -> Dict[str, Tuple]:
if self.matrix_reduction is False and self.n_features != 1:
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
if self.basis_matrix is False and self.n_features != 1:
raise ValueError("Cannot set basis_matrix=False and n_features to != 1")
if self.distribution_of_components is None and self.n_components is None:
raise ValueError(
"You need to provide n_components or distribution_of_components"
@ -446,36 +539,37 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
else:
calculated_distribution_of_components = self.distribution_of_components
width, height, depth = self.get_resolution(epoch=epoch)
shape_dict = {
"vector_components_x": (
calculated_distribution_of_components[1],
self.resolution[0],
width,
),
"vector_components_y": (
calculated_distribution_of_components[2],
self.resolution[1],
height,
),
"vector_components_z": (
calculated_distribution_of_components[0],
self.resolution[2],
depth,
),
"matrix_components_xy": (
calculated_distribution_of_components[0],
self.resolution[0],
self.resolution[1],
width,
height,
),
"matrix_components_yz": (
calculated_distribution_of_components[1],
self.resolution[1],
self.resolution[2],
height,
depth,
),
"matrix_components_xz": (
calculated_distribution_of_components[2],
self.resolution[0],
self.resolution[2],
width,
depth,
),
}
if self.matrix_reduction:
if self.basis_matrix:
shape_dict["basis_matrix"] = (
sum(calculated_distribution_of_components),
self.n_features,
@ -517,7 +611,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
super().__init__()
run_auto_creation(self)
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
shapes = self.voxel_grid.get_shapes()
shapes = self.voxel_grid.get_shapes(epoch=0)
params = {
name: torch.normal(
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
@ -526,6 +620,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
for name, shape in shapes.items()
}
self.params = torch.nn.ParameterDict(params)
self._register_load_state_dict_pre_hook(self._create_parameters_with_new_size)
def forward(self, points: torch.Tensor) -> torch.Tensor:
"""
@ -554,7 +649,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
voxel_size=tuple(self.extents),
volume_translation=tuple(self.translation),
# pyre-ignore[29]
device=next(self.params.values()).device,
device=next(val for val in self.params.values() if val is not None).device,
)
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
@ -576,3 +671,83 @@ class VoxelGridModule(Configurable, torch.nn.Module):
return grid.get_output_dim(
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
)
def subscribe_to_epochs(self) -> Tuple[List[int], Callable[[int], bool]]:
"""
Method which expresses interest in subscribing to optimization epoch updates.
Returns:
list of epochs on which to call a callable and callable to be called on
particular epoch. The callable returns True if parameter change has
happened else False and it must be supplied with one argument, epoch.
"""
return self.voxel_grid.get_resolution_change_epochs(), self._apply_epochs
def _apply_epochs(self, epoch: int) -> bool:
"""
Asks voxel_grid to change the resolution.
This method is returned with subscribe_to_epochs and is the method that collects
updates on training epochs, it is run on the training epochs that are requested.
Args:
epoch: current training epoch used for voxel grids to know to which
resolution to change
Returns:
True if parameter change has happened else False.
"""
# pyre-ignore[29]
grid_values = self.voxel_grid.values_type(**self.params)
grid_values, change = self.voxel_grid.change_resolution(epoch, grid_values)
if change:
# pyre-ignore[16]
self.params = torch.nn.ParameterDict(
{name: tensor for name, tensor in vars(grid_values).items()}
)
return change
def _create_parameters_with_new_size(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
'''
Automatically ran before loading the parameters with `load_state_dict()`.
Creates new parameters with the sizes of the ones in the loaded state dict.
This is necessary because the parameters are changing throughout training and
at the time of construction `VoxelGridModule` does not know the size of
parameters which will be loaded.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
Returns:
nothing
"""
'''
new_params = {}
# pyre-ignore[29]
for name in self.params:
key = prefix + "params." + name
if key in state_dict:
new_params[name] = torch.zeros_like(state_dict[key])
# pyre-ignore[16]
self.params = torch.nn.ParameterDict(new_params)

View File

@ -9,6 +9,7 @@ import unittest
from typing import Optional, Tuple
import torch
from omegaconf import DictConfig
from pytorch3d.implicitron.models.implicit_function.utils import (
interpolate_line,
@ -62,11 +63,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
# be of shape (n_grids, n_points, n_features) and be filled with n_components
# * value
grid = CPFactorizedVoxelGrid(
resolution=resolution,
resolution_changes={0: resolution},
n_components=n_components,
n_features=n_features,
)
shapes = grid.get_shapes()
shapes = grid.get_shapes(epoch=0)
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
@ -91,11 +92,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
# set everything to 'value' and do query for elements
grid = VMFactorizedVoxelGrid(
n_features=n_features,
resolution=resolution,
resolution_changes={0: resolution},
n_components=n_components,
distribution_of_components=distribution,
)
shapes = grid.get_shapes()
shapes = grid.get_shapes(epoch=0)
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
)
@ -118,8 +119,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_points: int = 1,
) -> None:
# set everything to 'value' and do query for elements
grid = FullResolutionVoxelGrid(n_features=n_features, resolution=resolution)
shapes = grid.get_shapes()
grid = FullResolutionVoxelGrid(
n_features=n_features, resolution_changes={0: resolution}
)
shapes = grid.get_shapes(epoch=0)
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
)
@ -329,8 +332,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
test querying the voxel grids on some float positions
"""
with self.subTest("FullResolution"):
grid = FullResolutionVoxelGrid(n_features=1, resolution=(1, 1, 1))
params = grid.values_type(**grid.get_shapes())
grid = FullResolutionVoxelGrid(
n_features=1, resolution_changes={0: (1, 1, 1)}
)
params = grid.values_type(**grid.get_shapes(epoch=0))
params.voxel_grid = torch.tensor(
[
[
@ -377,9 +382,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
), grid.evaluate_local(points, params)
with self.subTest("CP"):
grid = CPFactorizedVoxelGrid(
n_features=1, resolution=(1, 1, 1), n_components=3
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
)
params = grid.values_type(**grid.get_shapes())
params = grid.values_type(**grid.get_shapes(epoch=0))
params.vector_components_x = torch.tensor(
[
[[1, 2], [10.5, 20.5]],
@ -453,9 +458,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
with self.subTest("VM"):
grid = VMFactorizedVoxelGrid(
n_features=1, resolution=(1, 1, 1), n_components=3
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
)
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes())
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes(epoch=0))
params.matrix_components_xy = torch.tensor(
[
[[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]],
@ -555,7 +560,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
with self.subTest(cls.__name__):
n_grids = 3
grid = cls(**kwargs)
shapes = grid.get_shapes()
shapes = grid.get_shapes(epoch=0)
params = cls.values_type(
**{
k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001)
@ -570,18 +575,18 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
test(
FullResolutionVoxelGrid,
resolution=(4, 6, 9),
resolution_changes={0: (4, 6, 9)},
n_features=10,
)
test(
CPFactorizedVoxelGrid,
resolution=(4, 6, 9),
resolution_changes={0: (4, 6, 9)},
n_features=10,
n_components=3,
)
test(
VMFactorizedVoxelGrid,
resolution=(4, 6, 9),
resolution_changes={0: (4, 6, 9)},
n_features=10,
n_components=3,
)
@ -609,3 +614,105 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
rtol=0.0001,
)
def test_resolution_change(self, n_times=10):
for _ in range(n_times):
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
resolution = torch.randint(3, 10, (3,)).tolist()
resolution2 = torch.randint(3, 10, (3,)).tolist()
resolution_changes = {0: resolution, 1: resolution2}
n_components *= 3
for cls, kwargs in (
(
FullResolutionVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
},
),
(
CPFactorizedVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
"n_components": n_components,
},
),
(
VMFactorizedVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
"n_components": n_components,
},
),
):
with self.subTest(cls.__name__):
grid = cls(**kwargs)
self.assertEqual(grid.get_resolution(epoch=0), resolution)
shapes = grid.get_shapes(epoch=0)
params = {
name: torch.randn((n_grids, *shape))
for name, shape in shapes.items()
}
grid_values = grid.values_type(**params)
grid_values_changed_resolution, change = grid.change_resolution(
epoch=1,
grid_values=grid_values,
mode="linear",
)
assert change
self.assertEqual(grid.get_resolution(epoch=1), resolution2)
shapes_changed_resolution = grid.get_shapes(epoch=1)
for name, expected_shape in shapes_changed_resolution.items():
shape = getattr(grid_values_changed_resolution, name).shape
self.assertEqual(expected_shape, shape[1:])
with self.subTest("VoxelGridModule"):
n_changes = 10
grid = VoxelGridModule()
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
grid.voxel_grid = FullResolutionVoxelGrid(
resolution_changes=resolution_changes
)
epochs, apply_func = grid.subscribe_to_epochs()
self.assertEqual(list(range(n_changes)), list(epochs))
for epoch in epochs:
change = apply_func(epoch)
assert change
self.assertEqual(
resolution_changes[epoch],
grid.voxel_grid.get_resolution(epoch=epoch),
)
def test_loading_state_dict(self):
"""
Test loading state dict after rescaling.
Create a voxel grid, rescale it and get the state_dict.
Create a new voxel grid with the same args as the first one and load
the state_dict and check if everything is ok.
"""
n_changes = 10
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
cfg = DictConfig(
{
"voxel_grid_class_type": "VMFactorizedVoxelGrid",
"voxel_grid_VMFactorizedVoxelGrid_args": {
"resolution_changes": resolution_changes,
"n_components": 48,
},
}
)
grid = VoxelGridModule(**cfg)
epochs, apply_func = grid.subscribe_to_epochs()
for epoch in epochs:
apply_func(epoch)
loaded_grid = VoxelGridModule(**cfg)
loaded_grid.load_state_dict(grid.state_dict())
for name_loaded, param_loaded in loaded_grid.named_parameters():
for name, param in grid.named_parameters():
if name_loaded == name:
torch.allclose(param_loaded, param)