Add rescaling to voxel grids

Summary: Any module can be subscribed to step updates from the training loop. Once the training loop publishes a step the voxel grid changes its dimensions. During the construction of VoxelGridModule and its parameters it does not know which is the resolution that will be loaded from checkpoint, so before the checkpoint loading a hook runs which changes the VoxelGridModule's parameters to match shapes of the loaded checkpoint.

Reviewed By: bottler

Differential Revision: D39026775

fbshipit-source-id: 0d359ea5c8d2eda11d773d79c7513c83585d5f17
This commit is contained in:
Darijan Gudelj
2022-09-28 05:23:22 -07:00
committed by Facebook GitHub Bot
parent efea540bbc
commit 5005f09118
2 changed files with 333 additions and 51 deletions

View File

@@ -9,6 +9,7 @@ import unittest
from typing import Optional, Tuple
import torch
from omegaconf import DictConfig
from pytorch3d.implicitron.models.implicit_function.utils import (
interpolate_line,
@@ -62,11 +63,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
# be of shape (n_grids, n_points, n_features) and be filled with n_components
# * value
grid = CPFactorizedVoxelGrid(
resolution=resolution,
resolution_changes={0: resolution},
n_components=n_components,
n_features=n_features,
)
shapes = grid.get_shapes()
shapes = grid.get_shapes(epoch=0)
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
@@ -91,11 +92,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
# set everything to 'value' and do query for elements
grid = VMFactorizedVoxelGrid(
n_features=n_features,
resolution=resolution,
resolution_changes={0: resolution},
n_components=n_components,
distribution_of_components=distribution,
)
shapes = grid.get_shapes()
shapes = grid.get_shapes(epoch=0)
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
)
@@ -118,8 +119,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
n_points: int = 1,
) -> None:
# set everything to 'value' and do query for elements
grid = FullResolutionVoxelGrid(n_features=n_features, resolution=resolution)
shapes = grid.get_shapes()
grid = FullResolutionVoxelGrid(
n_features=n_features, resolution_changes={0: resolution}
)
shapes = grid.get_shapes(epoch=0)
params = grid.values_type(
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
)
@@ -329,8 +332,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
test querying the voxel grids on some float positions
"""
with self.subTest("FullResolution"):
grid = FullResolutionVoxelGrid(n_features=1, resolution=(1, 1, 1))
params = grid.values_type(**grid.get_shapes())
grid = FullResolutionVoxelGrid(
n_features=1, resolution_changes={0: (1, 1, 1)}
)
params = grid.values_type(**grid.get_shapes(epoch=0))
params.voxel_grid = torch.tensor(
[
[
@@ -377,9 +382,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
), grid.evaluate_local(points, params)
with self.subTest("CP"):
grid = CPFactorizedVoxelGrid(
n_features=1, resolution=(1, 1, 1), n_components=3
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
)
params = grid.values_type(**grid.get_shapes())
params = grid.values_type(**grid.get_shapes(epoch=0))
params.vector_components_x = torch.tensor(
[
[[1, 2], [10.5, 20.5]],
@@ -453,9 +458,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
with self.subTest("VM"):
grid = VMFactorizedVoxelGrid(
n_features=1, resolution=(1, 1, 1), n_components=3
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
)
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes())
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes(epoch=0))
params.matrix_components_xy = torch.tensor(
[
[[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]],
@@ -555,7 +560,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
with self.subTest(cls.__name__):
n_grids = 3
grid = cls(**kwargs)
shapes = grid.get_shapes()
shapes = grid.get_shapes(epoch=0)
params = cls.values_type(
**{
k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001)
@@ -570,18 +575,18 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
test(
FullResolutionVoxelGrid,
resolution=(4, 6, 9),
resolution_changes={0: (4, 6, 9)},
n_features=10,
)
test(
CPFactorizedVoxelGrid,
resolution=(4, 6, 9),
resolution_changes={0: (4, 6, 9)},
n_features=10,
n_components=3,
)
test(
VMFactorizedVoxelGrid,
resolution=(4, 6, 9),
resolution_changes={0: (4, 6, 9)},
n_features=10,
n_components=3,
)
@@ -609,3 +614,105 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
rtol=0.0001,
)
def test_resolution_change(self, n_times=10):
for _ in range(n_times):
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
resolution = torch.randint(3, 10, (3,)).tolist()
resolution2 = torch.randint(3, 10, (3,)).tolist()
resolution_changes = {0: resolution, 1: resolution2}
n_components *= 3
for cls, kwargs in (
(
FullResolutionVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
},
),
(
CPFactorizedVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
"n_components": n_components,
},
),
(
VMFactorizedVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
"n_components": n_components,
},
),
):
with self.subTest(cls.__name__):
grid = cls(**kwargs)
self.assertEqual(grid.get_resolution(epoch=0), resolution)
shapes = grid.get_shapes(epoch=0)
params = {
name: torch.randn((n_grids, *shape))
for name, shape in shapes.items()
}
grid_values = grid.values_type(**params)
grid_values_changed_resolution, change = grid.change_resolution(
epoch=1,
grid_values=grid_values,
mode="linear",
)
assert change
self.assertEqual(grid.get_resolution(epoch=1), resolution2)
shapes_changed_resolution = grid.get_shapes(epoch=1)
for name, expected_shape in shapes_changed_resolution.items():
shape = getattr(grid_values_changed_resolution, name).shape
self.assertEqual(expected_shape, shape[1:])
with self.subTest("VoxelGridModule"):
n_changes = 10
grid = VoxelGridModule()
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
grid.voxel_grid = FullResolutionVoxelGrid(
resolution_changes=resolution_changes
)
epochs, apply_func = grid.subscribe_to_epochs()
self.assertEqual(list(range(n_changes)), list(epochs))
for epoch in epochs:
change = apply_func(epoch)
assert change
self.assertEqual(
resolution_changes[epoch],
grid.voxel_grid.get_resolution(epoch=epoch),
)
def test_loading_state_dict(self):
"""
Test loading state dict after rescaling.
Create a voxel grid, rescale it and get the state_dict.
Create a new voxel grid with the same args as the first one and load
the state_dict and check if everything is ok.
"""
n_changes = 10
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
cfg = DictConfig(
{
"voxel_grid_class_type": "VMFactorizedVoxelGrid",
"voxel_grid_VMFactorizedVoxelGrid_args": {
"resolution_changes": resolution_changes,
"n_components": 48,
},
}
)
grid = VoxelGridModule(**cfg)
epochs, apply_func = grid.subscribe_to_epochs()
for epoch in epochs:
apply_func(epoch)
loaded_grid = VoxelGridModule(**cfg)
loaded_grid.load_state_dict(grid.state_dict())
for name_loaded, param_loaded in loaded_grid.named_parameters():
for name, param in grid.named_parameters():
if name_loaded == name:
torch.allclose(param_loaded, param)