mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Summary: Converts the directory specified to use the Ruff formatter in pyfmt ruff_dog If this diff causes merge conflicts when rebasing, please run `hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt` on your diff, and amend any changes before rebasing onto latest. That should help reduce or eliminate any merge conflicts. allow-large-files Reviewed By: bottler Differential Revision: D66472063 fbshipit-source-id: 35841cb397e4f8e066e2159550d2f56b403b1bef
860 lines
31 KiB
Python
860 lines
31 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import unittest
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
from omegaconf import DictConfig, OmegaConf
|
|
|
|
from pytorch3d.implicitron.models.implicit_function.utils import (
|
|
interpolate_line,
|
|
interpolate_plane,
|
|
interpolate_volume,
|
|
)
|
|
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
|
|
CPFactorizedVoxelGrid,
|
|
FullResolutionVoxelGrid,
|
|
VMFactorizedVoxelGrid,
|
|
VoxelGridModule,
|
|
)
|
|
|
|
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
|
from tests.common_testing import TestCaseMixin
|
|
|
|
|
|
class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
|
"""
|
|
Tests Voxel grids, tests them by setting all elements to zero (after retrieving
|
|
they should also return zero) and by setting all of the elements to one and
|
|
getting the result. Also tests the interpolation by 'manually' interpolating
|
|
one by one sample and comparing with the batched implementation.
|
|
"""
|
|
|
|
def get_random_normalized_points(
|
|
self, n_grids, n_points=None, dimension=3
|
|
) -> torch.Tensor:
|
|
middle_shape = torch.randint(1, 4, tuple(torch.randint(1, 5, size=(1,))))
|
|
# create random query points
|
|
return (
|
|
torch.rand(
|
|
n_grids, *(middle_shape if n_points is None else [n_points]), dimension
|
|
)
|
|
* 2
|
|
- 1
|
|
)
|
|
|
|
def _test_query_with_constant_init_cp(
|
|
self,
|
|
n_grids: int,
|
|
n_features: int,
|
|
n_components: int,
|
|
resolution: Tuple[int],
|
|
value: float = 1,
|
|
) -> None:
|
|
# set everything to 'value' and do query for elementsthe result should
|
|
# be of shape (n_grids, n_points, n_features) and be filled with n_components
|
|
# * value
|
|
grid = CPFactorizedVoxelGrid(
|
|
resolution_changes={0: resolution},
|
|
n_components=n_components,
|
|
n_features=n_features,
|
|
)
|
|
shapes = grid.get_shapes(epoch=0)
|
|
|
|
params = grid.values_type(
|
|
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
|
)
|
|
points = self.get_random_normalized_points(n_grids)
|
|
assert torch.allclose(
|
|
grid.evaluate_local(points, params),
|
|
torch.ones(n_grids, *points.shape[1:-1], n_features) * n_components * value,
|
|
rtol=0.0001,
|
|
)
|
|
|
|
def _test_query_with_constant_init_vm(
|
|
self,
|
|
n_grids: int,
|
|
n_features: int,
|
|
resolution: Tuple[int],
|
|
n_components: Optional[int] = None,
|
|
distribution: Optional[Tuple[int]] = None,
|
|
value: float = 1,
|
|
n_points: int = 1,
|
|
) -> None:
|
|
# set everything to 'value' and do query for elements
|
|
grid = VMFactorizedVoxelGrid(
|
|
n_features=n_features,
|
|
resolution_changes={0: resolution},
|
|
n_components=n_components,
|
|
distribution_of_components=distribution,
|
|
)
|
|
shapes = grid.get_shapes(epoch=0)
|
|
params = grid.values_type(
|
|
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
|
)
|
|
|
|
expected_element = (
|
|
n_components * value if distribution is None else sum(distribution) * value
|
|
)
|
|
points = self.get_random_normalized_points(n_grids)
|
|
assert torch.allclose(
|
|
grid.evaluate_local(points, params),
|
|
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
|
|
)
|
|
|
|
def _test_query_with_constant_init_full(
|
|
self,
|
|
n_grids: int,
|
|
n_features: int,
|
|
resolution: Tuple[int],
|
|
value: int = 1,
|
|
n_points: int = 1,
|
|
) -> None:
|
|
# set everything to 'value' and do query for elements
|
|
grid = FullResolutionVoxelGrid(
|
|
n_features=n_features, resolution_changes={0: resolution}
|
|
)
|
|
shapes = grid.get_shapes(epoch=0)
|
|
params = grid.values_type(
|
|
**{k: torch.ones(n_grids, *shapes[k]) * value for k in shapes}
|
|
)
|
|
|
|
expected_element = value
|
|
points = self.get_random_normalized_points(n_grids)
|
|
assert torch.allclose(
|
|
grid.evaluate_local(points, params),
|
|
torch.ones(n_grids, *points.shape[1:-1], n_features) * expected_element,
|
|
)
|
|
|
|
def test_query_with_constant_init(self):
|
|
with self.subTest("Full"):
|
|
self._test_query_with_constant_init_full(
|
|
n_grids=5, n_features=6, resolution=(3, 4, 5)
|
|
)
|
|
with self.subTest("Full with 1 in dimensions"):
|
|
self._test_query_with_constant_init_full(
|
|
n_grids=5, n_features=1, resolution=(33, 41, 1)
|
|
)
|
|
with self.subTest("CP"):
|
|
self._test_query_with_constant_init_cp(
|
|
n_grids=5,
|
|
n_features=6,
|
|
n_components=7,
|
|
resolution=(3, 4, 5),
|
|
)
|
|
with self.subTest("CP with 1 in dimensions"):
|
|
self._test_query_with_constant_init_cp(
|
|
n_grids=2,
|
|
n_features=1,
|
|
n_components=3,
|
|
resolution=(3, 1, 1),
|
|
)
|
|
with self.subTest("VM with symetric distribution"):
|
|
self._test_query_with_constant_init_vm(
|
|
n_grids=6,
|
|
n_features=9,
|
|
resolution=(2, 12, 2),
|
|
n_components=12,
|
|
)
|
|
with self.subTest("VM with distribution"):
|
|
self._test_query_with_constant_init_vm(
|
|
n_grids=5,
|
|
n_features=1,
|
|
resolution=(5, 9, 7),
|
|
distribution=(33, 41, 1),
|
|
)
|
|
|
|
def test_query_with_zero_init(self):
|
|
with self.subTest("Query testing with zero init CPFactorizedVoxelGrid"):
|
|
self._test_query_with_constant_init_cp(
|
|
n_grids=5,
|
|
n_features=6,
|
|
n_components=7,
|
|
resolution=(3, 2, 5),
|
|
value=0,
|
|
)
|
|
with self.subTest("Query testing with zero init VMFactorizedVoxelGrid"):
|
|
self._test_query_with_constant_init_vm(
|
|
n_grids=2,
|
|
n_features=9,
|
|
resolution=(2, 11, 3),
|
|
n_components=3,
|
|
value=0,
|
|
)
|
|
with self.subTest("Query testing with zero init FullResolutionVoxelGrid"):
|
|
self._test_query_with_constant_init_full(
|
|
n_grids=4, n_features=2, resolution=(3, 3, 5), value=0
|
|
)
|
|
|
|
def setUp(self):
|
|
torch.manual_seed(42)
|
|
expand_args_fields(FullResolutionVoxelGrid)
|
|
expand_args_fields(CPFactorizedVoxelGrid)
|
|
expand_args_fields(VMFactorizedVoxelGrid)
|
|
expand_args_fields(VoxelGridModule)
|
|
|
|
def _interpolate_1D(
|
|
self, points: torch.Tensor, vectors: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
interpolate vector from points, which are (batch, 1) and individual point is in [-1, 1]
|
|
"""
|
|
result = []
|
|
_, _, width = vectors.shape
|
|
# transform from [-1, 1] to [0, width-1]
|
|
points = (points + 1) / 2 * (width - 1)
|
|
for vector, row in zip(vectors, points):
|
|
newrow = []
|
|
for x in row:
|
|
xf, xc = int(torch.floor(x)), int(torch.ceil(x))
|
|
itemf, itemc = vector[:, xf], vector[:, xc]
|
|
tmp = itemf * (xc - x) + itemc * (x - xf)
|
|
newrow.append(tmp[None, None, :])
|
|
result.append(torch.cat(newrow, dim=1))
|
|
return torch.cat(result)
|
|
|
|
def _interpolate_2D(
|
|
self, points: torch.Tensor, matrices: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
interpolate matrix from points, which are (batch, 2) and individual point is in [-1, 1]
|
|
"""
|
|
result = []
|
|
n_grids, _, width, height = matrices.shape
|
|
points = (points + 1) / 2 * (torch.tensor([[[width, height]]]) - 1)
|
|
for matrix, row in zip(matrices, points):
|
|
newrow = []
|
|
for x, y in row:
|
|
xf, xc = int(torch.floor(x)), int(torch.ceil(x))
|
|
yf, yc = int(torch.floor(y)), int(torch.ceil(y))
|
|
itemff, itemfc = matrix[:, xf, yf], matrix[:, xf, yc]
|
|
itemcf, itemcc = matrix[:, xc, yf], matrix[:, xc, yc]
|
|
itemf = itemff * (xc - x) + itemcf * (x - xf)
|
|
itemc = itemfc * (xc - x) + itemcc * (x - xf)
|
|
tmp = itemf * (yc - y) + itemc * (y - yf)
|
|
newrow.append(tmp[None, None, :])
|
|
result.append(torch.cat(newrow, dim=1))
|
|
return torch.cat(result)
|
|
|
|
def _interpolate_3D(
|
|
self, points: torch.Tensor, tensors: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""
|
|
interpolate tensors from points, which are (batch, 3) and individual point is in [-1, 1]
|
|
"""
|
|
result = []
|
|
_, _, width, height, depth = tensors.shape
|
|
batch_normalized_points = (
|
|
(points + 1) / 2 * (torch.tensor([[[width, height, depth]]]) - 1)
|
|
)
|
|
batch_points = points
|
|
|
|
for tensor, points, normalized_points in zip(
|
|
tensors, batch_points, batch_normalized_points
|
|
):
|
|
newrow = []
|
|
for (x, y, z), (_, _, nz) in zip(points, normalized_points):
|
|
zf, zc = int(torch.floor(nz)), int(torch.ceil(nz))
|
|
itemf = self._interpolate_2D(
|
|
points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zf]
|
|
)
|
|
itemc = self._interpolate_2D(
|
|
points=torch.tensor([[[x, y]]]), matrices=tensor[None, :, :, :, zc]
|
|
)
|
|
tmp = self._interpolate_1D(
|
|
points=torch.tensor([[[z]]]),
|
|
vectors=torch.cat((itemf, itemc), dim=1).permute(0, 2, 1),
|
|
)
|
|
newrow.append(tmp)
|
|
result.append(torch.cat(newrow, dim=1))
|
|
return torch.cat(result)
|
|
|
|
def test_interpolation(self):
|
|
with self.subTest("1D interpolation"):
|
|
points = self.get_random_normalized_points(
|
|
n_grids=4, n_points=5, dimension=1
|
|
)
|
|
vector = torch.randn(size=(4, 3, 2))
|
|
assert torch.allclose(
|
|
self._interpolate_1D(points, vector),
|
|
interpolate_line(
|
|
points,
|
|
vector,
|
|
align_corners=True,
|
|
padding_mode="zeros",
|
|
mode="bilinear",
|
|
),
|
|
rtol=0.0001,
|
|
atol=0.0001,
|
|
)
|
|
with self.subTest("2D interpolation"):
|
|
points = self.get_random_normalized_points(
|
|
n_grids=4, n_points=5, dimension=2
|
|
)
|
|
matrix = torch.randn(size=(4, 2, 3, 5))
|
|
assert torch.allclose(
|
|
self._interpolate_2D(points, matrix),
|
|
interpolate_plane(
|
|
points,
|
|
matrix,
|
|
align_corners=True,
|
|
padding_mode="zeros",
|
|
mode="bilinear",
|
|
),
|
|
rtol=0.0001,
|
|
atol=0.0001,
|
|
)
|
|
|
|
with self.subTest("3D interpolation"):
|
|
points = self.get_random_normalized_points(
|
|
n_grids=4, n_points=5, dimension=3
|
|
)
|
|
tensor = torch.randn(size=(4, 5, 2, 7, 2))
|
|
assert torch.allclose(
|
|
self._interpolate_3D(points, tensor),
|
|
interpolate_volume(
|
|
points,
|
|
tensor,
|
|
align_corners=True,
|
|
padding_mode="zeros",
|
|
mode="bilinear",
|
|
),
|
|
rtol=0.0001,
|
|
atol=0.0001,
|
|
)
|
|
|
|
def test_floating_point_query(self):
|
|
"""
|
|
test querying the voxel grids on some float positions
|
|
"""
|
|
with self.subTest("FullResolution"):
|
|
grid = FullResolutionVoxelGrid(
|
|
n_features=1, resolution_changes={0: (1, 1, 1)}
|
|
)
|
|
params = grid.values_type(**grid.get_shapes(epoch=0))
|
|
params.voxel_grid = torch.tensor(
|
|
[
|
|
[
|
|
[[[1, 3], [5, 7]], [[9, 11], [13, 15]]],
|
|
[[[2, 4], [6, 8]], [[10, 12], [14, 16]]],
|
|
],
|
|
[
|
|
[[[17, 18], [19, 20]], [[21, 22], [23, 24]]],
|
|
[[[25, 26], [27, 28]], [[29, 30], [31, 32]]],
|
|
],
|
|
],
|
|
dtype=torch.float,
|
|
)
|
|
points = (
|
|
torch.tensor(
|
|
[
|
|
[
|
|
[1, 0, 1],
|
|
[0.5, 1, 1],
|
|
[1 / 3, 1 / 3, 2 / 3],
|
|
],
|
|
[
|
|
[0, 1, 1],
|
|
[0, 0.5, 1],
|
|
[1 / 4, 1 / 4, 3 / 4],
|
|
],
|
|
]
|
|
)
|
|
/ torch.tensor([[1.0, 1, 1]])
|
|
* 2
|
|
- 1
|
|
)
|
|
expected_result = torch.tensor(
|
|
[
|
|
[[11, 12], [11, 12], [6.333333, 7.3333333]],
|
|
[[20, 28], [19, 27], [19.25, 27.25]],
|
|
]
|
|
)
|
|
|
|
assert torch.allclose(
|
|
grid.evaluate_local(points, params),
|
|
expected_result,
|
|
rtol=0.0001,
|
|
atol=0.0001,
|
|
), grid.evaluate_local(points, params)
|
|
with self.subTest("CP"):
|
|
grid = CPFactorizedVoxelGrid(
|
|
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
|
|
)
|
|
params = grid.values_type(**grid.get_shapes(epoch=0))
|
|
params.vector_components_x = torch.tensor(
|
|
[
|
|
[[1, 2], [10.5, 20.5]],
|
|
[[10, 20], [2, 4]],
|
|
]
|
|
)
|
|
params.vector_components_y = torch.tensor(
|
|
[
|
|
[[3, 4, 5], [30.5, 40.5, 50.5]],
|
|
[[30, 40, 50], [1, 3, 5]],
|
|
]
|
|
)
|
|
params.vector_components_z = torch.tensor(
|
|
[
|
|
[[6, 7, 8, 9], [60.5, 70.5, 80.5, 90.5]],
|
|
[[60, 70, 80, 90], [6, 7, 8, 9]],
|
|
]
|
|
)
|
|
params.basis_matrix = torch.tensor(
|
|
[
|
|
[[2.0], [2.0]],
|
|
[[1.0], [2.0]],
|
|
]
|
|
)
|
|
points = (
|
|
torch.tensor(
|
|
[
|
|
[
|
|
[0, 2, 2],
|
|
[1, 2, 0.25],
|
|
[0.5, 0.5, 1],
|
|
[1 / 3, 2 / 3, 2 + 1 / 3],
|
|
],
|
|
[
|
|
[1, 0, 1],
|
|
[0.5, 2, 2],
|
|
[1, 0.5, 0.5],
|
|
[1 / 4, 3 / 4, 2 + 1 / 4],
|
|
],
|
|
]
|
|
)
|
|
/ torch.tensor([[[1.0, 2, 3]]])
|
|
* 2
|
|
- 1
|
|
)
|
|
expected_result_matrix = torch.tensor(
|
|
[
|
|
[[85450.25], [130566.5], [77658.75], [86285.422]],
|
|
[[42056], [60240], [45604], [38775]],
|
|
]
|
|
)
|
|
expected_result_sum = torch.tensor(
|
|
[
|
|
[[42725.125], [65283.25], [38829.375], [43142.711]],
|
|
[[42028], [60120], [45552], [38723.4375]],
|
|
]
|
|
)
|
|
with self.subTest("CP with basis_matrix reduction"):
|
|
assert torch.allclose(
|
|
grid.evaluate_local(points, params),
|
|
expected_result_matrix,
|
|
rtol=0.0001,
|
|
atol=0.0001,
|
|
)
|
|
del params.basis_matrix
|
|
with self.subTest("CP with sum reduction"):
|
|
assert torch.allclose(
|
|
grid.evaluate_local(points, params),
|
|
expected_result_sum,
|
|
rtol=0.0001,
|
|
atol=0.0001,
|
|
)
|
|
|
|
with self.subTest("VM"):
|
|
grid = VMFactorizedVoxelGrid(
|
|
n_features=1, resolution_changes={0: (1, 1, 1)}, n_components=3
|
|
)
|
|
params = VMFactorizedVoxelGrid.values_type(**grid.get_shapes(epoch=0))
|
|
params.matrix_components_xy = torch.tensor(
|
|
[
|
|
[[[1, 2], [3, 4]], [[19, 20], [21, 22.0]]],
|
|
[[[35, 36], [37, 38]], [[39, 40], [41, 42]]],
|
|
]
|
|
)
|
|
params.matrix_components_xz = torch.tensor(
|
|
[
|
|
[[[7, 8], [9, 10]], [[25, 26], [27, 28.0]]],
|
|
[[[43, 44], [45, 46]], [[47, 48], [49, 50]]],
|
|
]
|
|
)
|
|
params.matrix_components_yz = torch.tensor(
|
|
[
|
|
[[[13, 14], [15, 16]], [[31, 32], [33, 34.0]]],
|
|
[[[51, 52], [53, 54]], [[55, 56], [57, 58.0]]],
|
|
]
|
|
)
|
|
|
|
params.vector_components_z = torch.tensor(
|
|
[
|
|
[[5, 6], [23, 24.0]],
|
|
[[59, 60], [61, 62]],
|
|
]
|
|
)
|
|
params.vector_components_y = torch.tensor(
|
|
[
|
|
[[11, 12], [29, 30.0]],
|
|
[[63, 64], [65, 66]],
|
|
]
|
|
)
|
|
params.vector_components_x = torch.tensor(
|
|
[
|
|
[[17, 18], [35, 36.0]],
|
|
[[67, 68], [69, 70.0]],
|
|
]
|
|
)
|
|
|
|
params.basis_matrix = torch.tensor(
|
|
[
|
|
[2, 2, 2, 2, 2, 2.0],
|
|
[1, 2, 1, 2, 1, 2.0],
|
|
]
|
|
)[:, :, None]
|
|
points = (
|
|
torch.tensor(
|
|
[
|
|
[
|
|
[1, 0, 1],
|
|
[0.5, 1, 1],
|
|
[1 / 3, 1 / 3, 2 / 3],
|
|
],
|
|
[
|
|
[0, 1, 0],
|
|
[0, 0, 0],
|
|
[0, 1, 0],
|
|
],
|
|
]
|
|
)
|
|
/ torch.tensor([[[1.0, 1, 1]]])
|
|
* 2
|
|
- 1
|
|
)
|
|
expected_result_matrix = torch.tensor(
|
|
[
|
|
[[5696], [5854], [5484.888]],
|
|
[[27377], [26649], [27377]],
|
|
]
|
|
)
|
|
expected_result_sum = torch.tensor(
|
|
[
|
|
[[2848], [2927], [2742.444]],
|
|
[[17902], [17420], [17902]],
|
|
]
|
|
)
|
|
with self.subTest("VM with basis_matrix reduction"):
|
|
assert torch.allclose(
|
|
grid.evaluate_local(points, params),
|
|
expected_result_matrix,
|
|
rtol=0.0001,
|
|
atol=0.0001,
|
|
)
|
|
del params.basis_matrix
|
|
with self.subTest("VM with sum reduction"):
|
|
assert torch.allclose(
|
|
grid.evaluate_local(points, params),
|
|
expected_result_sum,
|
|
rtol=0.0001,
|
|
atol=0.0001,
|
|
), grid.evaluate_local(points, params)
|
|
|
|
def test_forward_with_small_init_std(self):
|
|
"""
|
|
Test does the grid return small values if it is initialized with small
|
|
mean and small standard deviation.
|
|
"""
|
|
|
|
def test(cls, **kwargs):
|
|
with self.subTest(cls.__name__):
|
|
n_grids = 3
|
|
grid = cls(**kwargs)
|
|
shapes = grid.get_shapes(epoch=0)
|
|
params = cls.values_type(
|
|
**{
|
|
k: torch.normal(mean=torch.zeros(n_grids, *shape), std=0.0001)
|
|
for k, shape in shapes.items()
|
|
}
|
|
)
|
|
points = self.get_random_normalized_points(n_grids=n_grids, n_points=3)
|
|
max_expected_result = torch.zeros((len(points), 10)) + 1e-2
|
|
assert torch.all(
|
|
grid.evaluate_local(points, params) < max_expected_result
|
|
)
|
|
|
|
test(
|
|
FullResolutionVoxelGrid,
|
|
resolution_changes={0: (4, 6, 9)},
|
|
n_features=10,
|
|
)
|
|
test(
|
|
CPFactorizedVoxelGrid,
|
|
resolution_changes={0: (4, 6, 9)},
|
|
n_features=10,
|
|
n_components=3,
|
|
)
|
|
test(
|
|
VMFactorizedVoxelGrid,
|
|
resolution_changes={0: (4, 6, 9)},
|
|
n_features=10,
|
|
n_components=3,
|
|
)
|
|
|
|
def test_voxel_grid_module_location(self, n_times=10):
|
|
"""
|
|
This checks the module uses locator correctly etc..
|
|
|
|
If we know that voxel grids work for (x, y, z) in local coordinates
|
|
to test if the VoxelGridModule does not have permuted dimensions we
|
|
create local coordinates, pass them through verified voxelgrids and
|
|
compare the result with the result that we get when we convert
|
|
coordinates to world and pass them through the VoxelGridModule
|
|
"""
|
|
for _ in range(n_times):
|
|
extents = tuple(torch.randint(1, 50, size=(3,)).tolist())
|
|
|
|
grid = VoxelGridModule(extents=extents)
|
|
local_point = torch.rand(1, 3) * 2 - 1
|
|
world_point = local_point * torch.tensor(extents) / 2
|
|
grid_values = grid.voxel_grid.values_type(**grid.params)
|
|
|
|
assert torch.allclose(
|
|
grid(world_point)[0, 0],
|
|
grid.voxel_grid.evaluate_local(local_point[None], grid_values)[0, 0, 0],
|
|
rtol=0.0001,
|
|
atol=0.0001,
|
|
)
|
|
|
|
def test_resolution_change(self, n_times=10):
|
|
for _ in range(n_times):
|
|
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
|
|
resolution = torch.randint(3, 10, (3,)).tolist()
|
|
resolution2 = torch.randint(3, 10, (3,)).tolist()
|
|
resolution_changes = {0: resolution, 1: resolution2}
|
|
n_components *= 3
|
|
for cls, kwargs in (
|
|
(
|
|
FullResolutionVoxelGrid,
|
|
{
|
|
"n_features": n_features,
|
|
"resolution_changes": resolution_changes,
|
|
},
|
|
),
|
|
(
|
|
CPFactorizedVoxelGrid,
|
|
{
|
|
"n_features": n_features,
|
|
"resolution_changes": resolution_changes,
|
|
"n_components": n_components,
|
|
},
|
|
),
|
|
(
|
|
VMFactorizedVoxelGrid,
|
|
{
|
|
"n_features": n_features,
|
|
"resolution_changes": resolution_changes,
|
|
"n_components": n_components,
|
|
},
|
|
),
|
|
):
|
|
with self.subTest(cls.__name__):
|
|
grid = cls(**kwargs)
|
|
self.assertEqual(grid.get_resolution(epoch=0), resolution)
|
|
shapes = grid.get_shapes(epoch=0)
|
|
params = {
|
|
name: torch.randn((n_grids, *shape))
|
|
for name, shape in shapes.items()
|
|
}
|
|
grid_values = grid.values_type(**params)
|
|
grid_values_changed_resolution, change = grid.change_resolution(
|
|
epoch=1,
|
|
grid_values=grid_values,
|
|
mode="linear",
|
|
)
|
|
assert change
|
|
self.assertEqual(grid.get_resolution(epoch=1), resolution2)
|
|
shapes_changed_resolution = grid.get_shapes(epoch=1)
|
|
for name, expected_shape in shapes_changed_resolution.items():
|
|
shape = getattr(grid_values_changed_resolution, name).shape
|
|
self.assertEqual(expected_shape, shape[1:])
|
|
|
|
with self.subTest("VoxelGridModule"):
|
|
n_changes = 10
|
|
grid = VoxelGridModule()
|
|
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
|
|
grid.voxel_grid = FullResolutionVoxelGrid(
|
|
resolution_changes=resolution_changes
|
|
)
|
|
epochs, apply_func = grid.subscribe_to_epochs()
|
|
self.assertEqual(list(range(n_changes)), list(epochs))
|
|
for epoch in epochs:
|
|
change = apply_func(epoch)
|
|
assert change
|
|
self.assertEqual(
|
|
resolution_changes[epoch],
|
|
grid.voxel_grid.get_resolution(epoch=epoch),
|
|
)
|
|
|
|
def _get_min_max_tuple(
|
|
self, n=4, denominator_base=2, max_exponent=6, add_edge_cases=True
|
|
):
|
|
if add_edge_cases:
|
|
n -= 2
|
|
|
|
def get_pair():
|
|
def get_one():
|
|
sign = -1 if torch.rand((1,)) < 0.5 else 1
|
|
exponent = int(torch.randint(1, max_exponent, (1,)))
|
|
denominator = denominator_base**exponent
|
|
numerator = int(torch.randint(1, denominator, (1,)))
|
|
return sign * numerator / denominator * 1.0
|
|
|
|
while True:
|
|
a, b = get_one(), get_one()
|
|
if a < b:
|
|
return a, b
|
|
|
|
for _ in range(n):
|
|
a, b, c = get_pair(), get_pair(), get_pair()
|
|
yield torch.tensor((a[0], b[0], c[0])), torch.tensor((a[1], b[1], c[1]))
|
|
if add_edge_cases:
|
|
yield torch.tensor((-1.0, -1.0, -1.0)), torch.tensor((1.0, 1.0, 1.0))
|
|
yield torch.tensor([0.0, 0.0, 0.0]), torch.tensor([1.0, 1.0, 1.0])
|
|
|
|
def test_cropping_voxel_grids(self, n_times=1):
|
|
"""
|
|
If the grid is 1d and we crop at A and B
|
|
---------A---------B---
|
|
and choose point p between them
|
|
---------A-----p---B---
|
|
it can be represented as
|
|
p = A + (B-A) * p_c
|
|
where p_c is local coordinate of p in cropped grid. So we now just see
|
|
if grid evaluated at p and cropped grid evaluated at p_c agree.
|
|
"""
|
|
for points_min, points_max in self._get_min_max_tuple(n=10):
|
|
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
|
|
n_grids = 1
|
|
n_components *= 3
|
|
resolution_changes = {0: (128 + 1, 128 + 1, 128 + 1)}
|
|
for cls, kwargs in (
|
|
(
|
|
FullResolutionVoxelGrid,
|
|
{
|
|
"n_features": n_features,
|
|
"resolution_changes": resolution_changes,
|
|
},
|
|
),
|
|
(
|
|
CPFactorizedVoxelGrid,
|
|
{
|
|
"n_features": n_features,
|
|
"resolution_changes": resolution_changes,
|
|
"n_components": n_components,
|
|
},
|
|
),
|
|
(
|
|
VMFactorizedVoxelGrid,
|
|
{
|
|
"n_features": n_features,
|
|
"resolution_changes": resolution_changes,
|
|
"n_components": n_components,
|
|
},
|
|
),
|
|
):
|
|
with self.subTest(
|
|
cls.__name__ + f" points {points_min} and {points_max}"
|
|
):
|
|
grid = cls(**kwargs)
|
|
shapes = grid.get_shapes(epoch=0)
|
|
params = {
|
|
name: torch.normal(
|
|
mean=torch.zeros((n_grids, *shape)),
|
|
std=1,
|
|
)
|
|
for name, shape in shapes.items()
|
|
}
|
|
grid_values = grid.values_type(**params)
|
|
|
|
grid_values_cropped = grid.crop_local(
|
|
points_min, points_max, grid_values
|
|
)
|
|
|
|
points_local_cropped = torch.rand((1, n_times, 3))
|
|
points_local = (
|
|
points_min[None, None]
|
|
+ (points_max - points_min)[None, None] * points_local_cropped
|
|
)
|
|
points_local_cropped = (points_local_cropped - 0.5) * 2
|
|
|
|
pred = grid.evaluate_local(points_local, grid_values)
|
|
pred_cropped = grid.evaluate_local(
|
|
points_local_cropped, grid_values_cropped
|
|
)
|
|
|
|
assert torch.allclose(pred, pred_cropped, rtol=1e-4, atol=1e-4), (
|
|
pred,
|
|
pred_cropped,
|
|
points_local,
|
|
points_local_cropped,
|
|
)
|
|
|
|
def test_cropping_voxel_grid_module(self, n_times=1):
|
|
for points_min, points_max in self._get_min_max_tuple(n=5, max_exponent=5):
|
|
extents = torch.ones((3,)) * 2
|
|
translation = torch.ones((3,)) * 0.2
|
|
points_min += translation
|
|
points_max += translation
|
|
|
|
default_cfg = get_default_args(VoxelGridModule)
|
|
custom_cfg = DictConfig(
|
|
{
|
|
"extents": tuple(float(e) for e in extents),
|
|
"translation": tuple(float(t) for t in translation),
|
|
"voxel_grid_FullResolutionVoxelGrid_args": {
|
|
"resolution_changes": {0: (128 + 1, 128 + 1, 128 + 1)}
|
|
},
|
|
}
|
|
)
|
|
cfg = OmegaConf.merge(default_cfg, custom_cfg)
|
|
grid = VoxelGridModule(**cfg)
|
|
|
|
points = (torch.rand(3) * (points_max - points_min) + points_min)[None]
|
|
result = grid(points)
|
|
grid.crop_self(points_min, points_max)
|
|
result_cropped = grid(points)
|
|
|
|
assert torch.allclose(result, result_cropped, rtol=0.001, atol=0.001), (
|
|
result,
|
|
result_cropped,
|
|
)
|
|
|
|
def test_loading_state_dict(self):
|
|
"""
|
|
Test loading state dict after rescaling.
|
|
|
|
Create a voxel grid, rescale it and get the state_dict.
|
|
Create a new voxel grid with the same args as the first one and load
|
|
the state_dict and check if everything is ok.
|
|
"""
|
|
n_changes = 10
|
|
|
|
resolution_changes = {i: (i + 2, i + 2, i + 2) for i in range(n_changes)}
|
|
cfg = DictConfig(
|
|
{
|
|
"voxel_grid_class_type": "VMFactorizedVoxelGrid",
|
|
"voxel_grid_VMFactorizedVoxelGrid_args": {
|
|
"resolution_changes": resolution_changes,
|
|
"n_components": 48,
|
|
},
|
|
}
|
|
)
|
|
grid = VoxelGridModule(**cfg)
|
|
epochs, apply_func = grid.subscribe_to_epochs()
|
|
for epoch in epochs:
|
|
apply_func(epoch)
|
|
|
|
loaded_grid = VoxelGridModule(**cfg)
|
|
loaded_grid.load_state_dict(grid.state_dict())
|
|
for name_loaded, param_loaded in loaded_grid.named_parameters():
|
|
for name, param in grid.named_parameters():
|
|
if name_loaded == name:
|
|
torch.allclose(param_loaded, param)
|