mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add rescaling to voxel grids
Summary: Any module can be subscribed to step updates from the training loop. Once the training loop publishes a step the voxel grid changes its dimensions. During the construction of VoxelGridModule and its parameters it does not know which is the resolution that will be loaded from checkpoint, so before the checkpoint loading a hook runs which changes the VoxelGridModule's parameters to match shapes of the loaded checkpoint. Reviewed By: bottler Differential Revision: D39026775 fbshipit-source-id: 0d359ea5c8d2eda11d773d79c7513c83585d5f17
This commit is contained in:
parent
efea540bbc
commit
5005f09118
@ -15,8 +15,8 @@ these classes.
|
||||
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Dict, Optional, Tuple, Type
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
@ -58,17 +58,22 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
will actually be trilinear.
|
||||
n_features: number of dimensions of base feature vector. Determines how many features
|
||||
the grid returns.
|
||||
resolution: 3-tuple containing x, y, z grid sizes corresponding to each axis.
|
||||
resolution_changes: a dictionary, where keys are change epochs and values are
|
||||
3-tuples containing x, y, z grid sizes corresponding to each axis to each epoch
|
||||
"""
|
||||
|
||||
align_corners: bool = True
|
||||
padding: str = "zeros"
|
||||
mode: str = "bilinear"
|
||||
n_features: int = 1
|
||||
resolution: Tuple[int, int, int] = (128, 128, 128)
|
||||
resolution_changes: Dict[int, List[int]] = field(
|
||||
default_factory=lambda: {0: [128, 128, 128]}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
if 0 not in self.resolution_changes:
|
||||
raise ValueError("There has to be key `0` in `resolution_changes`.")
|
||||
|
||||
def evaluate_world(
|
||||
self,
|
||||
@ -109,11 +114,13 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
|
||||
"""
|
||||
Using parameters from the __init__ method, this method returns the
|
||||
shapes of individual tensors needed to run the evaluate method.
|
||||
|
||||
Args:
|
||||
epoch: If the shape varies during training, which training epoch's shape to return.
|
||||
Returns:
|
||||
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
|
||||
replace the shapes in the dictionary with tensors of those shapes and add the
|
||||
@ -123,6 +130,21 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_resolution(self, epoch: int) -> List[int]:
|
||||
"""
|
||||
Returns the resolution which the grid should have at specific epoch
|
||||
|
||||
Args:
|
||||
epoch which to use in the resolution calculation
|
||||
Returns:
|
||||
resolution at specific epoch
|
||||
"""
|
||||
last_change = 0
|
||||
for change_epoch in self.resolution_changes:
|
||||
if change_epoch <= epoch:
|
||||
last_change = max(last_change, change_epoch)
|
||||
return self.resolution_changes[last_change]
|
||||
|
||||
@staticmethod
|
||||
def get_output_dim(args: DictConfig) -> int:
|
||||
"""
|
||||
@ -140,6 +162,75 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
return args["n_features"]
|
||||
|
||||
def change_resolution(
|
||||
self,
|
||||
epoch: int,
|
||||
grid_values: VoxelGridValuesBase,
|
||||
mode: str = "linear",
|
||||
align_corners: bool = True,
|
||||
antialias: bool = False,
|
||||
) -> Tuple[VoxelGridValuesBase, bool]:
|
||||
"""
|
||||
Changes resolution of tensors in `grid_values` to match the `wanted_resolution`.
|
||||
|
||||
Args:
|
||||
epoch: current training epoch, used to see if the grid needs regridding
|
||||
grid_values: instance of self.values_type which contains
|
||||
the voxel grid which will be interpolated to create the new grid
|
||||
wanted_resolution: tuple of (x, y, z) resolutions which determine
|
||||
new grid's resolution
|
||||
align_corners: as for torch.nn.functional.interpolate
|
||||
mode: as for torch.nn.functional.interpolate
|
||||
'nearest' | 'bicubic' | 'linear' | 'area' | 'nearest-exact'.
|
||||
Default: 'linear'
|
||||
antialias: as for torch.nn.functional.interpolate.
|
||||
Using anti-alias option
|
||||
together with align_corners=False and mode='bicubic', interpolation
|
||||
result would match Pillow result for downsampling operation.
|
||||
Supported mode: 'bicubic'
|
||||
Returns:
|
||||
tuple of
|
||||
- new voxel grid_values of desired resolution, of type self.values_type
|
||||
- True if regridding has happened.
|
||||
"""
|
||||
if epoch not in self.resolution_changes:
|
||||
return grid_values, False
|
||||
|
||||
if mode not in ("nearest", "bicubic", "linear", "area", "nearest-exact"):
|
||||
raise ValueError(
|
||||
"`mode` should be one of the following 'nearest'"
|
||||
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
|
||||
)
|
||||
|
||||
def change_individual_resolution(tensor, wanted_resolution):
|
||||
if mode == "linear":
|
||||
n_dim = len(wanted_resolution)
|
||||
new_mode = ("linear", "bilinear", "trilinear")[n_dim - 1]
|
||||
else:
|
||||
new_mode = mode
|
||||
return torch.nn.functional.interpolate(
|
||||
input=tensor,
|
||||
size=wanted_resolution,
|
||||
mode=new_mode,
|
||||
align_corners=align_corners,
|
||||
antialias=antialias,
|
||||
recompute_scale_factor=False,
|
||||
)
|
||||
|
||||
wanted_shapes = self.get_shapes(epoch=epoch)
|
||||
params = {
|
||||
name: change_individual_resolution(getattr(grid_values, name), shape[1:])
|
||||
for name, shape in wanted_shapes.items()
|
||||
}
|
||||
# pyre-ignore[29]
|
||||
return self.values_type(**params), True
|
||||
|
||||
def get_resolution_change_epochs(self) -> List[int]:
|
||||
"""
|
||||
Returns epochs at which this grid should change epochs.
|
||||
"""
|
||||
return list(self.resolution_changes.keys())
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
|
||||
@ -185,8 +276,9 @@ class FullResolutionVoxelGrid(VoxelGridBase):
|
||||
)
|
||||
return interpolated.view(*recorded_shape[:-1], -1)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
return {"voxel_grid": (self.n_features, *self.resolution)}
|
||||
def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
|
||||
width, height, depth = self.get_resolution(epoch)
|
||||
return {"voxel_grid": (self.n_features, width, height, depth)}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -212,7 +304,7 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
|
||||
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
|
||||
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
|
||||
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
|
||||
brings us to the desired "n_features" dimensionality. If basis_matrix=False the elements
|
||||
of different components are summed together to create (n_grids, n_components, 1) tensor.
|
||||
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
|
||||
n_features as F, n_components as C and n_grids as G:
|
||||
@ -223,7 +315,7 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
|
||||
Members:
|
||||
n_components: number of vector triplets, higher number gives better approximation.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
basis_matrix: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
|
||||
by the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
|
||||
@ -235,7 +327,7 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
values_type: ClassVar[Type[VoxelGridValuesBase]] = CPFactorizedVoxelGridValues
|
||||
|
||||
n_components: int = 24
|
||||
matrix_reduction: bool = True
|
||||
basis_matrix: bool = True
|
||||
|
||||
# pyre-fixme[14]: `evaluate_local` overrides method defined in `VoxelGridBase`
|
||||
# inconsistently.
|
||||
@ -274,16 +366,17 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
# (n_grids, ..., n_features)
|
||||
return result.view(*recorded_shape[:-1], -1)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
|
||||
def get_shapes(self, epoch: int) -> Dict[str, Tuple[int, int]]:
|
||||
if self.basis_matrix is False and self.n_features != 1:
|
||||
raise ValueError("Cannot set basis_matrix=False and n_features to != 1")
|
||||
|
||||
width, height, depth = self.get_resolution(epoch=epoch)
|
||||
shape_dict = {
|
||||
"vector_components_x": (self.n_components, self.resolution[0]),
|
||||
"vector_components_y": (self.n_components, self.resolution[1]),
|
||||
"vector_components_z": (self.n_components, self.resolution[2]),
|
||||
"vector_components_x": (self.n_components, width),
|
||||
"vector_components_y": (self.n_components, height),
|
||||
"vector_components_z": (self.n_components, depth),
|
||||
}
|
||||
if self.matrix_reduction:
|
||||
if self.basis_matrix:
|
||||
shape_dict["basis_matrix"] = (self.n_components, self.n_features)
|
||||
return shape_dict
|
||||
|
||||
@ -321,7 +414,7 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
|
||||
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
|
||||
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
|
||||
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
|
||||
brings us to the desired "n_features" dimensionality. If basis_matrix=False the elements
|
||||
of different components are summed together to create (n_grids, n_components, 1) tensor.
|
||||
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
|
||||
n_features as F, n_components as C (which can differ for each dimension) and n_grids as G:
|
||||
@ -338,7 +431,7 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
all 3 directions specify a tuple of numbers of matrix_vector pairs for each
|
||||
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
|
||||
either n_components or distribution_of_components, you cannot specify both.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
basis_matrix: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
|
||||
by the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
|
||||
@ -351,7 +444,7 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
|
||||
n_components: Optional[int] = None
|
||||
distribution_of_components: Optional[Tuple[int, int, int]] = None
|
||||
matrix_reduction: bool = True
|
||||
basis_matrix: bool = True
|
||||
|
||||
# pyre-fixme[14]: `evaluate_local` overrides method defined in `VoxelGridBase`
|
||||
# inconsistently.
|
||||
@ -419,9 +512,9 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
# (n_grids, ..., n_features)
|
||||
return result.view(*recorded_shape[:-1], -1)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
|
||||
def get_shapes(self, epoch: int) -> Dict[str, Tuple]:
|
||||
if self.basis_matrix is False and self.n_features != 1:
|
||||
raise ValueError("Cannot set basis_matrix=False and n_features to != 1")
|
||||
if self.distribution_of_components is None and self.n_components is None:
|
||||
raise ValueError(
|
||||
"You need to provide n_components or distribution_of_components"
|
||||
@ -446,36 +539,37 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
else:
|
||||
calculated_distribution_of_components = self.distribution_of_components
|
||||
|
||||
width, height, depth = self.get_resolution(epoch=epoch)
|
||||
shape_dict = {
|
||||
"vector_components_x": (
|
||||
calculated_distribution_of_components[1],
|
||||
self.resolution[0],
|
||||
width,
|
||||
),
|
||||
"vector_components_y": (
|
||||
calculated_distribution_of_components[2],
|
||||
self.resolution[1],
|
||||
height,
|
||||
),
|
||||
"vector_components_z": (
|
||||
calculated_distribution_of_components[0],
|
||||
self.resolution[2],
|
||||
depth,
|
||||
),
|
||||
"matrix_components_xy": (
|
||||
calculated_distribution_of_components[0],
|
||||
self.resolution[0],
|
||||
self.resolution[1],
|
||||
width,
|
||||
height,
|
||||
),
|
||||
"matrix_components_yz": (
|
||||
calculated_distribution_of_components[1],
|
||||
self.resolution[1],
|
||||
self.resolution[2],
|
||||
height,
|
||||
depth,
|
||||
),
|
||||
"matrix_components_xz": (
|
||||
calculated_distribution_of_components[2],
|
||||
self.resolution[0],
|
||||
self.resolution[2],
|
||||
width,
|
||||
depth,
|
||||
),
|
||||
}
|
||||
if self.matrix_reduction:
|
||||
if self.basis_matrix:
|
||||
shape_dict["basis_matrix"] = (
|
||||
sum(calculated_distribution_of_components),
|
||||
self.n_features,
|
||||
@ -517,7 +611,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
n_grids = 1 # Voxel grid objects are batched. We need only a single grid.
|
||||
shapes = self.voxel_grid.get_shapes()
|
||||
shapes = self.voxel_grid.get_shapes(epoch=0)
|
||||
params = {
|
||||
name: torch.normal(
|
||||
mean=torch.zeros((n_grids, *shape)) + self.init_mean,
|
||||
@ -526,6 +620,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
for name, shape in shapes.items()
|
||||
}
|
||||
self.params = torch.nn.ParameterDict(params)
|
||||
self._register_load_state_dict_pre_hook(self._create_parameters_with_new_size)
|
||||
|
||||
def forward(self, points: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@ -554,7 +649,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
voxel_size=tuple(self.extents),
|
||||
volume_translation=tuple(self.translation),
|
||||
# pyre-ignore[29]
|
||||
device=next(self.params.values()).device,
|
||||
device=next(val for val in self.params.values() if val is not None).device,
|
||||
)
|
||||
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
|
||||
# torch.nn.modules.module.Module]` is not a function.
|
||||
@ -576,3 +671,83 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
return grid.get_output_dim(
|
||||
args["voxel_grid_" + args["voxel_grid_class_type"] + "_args"]
|
||||
)
|
||||
|
||||
def subscribe_to_epochs(self) -> Tuple[List[int], Callable[[int], bool]]:
|
||||
"""
|
||||
Method which expresses interest in subscribing to optimization epoch updates.
|
||||
|
||||
Returns:
|
||||
list of epochs on which to call a callable and callable to be called on
|
||||
particular epoch. The callable returns True if parameter change has
|
||||
happened else False and it must be supplied with one argument, epoch.
|
||||
"""
|
||||
return self.voxel_grid.get_resolution_change_epochs(), self._apply_epochs
|
||||
|
||||
def _apply_epochs(self, epoch: int) -> bool:
|
||||
"""
|
||||
Asks voxel_grid to change the resolution.
|
||||
This method is returned with subscribe_to_epochs and is the method that collects
|
||||
updates on training epochs, it is run on the training epochs that are requested.
|
||||
|
||||
Args:
|
||||
epoch: current training epoch used for voxel grids to know to which
|
||||
resolution to change
|
||||
Returns:
|
||||
True if parameter change has happened else False.
|
||||
"""
|
||||
# pyre-ignore[29]
|
||||
grid_values = self.voxel_grid.values_type(**self.params)
|
||||
grid_values, change = self.voxel_grid.change_resolution(epoch, grid_values)
|
||||
if change:
|
||||
# pyre-ignore[16]
|
||||
self.params = torch.nn.ParameterDict(
|
||||
{name: tensor for name, tensor in vars(grid_values).items()}
|
||||
)
|
||||
return change
|
||||
|
||||
def _create_parameters_with_new_size(
|
||||
self,
|
||||
state_dict: dict,
|
||||
prefix: str,
|
||||
local_metadata: dict,
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
'''
|
||||
Automatically ran before loading the parameters with `load_state_dict()`.
|
||||
Creates new parameters with the sizes of the ones in the loaded state dict.
|
||||
This is necessary because the parameters are changing throughout training and
|
||||
at the time of construction `VoxelGridModule` does not know the size of
|
||||
parameters which will be loaded.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
Returns:
|
||||
nothing
|
||||
"""
|
||||
'''
|
||||
new_params = {}
|
||||
# pyre-ignore[29]
|
||||
for name in self.params:
|
||||
key = prefix + "params." + name
|
||||
if key in state_dict:
|
||||
new_params[name] = torch.zeros_like(state_dict[key])
|
||||
# pyre-ignore[16]
|
||||
self.params = torch.nn.ParameterDict(new_params)
|
||||
|
@ -9,6 +9,7 @@ import unittest
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from pytorch3d.implicitron.models.implicit_function.utils import (
|
||||
interpolate_line,
|
||||
@ -62,11 +63,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
# be of shape (n_grids, n_points, n_features) and be filled with n_components
|
||||
# * value
|
||||
grid = CPFactorizedVoxelGrid(
|
||||
resolution=resolution,
|
||||
resolution_changes={0: resolution},
|
||||
n_components=n_components,
|
||||
n_features=n_features,
|
||||
)
|
||||
shapes = grid.get_shapes()
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
@ -91,11 +92,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
# set everything to 'value' and do query for elements
|
||||
grid = VMFactorizedVoxelGrid(
|
||||
n_features=n_features,
|
||||
resolution=resolution,
|
||||
resolution_changes={0: resolution},
|
||||
n_components=n_components,
|
||||
distribution_of_components=distribution,
|
||||
)
|
||||
shapes = grid.get_shapes()
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
)
|
||||
@ -118,8 +119,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_points: int = 1,
|
||||
) -> None:
|
||||
# set everything to 'value' and do query for elements
|
||||
grid = FullResolutionVoxelGrid(n_features=n_features, resolution=resolution)
|
||||
shapes = grid.get_shapes()
|
||||
grid = FullResolutionVoxelGrid(
|
||||
n_features=n_features, resolution_changes={0: resolution}
|
||||
)
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
)
|
||||
@ -329,8 +332,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
test querying the voxel grids on some float positions
|
||||
"""
|
||||
with self.subTest("FullResolution"):
|
||||
grid = FullResolutionVoxelGrid(n_features=1, resolution=(1, 1, 1))
|
||||
params = grid.values_type(**grid.get_shapes())
|
||||
grid = FullResolutionVoxelGrid(
|
||||
n_features=1, resolution_changes={0: (1, 1, 1)}
|
||||
)
|
||||
params = grid.values_type(**grid.get_shapes(epoch=0))
|
||||
params.voxel_grid = torch.tensor(
|
||||
[
|
||||
[
|
||||
@ -377,9 +382,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
), grid.evaluate_local(points, params)
|
||||
with self.subTest("CP"):
|
||||
grid = CPFactorizedVoxelGrid(
|
||||
n_features=1, resolution=(1, 1, 1), n_components=3
|
||||
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
|
||||
)
|
||||
params = grid.values_type(**grid.get_shapes())
|
||||
params = grid.values_type(**grid.get_shapes(epoch=0))
|
||||
params.vector_components_x = torch.tensor(
|
||||
[
|
||||
[[1, 2], [10.5, 20.5]],
|
||||
@ -453,9 +458,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
with self.subTest("VM"):
|
||||
grid = VMFactorizedVoxelGrid(
|
||||
n_features=1, resolution=(1, 1, 1), n_components=3
|
||||
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
|
||||
)
|
||||
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes())
|
||||
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes(epoch=0))
|
||||
params.matrix_components_xy = torch.tensor(
|
||||
[
|
||||
[[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]],
|
||||
@ -555,7 +560,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
with self.subTest(cls.__name__):
|
||||
n_grids = 3
|
||||
grid = cls(**kwargs)
|
||||
shapes = grid.get_shapes()
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = cls.values_type(
|
||||
**{
|
||||
k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001)
|
||||
@ -570,18 +575,18 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
test(
|
||||
FullResolutionVoxelGrid,
|
||||
resolution=(4, 6, 9),
|
||||
resolution_changes={0: (4, 6, 9)},
|
||||
n_features=10,
|
||||
)
|
||||
test(
|
||||
CPFactorizedVoxelGrid,
|
||||
resolution=(4, 6, 9),
|
||||
resolution_changes={0: (4, 6, 9)},
|
||||
n_features=10,
|
||||
n_components=3,
|
||||
)
|
||||
test(
|
||||
VMFactorizedVoxelGrid,
|
||||
resolution=(4, 6, 9),
|
||||
resolution_changes={0: (4, 6, 9)},
|
||||
n_features=10,
|
||||
n_components=3,
|
||||
)
|
||||
@ -609,3 +614,105 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
|
||||
rtol=0.0001,
|
||||
)
|
||||
|
||||
def test_resolution_change(self, n_times=10):
|
||||
for _ in range(n_times):
|
||||
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
|
||||
resolution = torch.randint(3, 10, (3,)).tolist()
|
||||
resolution2 = torch.randint(3, 10, (3,)).tolist()
|
||||
resolution_changes = {0: resolution, 1: resolution2}
|
||||
n_components *= 3
|
||||
for cls, kwargs in (
|
||||
(
|
||||
FullResolutionVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
},
|
||||
),
|
||||
(
|
||||
CPFactorizedVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": n_components,
|
||||
},
|
||||
),
|
||||
(
|
||||
VMFactorizedVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": n_components,
|
||||
},
|
||||
),
|
||||
):
|
||||
with self.subTest(cls.__name__):
|
||||
grid = cls(**kwargs)
|
||||
self.assertEqual(grid.get_resolution(epoch=0), resolution)
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = {
|
||||
name: torch.randn((n_grids, *shape))
|
||||
for name, shape in shapes.items()
|
||||
}
|
||||
grid_values = grid.values_type(**params)
|
||||
grid_values_changed_resolution, change = grid.change_resolution(
|
||||
epoch=1,
|
||||
grid_values=grid_values,
|
||||
mode="linear",
|
||||
)
|
||||
assert change
|
||||
self.assertEqual(grid.get_resolution(epoch=1), resolution2)
|
||||
shapes_changed_resolution = grid.get_shapes(epoch=1)
|
||||
for name, expected_shape in shapes_changed_resolution.items():
|
||||
shape = getattr(grid_values_changed_resolution, name).shape
|
||||
self.assertEqual(expected_shape, shape[1:])
|
||||
|
||||
with self.subTest("VoxelGridModule"):
|
||||
n_changes = 10
|
||||
grid = VoxelGridModule()
|
||||
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
|
||||
grid.voxel_grid = FullResolutionVoxelGrid(
|
||||
resolution_changes=resolution_changes
|
||||
)
|
||||
epochs, apply_func = grid.subscribe_to_epochs()
|
||||
self.assertEqual(list(range(n_changes)), list(epochs))
|
||||
for epoch in epochs:
|
||||
change = apply_func(epoch)
|
||||
assert change
|
||||
self.assertEqual(
|
||||
resolution_changes[epoch],
|
||||
grid.voxel_grid.get_resolution(epoch=epoch),
|
||||
)
|
||||
|
||||
def test_loading_state_dict(self):
|
||||
"""
|
||||
Test loading state dict after rescaling.
|
||||
|
||||
Create a voxel grid, rescale it and get the state_dict.
|
||||
Create a new voxel grid with the same args as the first one and load
|
||||
the state_dict and check if everything is ok.
|
||||
"""
|
||||
n_changes = 10
|
||||
|
||||
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
|
||||
cfg = DictConfig(
|
||||
{
|
||||
"voxel_grid_class_type": "VMFactorizedVoxelGrid",
|
||||
"voxel_grid_VMFactorizedVoxelGrid_args": {
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": 48,
|
||||
},
|
||||
}
|
||||
)
|
||||
grid = VoxelGridModule(**cfg)
|
||||
epochs, apply_func = grid.subscribe_to_epochs()
|
||||
for epoch in epochs:
|
||||
apply_func(epoch)
|
||||
|
||||
loaded_grid = VoxelGridModule(**cfg)
|
||||
loaded_grid.load_state_dict(grid.state_dict())
|
||||
for name_loaded, param_loaded in loaded_grid.named_parameters():
|
||||
for name, param in grid.named_parameters():
|
||||
if name_loaded == name:
|
||||
torch.allclose(param_loaded, param)
|
||||
|
Loading…
x
Reference in New Issue
Block a user