mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-18 21:30:35 +08:00
Add rescaling to voxel grids
Summary: Any module can be subscribed to step updates from the training loop. Once the training loop publishes a step the voxel grid changes its dimensions. During the construction of VoxelGridModule and its parameters it does not know which is the resolution that will be loaded from checkpoint, so before the checkpoint loading a hook runs which changes the VoxelGridModule's parameters to match shapes of the loaded checkpoint. Reviewed By: bottler Differential Revision: D39026775 fbshipit-source-id: 0d359ea5c8d2eda11d773d79c7513c83585d5f17
This commit is contained in:
committed by
Facebook GitHub Bot
parent
efea540bbc
commit
5005f09118
@@ -9,6 +9,7 @@ import unittest
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from pytorch3d.implicitron.models.implicit_function.utils import (
|
||||
interpolate_line,
|
||||
@@ -62,11 +63,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
# be of shape (n_grids, n_points, n_features) and be filled with n_components
|
||||
# * value
|
||||
grid = CPFactorizedVoxelGrid(
|
||||
resolution=resolution,
|
||||
resolution_changes={0: resolution},
|
||||
n_components=n_components,
|
||||
n_features=n_features,
|
||||
)
|
||||
shapes = grid.get_shapes()
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
@@ -91,11 +92,11 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
# set everything to 'value' and do query for elements
|
||||
grid = VMFactorizedVoxelGrid(
|
||||
n_features=n_features,
|
||||
resolution=resolution,
|
||||
resolution_changes={0: resolution},
|
||||
n_components=n_components,
|
||||
distribution_of_components=distribution,
|
||||
)
|
||||
shapes = grid.get_shapes()
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
)
|
||||
@@ -118,8 +119,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
n_points: int = 1,
|
||||
) -> None:
|
||||
# set everything to 'value' and do query for elements
|
||||
grid = FullResolutionVoxelGrid(n_features=n_features, resolution=resolution)
|
||||
shapes = grid.get_shapes()
|
||||
grid = FullResolutionVoxelGrid(
|
||||
n_features=n_features, resolution_changes={0: resolution}
|
||||
)
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = grid.values_type(
|
||||
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
||||
)
|
||||
@@ -329,8 +332,10 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
test querying the voxel grids on some float positions
|
||||
"""
|
||||
with self.subTest("FullResolution"):
|
||||
grid = FullResolutionVoxelGrid(n_features=1, resolution=(1, 1, 1))
|
||||
params = grid.values_type(**grid.get_shapes())
|
||||
grid = FullResolutionVoxelGrid(
|
||||
n_features=1, resolution_changes={0: (1, 1, 1)}
|
||||
)
|
||||
params = grid.values_type(**grid.get_shapes(epoch=0))
|
||||
params.voxel_grid = torch.tensor(
|
||||
[
|
||||
[
|
||||
@@ -377,9 +382,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
), grid.evaluate_local(points, params)
|
||||
with self.subTest("CP"):
|
||||
grid = CPFactorizedVoxelGrid(
|
||||
n_features=1, resolution=(1, 1, 1), n_components=3
|
||||
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
|
||||
)
|
||||
params = grid.values_type(**grid.get_shapes())
|
||||
params = grid.values_type(**grid.get_shapes(epoch=0))
|
||||
params.vector_components_x = torch.tensor(
|
||||
[
|
||||
[[1, 2], [10.5, 20.5]],
|
||||
@@ -453,9 +458,9 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
with self.subTest("VM"):
|
||||
grid = VMFactorizedVoxelGrid(
|
||||
n_features=1, resolution=(1, 1, 1), n_components=3
|
||||
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
|
||||
)
|
||||
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes())
|
||||
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes(epoch=0))
|
||||
params.matrix_components_xy = torch.tensor(
|
||||
[
|
||||
[[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]],
|
||||
@@ -555,7 +560,7 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
with self.subTest(cls.__name__):
|
||||
n_grids = 3
|
||||
grid = cls(**kwargs)
|
||||
shapes = grid.get_shapes()
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = cls.values_type(
|
||||
**{
|
||||
k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001)
|
||||
@@ -570,18 +575,18 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
test(
|
||||
FullResolutionVoxelGrid,
|
||||
resolution=(4, 6, 9),
|
||||
resolution_changes={0: (4, 6, 9)},
|
||||
n_features=10,
|
||||
)
|
||||
test(
|
||||
CPFactorizedVoxelGrid,
|
||||
resolution=(4, 6, 9),
|
||||
resolution_changes={0: (4, 6, 9)},
|
||||
n_features=10,
|
||||
n_components=3,
|
||||
)
|
||||
test(
|
||||
VMFactorizedVoxelGrid,
|
||||
resolution=(4, 6, 9),
|
||||
resolution_changes={0: (4, 6, 9)},
|
||||
n_features=10,
|
||||
n_components=3,
|
||||
)
|
||||
@@ -609,3 +614,105 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
|
||||
rtol=0.0001,
|
||||
)
|
||||
|
||||
def test_resolution_change(self, n_times=10):
|
||||
for _ in range(n_times):
|
||||
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
|
||||
resolution = torch.randint(3, 10, (3,)).tolist()
|
||||
resolution2 = torch.randint(3, 10, (3,)).tolist()
|
||||
resolution_changes = {0: resolution, 1: resolution2}
|
||||
n_components *= 3
|
||||
for cls, kwargs in (
|
||||
(
|
||||
FullResolutionVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
},
|
||||
),
|
||||
(
|
||||
CPFactorizedVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": n_components,
|
||||
},
|
||||
),
|
||||
(
|
||||
VMFactorizedVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": n_components,
|
||||
},
|
||||
),
|
||||
):
|
||||
with self.subTest(cls.__name__):
|
||||
grid = cls(**kwargs)
|
||||
self.assertEqual(grid.get_resolution(epoch=0), resolution)
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = {
|
||||
name: torch.randn((n_grids, *shape))
|
||||
for name, shape in shapes.items()
|
||||
}
|
||||
grid_values = grid.values_type(**params)
|
||||
grid_values_changed_resolution, change = grid.change_resolution(
|
||||
epoch=1,
|
||||
grid_values=grid_values,
|
||||
mode="linear",
|
||||
)
|
||||
assert change
|
||||
self.assertEqual(grid.get_resolution(epoch=1), resolution2)
|
||||
shapes_changed_resolution = grid.get_shapes(epoch=1)
|
||||
for name, expected_shape in shapes_changed_resolution.items():
|
||||
shape = getattr(grid_values_changed_resolution, name).shape
|
||||
self.assertEqual(expected_shape, shape[1:])
|
||||
|
||||
with self.subTest("VoxelGridModule"):
|
||||
n_changes = 10
|
||||
grid = VoxelGridModule()
|
||||
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
|
||||
grid.voxel_grid = FullResolutionVoxelGrid(
|
||||
resolution_changes=resolution_changes
|
||||
)
|
||||
epochs, apply_func = grid.subscribe_to_epochs()
|
||||
self.assertEqual(list(range(n_changes)), list(epochs))
|
||||
for epoch in epochs:
|
||||
change = apply_func(epoch)
|
||||
assert change
|
||||
self.assertEqual(
|
||||
resolution_changes[epoch],
|
||||
grid.voxel_grid.get_resolution(epoch=epoch),
|
||||
)
|
||||
|
||||
def test_loading_state_dict(self):
|
||||
"""
|
||||
Test loading state dict after rescaling.
|
||||
|
||||
Create a voxel grid, rescale it and get the state_dict.
|
||||
Create a new voxel grid with the same args as the first one and load
|
||||
the state_dict and check if everything is ok.
|
||||
"""
|
||||
n_changes = 10
|
||||
|
||||
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
|
||||
cfg = DictConfig(
|
||||
{
|
||||
"voxel_grid_class_type": "VMFactorizedVoxelGrid",
|
||||
"voxel_grid_VMFactorizedVoxelGrid_args": {
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": 48,
|
||||
},
|
||||
}
|
||||
)
|
||||
grid = VoxelGridModule(**cfg)
|
||||
epochs, apply_func = grid.subscribe_to_epochs()
|
||||
for epoch in epochs:
|
||||
apply_func(epoch)
|
||||
|
||||
loaded_grid = VoxelGridModule(**cfg)
|
||||
loaded_grid.load_state_dict(grid.state_dict())
|
||||
for name_loaded, param_loaded in loaded_grid.named_parameters():
|
||||
for name, param in grid.named_parameters():
|
||||
if name_loaded == name:
|
||||
torch.allclose(param_loaded, param)
|
||||
|
||||
Reference in New Issue
Block a user