pytorch3d/tests/test_volumes.py
Thomas Polasek 055ab3a2e3 Convert directory fbcode/vision to use the Ruff Formatter
Summary:
Converts the directory specified to use the Ruff formatter in pyfmt

ruff_dog

If this diff causes merge conflicts when rebasing, please run
`hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt`
on your diff, and amend any changes before rebasing onto latest.
That should help reduce or eliminate any merge conflicts.

allow-large-files

Reviewed By: bottler

Differential Revision: D66472063

fbshipit-source-id: 35841cb397e4f8e066e2159550d2f56b403b1bef
2024-11-26 02:38:20 -08:00

967 lines
38 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import itertools
import random
import unittest
import numpy as np
import torch
from pytorch3d.structures.volumes import VolumeLocator, Volumes
from pytorch3d.transforms import Scale
from .common_testing import TestCaseMixin
class TestVolumes(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)
@staticmethod
def _random_volume_list(
num_volumes, min_size, max_size, num_channels, device, rand_sizes=None
):
"""
Init a list of `num_volumes` random tensors of size [num_channels, *rand_size].
If `rand_sizes` is None, rand_size is a 3D long vector sampled
from [min_size, max_size]. Otherwise, rand_size should be a list
[rand_size_1, rand_size_2, ..., rand_size_num_volumes] where each
`rand_size_i` denotes the size of the corresponding `i`-th tensor.
"""
if rand_sizes is None:
rand_sizes = [
[random.randint(min_size, vs) for vs in max_size]
for _ in range(num_volumes)
]
volume_list = [
torch.randn(
size=[num_channels, *rand_size], device=device, dtype=torch.float32
)
for rand_size in rand_sizes
]
return volume_list, rand_sizes
def _check_indexed_volumes(self, v, selected, indices):
for selectedIdx, index in indices:
self.assertClose(selected.densities()[selectedIdx], v.densities()[index])
self.assertClose(
v.locator._local_to_world_transform.get_matrix()[index],
selected.locator._local_to_world_transform.get_matrix()[selectedIdx],
)
if selected.features() is not None:
self.assertClose(selected.features()[selectedIdx], v.features()[index])
def test_get_item(
self,
num_volumes=5,
num_channels=4,
volume_size=(10, 13, 8),
dtype=torch.float32,
):
device = torch.device("cuda:0")
# make sure we have at least 3 volumes to prevent indexing crash
num_volumes = max(num_volumes, 3)
features = torch.randn(
size=[num_volumes, num_channels, *volume_size],
device=device,
dtype=torch.float32,
)
densities = torch.randn(
size=[num_volumes, 1, *volume_size], device=device, dtype=torch.float32
)
features_list, rand_sizes = TestVolumes._random_volume_list(
num_volumes, 3, volume_size, num_channels, device
)
densities_list, _ = TestVolumes._random_volume_list(
num_volumes, 3, volume_size, 1, device, rand_sizes=rand_sizes
)
volume_translation = -torch.randn(num_volumes, 3).type_as(features)
voxel_size = torch.rand(num_volumes, 1).type_as(features) + 0.5
for features_, densities_ in zip(
(None, features, features_list), (densities, densities, densities_list)
):
# init the volume structure
v = Volumes(
features=features_,
densities=densities_,
volume_translation=volume_translation,
voxel_size=voxel_size,
)
# int index
index = 1
v_selected = v[index]
self.assertEqual(len(v_selected), 1)
self._check_indexed_volumes(v, v_selected, [(0, 1)])
# list index
index = [1, 2]
v_selected = v[index]
self.assertEqual(len(v_selected), len(index))
self._check_indexed_volumes(v, v_selected, enumerate(index))
# slice index
index = slice(0, 2, 1)
v_selected = v[0:2]
self.assertEqual(len(v_selected), 2)
self._check_indexed_volumes(v, v_selected, [(0, 0), (1, 1)])
# bool tensor
index = (torch.rand(num_volumes) > 0.5).to(device)
index[:2] = True # make sure smth is selected
v_selected = v[index]
self.assertEqual(len(v_selected), index.sum())
self._check_indexed_volumes(
v,
v_selected,
zip(
torch.arange(index.sum()),
torch.nonzero(index, as_tuple=False).squeeze(),
),
)
# int tensor
index = torch.tensor([1, 2], dtype=torch.int64, device=device)
v_selected = v[index]
self.assertEqual(len(v_selected), index.numel())
self._check_indexed_volumes(v, v_selected, enumerate(index.tolist()))
# invalid index
index = torch.tensor([1, 0, 1], dtype=torch.float32, device=device)
with self.assertRaises(IndexError):
v_selected = v[index]
index = 1.2 # floating point index
with self.assertRaises(IndexError):
v_selected = v[index]
def test_locator_init(self, batch_size=9, resolution=(3, 5, 7)):
with self.subTest("VolumeLocator init with all sizes equal"):
grid_sizes = [resolution for _ in range(batch_size)]
locator_tuple = VolumeLocator(
batch_size=batch_size, grid_sizes=resolution, device=torch.device("cpu")
)
locator_list = VolumeLocator(
batch_size=batch_size, grid_sizes=grid_sizes, device=torch.device("cpu")
)
locator_tensor = VolumeLocator(
batch_size=batch_size,
grid_sizes=torch.tensor(grid_sizes),
device=torch.device("cpu"),
)
expected_grid_sizes = torch.tensor(grid_sizes)
expected_resolution = resolution
assert torch.allclose(expected_grid_sizes, locator_tuple._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
self.assertEqual(expected_resolution, locator_tuple._resolution)
self.assertEqual(expected_resolution, locator_list._resolution)
self.assertEqual(expected_resolution, locator_tensor._resolution)
with self.subTest("VolumeLocator with different sizes in different grids"):
grid_sizes_list = [
torch.randint(low=1, high=42, size=(3,)) for _ in range(batch_size)
]
grid_sizes_tensor = torch.cat([el[None] for el in grid_sizes_list])
locator_list = VolumeLocator(
batch_size=batch_size,
grid_sizes=grid_sizes_list,
device=torch.device("cpu"),
)
locator_tensor = VolumeLocator(
batch_size=batch_size,
grid_sizes=grid_sizes_tensor,
device=torch.device("cpu"),
)
expected_grid_sizes = grid_sizes_tensor
expected_resolution = tuple(torch.max(expected_grid_sizes, dim=0).values)
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
self.assertEqual(expected_resolution, locator_list._resolution)
self.assertEqual(expected_resolution, locator_tensor._resolution)
def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32):
"""
Test the correctness of the conversion between the internal
Transform3D Volumes.VolumeLocator._local_to_world_transform and the initialization
from the translation and voxel_size.
"""
device = torch.device("cuda:0")
# try for 10 sets of different random sizes/centers/voxel_sizes
for _ in range(10):
size = torch.randint(high=10, size=(3,), low=3).tolist()
densities = torch.randn(
size=[num_volumes, num_channels, *size],
device=device,
dtype=torch.float32,
)
# init the transformation params
volume_translation = torch.randn(num_volumes, 3)
voxel_size = torch.rand(num_volumes, 3) * 3.0 + 0.5
# get the corresponding Transform3d object
local_offset = torch.tensor(list(size), dtype=torch.float32, device=device)[
[2, 1, 0]
][None].repeat(num_volumes, 1)
local_to_world_transform = (
Scale(0.5 * local_offset - 0.5, device=device)
.scale(voxel_size)
.translate(-volume_translation)
)
# init the volume structures with the scale and translation,
# then get the coord grid in world coords
v_trans_vs = Volumes(
densities=densities,
voxel_size=voxel_size,
volume_translation=volume_translation,
)
grid_rot_trans_vs = v_trans_vs.get_coord_grid(world_coordinates=True)
# map the default local coords to the world coords
# with local_to_world_transform
v_default = Volumes(densities=densities)
grid_default_local = v_default.get_coord_grid(world_coordinates=False)
grid_default_world = local_to_world_transform.transform_points(
grid_default_local.view(num_volumes, -1, 3)
).view(num_volumes, *size, 3)
# check that both grids are the same
self.assertClose(grid_rot_trans_vs, grid_default_world, atol=1e-5)
# check that the transformations are the same
self.assertClose(
v_trans_vs.get_local_to_world_coords_transform().get_matrix(),
local_to_world_transform.get_matrix(),
atol=1e-5,
)
def test_coord_grid_convention(
self, num_volumes=3, num_channels=4, dtype=torch.float32
):
"""
Check that for a trivial volume with spatial size DxHxW=5x7x5:
1) xyz_world=(0, 0, 0) lands right in the middle of the volume
with xyz_local=(0, 0, 0).
2) xyz_world=(-2, 3, 2) results in xyz_local=(-1, 1, -1).
3) The centeral voxel of the volume coordinate grid
has coords x_world=(0, 0, 0) and x_local=(0, 0, 0)
4) grid_sampler(world_coordinate_grid, local_coordinate_grid)
is the same as world_coordinate_grid itself. I.e. the local coordinate
grid matches the grid_sampler coordinate convention.
"""
device = torch.device("cuda:0")
densities = torch.randn(
size=[num_volumes, num_channels, 5, 7, 5],
device=device,
dtype=torch.float32,
)
v_trivial = Volumes(densities=densities)
# check the case with x_world=(0,0,0)
pts_world = torch.zeros(num_volumes, 1, 3, device=device, dtype=torch.float32)
pts_local = v_trivial.world_to_local_coords(pts_world)
pts_local_expected = torch.zeros_like(pts_local)
self.assertClose(pts_local, pts_local_expected)
# check the case with x_world=(-2, 3, -2)
pts_world = torch.tensor([-2, 3, -2], device=device, dtype=torch.float32)[
None, None
].repeat(num_volumes, 1, 1)
pts_local = v_trivial.world_to_local_coords(pts_world)
pts_local_expected = torch.tensor(
[-1, 1, -1], device=device, dtype=torch.float32
)[None, None].repeat(num_volumes, 1, 1)
self.assertClose(pts_local, pts_local_expected)
# check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
for grid in (grid_world, grid_local):
x0 = grid[0, :, :, 2, 0]
y0 = grid[0, :, 3, :, 1]
z0 = grid[0, 2, :, :, 2]
for coord_line in (x0, y0, z0):
self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)
# resample grid_world using grid_sampler with local coords
# -> make sure the resampled version is the same as original
grid_world_resampled = torch.nn.functional.grid_sample(
grid_world.permute(0, 4, 1, 2, 3), grid_local, align_corners=True
).permute(0, 2, 3, 4, 1)
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
for align_corners in [True, False]:
v_trivial = Volumes(densities=densities, align_corners=align_corners)
# check the case with x_world=(0,0,0)
pts_world = torch.zeros(
num_volumes, 1, 3, device=device, dtype=torch.float32
)
pts_local = v_trivial.world_to_local_coords(pts_world)
pts_local_expected = torch.zeros_like(pts_local)
self.assertClose(pts_local, pts_local_expected)
# check the case with x_world=(-2, 3, -2)
pts_world_tuple = [-2, 3, -2]
pts_world = torch.tensor(
pts_world_tuple, device=device, dtype=torch.float32
)[None, None].repeat(num_volumes, 1, 1)
pts_local = v_trivial.world_to_local_coords(pts_world)
pts_local_expected = torch.tensor(
[-1, 1, -1], device=device, dtype=torch.float32
)[None, None].repeat(num_volumes, 1, 1)
self.assertClose(pts_local, pts_local_expected)
# # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
for grid in (grid_world, grid_local):
x0 = grid[0, :, :, 2, 0]
y0 = grid[0, :, 3, :, 1]
z0 = grid[0, 2, :, :, 2]
for coord_line in (x0, y0, z0):
self.assertClose(
coord_line, torch.zeros_like(coord_line), atol=1e-7
)
# resample grid_world using grid_sampler with local coords
# -> make sure the resampled version is the same as original
grid_world_resampled = torch.nn.functional.grid_sample(
grid_world.permute(0, 4, 1, 2, 3),
grid_local,
align_corners=align_corners,
).permute(0, 2, 3, 4, 1)
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
def test_coord_grid_convention_heterogeneous(
self, num_channels=4, dtype=torch.float32
):
"""
Check that for a list of 2 trivial volumes with
spatial sizes DxHxW=(5x7x5, 3x5x5):
1) xyz_world=(0, 0, 0) lands right in the middle of the volume
with xyz_local=(0, 0, 0).
2) xyz_world=((-2, 3, -2), (-2, -2, 1)) results
in xyz_local=((-1, 1, -1), (-1, -1, 1)).
3) The centeral voxel of the volume coordinate grid
has coords x_world=(0, 0, 0) and x_local=(0, 0, 0)
4) grid_sampler(world_coordinate_grid, local_coordinate_grid)
is the same as world_coordinate_grid itself. I.e. the local coordinate
grid matches the grid_sampler coordinate convention.
"""
device = torch.device("cuda:0")
sizes = [(5, 7, 5), (3, 5, 5)]
densities_list = [
torch.randn(size=[num_channels, *size], device=device, dtype=torch.float32)
for size in sizes
]
# init the volume
v_trivial = Volumes(densities=densities_list)
# check the border point locations
pts_world = torch.tensor(
[[-2.0, 3.0, -2.0], [-2.0, -2.0, 1.0]], device=device, dtype=torch.float32
)[:, None]
pts_local = v_trivial.world_to_local_coords(pts_world)
pts_local_expected = torch.tensor(
[[-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0]], device=device, dtype=torch.float32
)[:, None]
self.assertClose(pts_local, pts_local_expected)
# check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
for grid in (grid_world, grid_local):
x0 = grid[0, :, :, 2, 0]
y0 = grid[0, :, 3, :, 1]
z0 = grid[0, 2, :, :, 2]
for coord_line in (x0, y0, z0):
self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)
x0 = grid[1, :, :, 2, 0]
y0 = grid[1, :, 2, :, 1]
z0 = grid[1, 1, :, :, 2]
for coord_line in (x0, y0, z0):
self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)
# resample grid_world using grid_sampler with local coords
# -> make sure the resampled version is the same as original
for grid_world_, grid_local_, size in zip(grid_world, grid_local, sizes):
grid_world_crop = grid_world_[: size[0], : size[1], : size[2], :][None]
grid_local_crop = grid_local_[: size[0], : size[1], : size[2], :][None]
grid_world_crop_resampled = torch.nn.functional.grid_sample(
grid_world_crop.permute(0, 4, 1, 2, 3),
grid_local_crop,
align_corners=True,
).permute(0, 2, 3, 4, 1)
self.assertClose(grid_world_crop_resampled, grid_world_crop, atol=1e-7)
def test_coord_grid_transforms(
self, num_volumes=3, num_channels=4, dtype=torch.float32
):
"""
Test whether conversion between local-world coordinates of the
volume returns correct results.
"""
device = torch.device("cuda:0")
# try for 10 sets of different random sizes/centers/voxel_sizes
for _ in range(10):
size = torch.randint(high=10, size=(3,), low=3).tolist()
center = torch.randn(num_volumes, 3, dtype=torch.float32, device=device)
voxel_size = torch.rand(1, dtype=torch.float32, device=device) * 5.0 + 0.5
for densities in (
torch.randn(
size=[num_volumes, num_channels, *size],
device=device,
dtype=torch.float32,
),
TestVolumes._random_volume_list(
num_volumes, 3, size, num_channels, device, rand_sizes=None
)[0],
):
# init the volume structure
v = Volumes(
densities=densities,
voxel_size=voxel_size,
volume_translation=-center,
)
# get local coord grid
grid_local = v.get_coord_grid(world_coordinates=False)
# convert from world to local to world
grid_world = v.get_coord_grid(world_coordinates=True)
grid_local_2 = v.world_to_local_coords(grid_world)
grid_world_2 = v.local_to_world_coords(grid_local_2)
# assertions on shape and values of grid_world and grid_local
self.assertClose(grid_world, grid_world_2, atol=1e-5)
self.assertClose(grid_local, grid_local_2, atol=1e-5)
# check that the individual slices of the location grid have
# constant values along expected dimensions
for plane_dim in (1, 2, 3):
for grid_plane in grid_world.split(1, dim=plane_dim):
grid_coord_dim = {1: 2, 2: 1, 3: 0}[plane_dim]
grid_coord_plane = grid_plane.squeeze()[..., grid_coord_dim]
# check that all elements of grid_coord_plane are
# the same for each batch element
self.assertClose(
grid_coord_plane.reshape(num_volumes, -1).max(dim=1).values,
grid_coord_plane.reshape(num_volumes, -1).min(dim=1).values,
)
def test_clone(
self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
):
"""
Test cloning of a `Volumes` object
"""
device = torch.device("cuda:0")
features = torch.randn(
size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32
)
densities = torch.rand(
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
)
for has_features in (True, False):
v = Volumes(
densities=densities, features=features if has_features else None
)
vnew = v.clone()
vnew._densities.data[0, 0, 0, 0, 0] += 1.0
self.assertNotAlmostEqual(
float(
(vnew.densities()[0, 0, 0, 0, 0] - v.densities()[0, 0, 0, 0, 0])
.abs()
.max()
),
0.0,
)
if has_features:
vnew._features.data[0, 0, 0, 0, 0] += 1.0
self.assertNotAlmostEqual(
float(
(vnew.features()[0, 0, 0, 0, 0] - v.features()[0, 0, 0, 0, 0])
.abs()
.max()
),
0.0,
)
def _check_vars_on_device(self, v, desired_device):
for var_name, var in vars(v).items():
if var_name != "device":
if var is not None:
self.assertTrue(
var.device.type == desired_device.type,
(var_name, var.device, desired_device),
)
else:
self.assertTrue(var.type == desired_device.type)
def test_to(
self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
):
"""
Test the moving of the volumes from/to gpu and cpu
"""
features = torch.randn(
size=[num_volumes, num_channels, *size], dtype=torch.float32
)
densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype)
volumes = Volumes(densities=densities, features=features)
locator = VolumeLocator(
batch_size=5, grid_sizes=(3, 5, 7), device=volumes.device
)
for name, obj in (("VolumeLocator", locator), ("Volumes", volumes)):
with self.subTest(f"Moving {name} from/to gpu and cpu"):
# Test support for str and torch.device
cpu_device = torch.device("cpu")
converted_obj = obj.to("cpu")
self.assertEqual(cpu_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIs(obj, converted_obj)
converted_obj = obj.to(cpu_device)
self.assertEqual(cpu_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIs(obj, converted_obj)
cuda_device = torch.device("cuda:0")
converted_obj = obj.to("cuda:0")
self.assertEqual(cuda_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIsNot(obj, converted_obj)
converted_obj = obj.to(cuda_device)
self.assertEqual(cuda_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIsNot(obj, converted_obj)
with self.subTest("Test device placement of internal tensors of Volumes"):
features = features.to(cuda_device)
densities = features.to(cuda_device)
for features_ in (features, None):
volumes = Volumes(densities=densities, features=features_)
cpu_volumes = volumes.cpu()
cuda_volumes = cpu_volumes.cuda()
cuda_volumes2 = cuda_volumes.cuda()
cpu_volumes2 = cuda_volumes2.cpu()
for volumes1, volumes2 in itertools.combinations(
(volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
):
if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
# checks that we do not copy if the devices stay the same
assert_fun = self.assertIs
else:
assert_fun = self.assertSeparate
assert_fun(volumes1._densities, volumes2._densities)
if features_ is not None:
assert_fun(volumes1._features, volumes2._features)
for volumes_ in (volumes1, volumes2):
if volumes_ in (cpu_volumes, cpu_volumes2):
self._check_vars_on_device(volumes_, cpu_device)
else:
self._check_vars_on_device(volumes_, cuda_device)
with self.subTest("Test device placement of internal tensors of VolumeLocator"):
for device1, device2 in itertools.combinations(
(torch.device("cpu"), torch.device("cuda:0")), 2
):
locator = locator.to(device1)
locator = locator.to(device2)
self.assertEqual(locator._grid_sizes.device, device2)
self.assertEqual(locator._local_to_world_transform.device, device2)
def _check_padded(self, x_pad, x_list, grid_sizes):
"""
Check that padded tensors x_pad are the same as x_list tensors.
"""
num_volumes = len(x_list)
for i in range(num_volumes):
self.assertClose(
x_pad[i][:, : grid_sizes[i][0], : grid_sizes[i][1], : grid_sizes[i][2]],
x_list[i],
)
def test_feature_density_setters(self):
"""
Tests getters and setters for padded/list representations.
"""
device = torch.device("cuda:0")
diff_device = torch.device("cpu")
num_volumes = 30
num_channels = 4
K = 20
densities = []
features = []
grid_sizes = []
diff_grid_sizes = []
for _ in range(num_volumes):
grid_size = torch.randint(K - 1, size=(3,)).long() + 1
densities.append(
torch.rand((1, *grid_size), device=device, dtype=torch.float32)
)
features.append(
torch.rand(
(num_channels, *grid_size), device=device, dtype=torch.float32
)
)
grid_sizes.append(grid_size)
diff_grid_size = (
copy.deepcopy(grid_size) + torch.randint(2, size=(3,)).long() + 1
)
diff_grid_sizes.append(diff_grid_size)
grid_sizes = torch.stack(grid_sizes).to(device)
diff_grid_sizes = torch.stack(diff_grid_sizes).to(device)
volumes = Volumes(densities=densities, features=features)
self.assertClose(volumes.get_grid_sizes(), grid_sizes)
# test the getters
features_padded = volumes.features()
densities_padded = volumes.densities()
features_list = volumes.features_list()
densities_list = volumes.densities_list()
for x_pad, x_list in zip(
(densities_padded, features_padded, densities_padded, features_padded),
(densities_list, features_list, densities, features),
):
self._check_padded(x_pad, x_list, grid_sizes)
# test feature setters
features_new = [
torch.rand((num_channels, *grid_size), device=device, dtype=torch.float32)
for grid_size in grid_sizes
]
volumes._set_features(features_new)
features_new_list = volumes.features_list()
features_new_padded = volumes.features()
for x_pad, x_list in zip(
(features_new_padded, features_new_padded),
(features_new, features_new_list),
):
self._check_padded(x_pad, x_list, grid_sizes)
# wrong features to update
bad_features_new = [
[
torch.rand(
(num_channels, *grid_size), device=diff_device, dtype=torch.float32
)
for grid_size in diff_grid_sizes
],
torch.rand(
(num_volumes, num_channels, K + 1, K, K),
device=device,
dtype=torch.float32,
),
None,
]
for bad_features_new_ in bad_features_new:
with self.assertRaises(ValueError):
volumes._set_densities(bad_features_new_)
# test density setters
densities_new = [
torch.rand((1, *grid_size), device=device, dtype=torch.float32)
for grid_size in grid_sizes
]
volumes._set_densities(densities_new)
densities_new_list = volumes.densities_list()
densities_new_padded = volumes.densities()
for x_pad, x_list in zip(
(densities_new_padded, densities_new_padded),
(densities_new, densities_new_list),
):
self._check_padded(x_pad, x_list, grid_sizes)
# wrong densities to update
bad_densities_new = [
[
torch.rand((1, *grid_size), device=diff_device, dtype=torch.float32)
for grid_size in diff_grid_sizes
],
torch.rand(
(num_volumes, 1, K + 1, K, K), device=device, dtype=torch.float32
),
None,
]
for bad_densities_new_ in bad_densities_new:
with self.assertRaises(ValueError):
volumes._set_densities(bad_densities_new_)
# test update_padded
volumes = Volumes(densities=densities, features=features)
volumes_updated = volumes.update_padded(
densities_new, new_features=features_new
)
densities_new_list = volumes_updated.densities_list()
densities_new_padded = volumes_updated.densities()
features_new_list = volumes_updated.features_list()
features_new_padded = volumes_updated.features()
for x_pad, x_list in zip(
(
densities_new_padded,
densities_new_padded,
features_new_padded,
features_new_padded,
),
(densities_new, densities_new_list, features_new, features_new_list),
):
self._check_padded(x_pad, x_list, grid_sizes)
self.assertIs(volumes.get_grid_sizes(), volumes_updated.get_grid_sizes())
self.assertIs(
volumes.get_local_to_world_coords_transform(),
volumes_updated.get_local_to_world_coords_transform(),
)
self.assertIs(volumes.device, volumes_updated.device)
def test_constructor_for_padded_lists(self):
"""
Tests constructor for padded/list representations.
"""
device = torch.device("cuda:0")
diff_device = torch.device("cpu")
num_volumes = 3
num_channels = 4
size = (6, 8, 10)
diff_size = (6, 8, 11)
# good ways to define densities
ok_densities = [
torch.randn(
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
).unbind(0),
torch.randn(
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
),
]
# bad ways to define features
bad_features = [
torch.randn(
size=[num_volumes + 1, num_channels, *size],
device=device,
dtype=torch.float32,
).unbind(0), # list with diff batch size
torch.randn(
size=[num_volumes + 1, num_channels, *size],
device=device,
dtype=torch.float32,
), # diff batch size
torch.randn(
size=[num_volumes, num_channels, *diff_size],
device=device,
dtype=torch.float32,
).unbind(0), # list with different size
torch.randn(
size=[num_volumes, num_channels, *diff_size],
device=device,
dtype=torch.float32,
), # different size
torch.randn(
size=[num_volumes, num_channels, *size],
device=diff_device,
dtype=torch.float32,
), # different device
torch.randn(
size=[num_volumes, num_channels, *size],
device=diff_device,
dtype=torch.float32,
).unbind(0), # list with different device
]
# good ways to define features
ok_features = [
torch.randn(
size=[num_volumes, num_channels, *size],
device=device,
dtype=torch.float32,
).unbind(0), # list of features of correct size
torch.randn(
size=[num_volumes, num_channels, *size],
device=device,
dtype=torch.float32,
),
]
for densities in ok_densities:
for features in bad_features:
self.assertRaises(
ValueError, Volumes, densities=densities, features=features
)
for features in ok_features:
Volumes(densities=densities, features=features)
def test_constructor(
self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
):
"""
Test different ways of calling the `Volumes` constructor
"""
device = torch.device("cuda:0")
# all ways to define features
features = [
torch.randn(
size=[num_volumes, num_channels, *size],
device=device,
dtype=torch.float32,
), # padded tensor
torch.randn(
size=[num_volumes, num_channels, *size],
device=device,
dtype=torch.float32,
).unbind(0), # list of features
None, # no features
]
# bad ways to define features
bad_features = [
torch.randn(
size=[num_volumes, num_channels, 2, *size],
device=device,
dtype=torch.float32,
), # 6 dims
torch.randn(
size=[num_volumes, *size], device=device, dtype=torch.float32
), # 4 dims
torch.randn(
size=[num_volumes, *size], device=device, dtype=torch.float32
).unbind(0), # list of 4 dim tensors
]
# all ways to define densities
densities = [
torch.randn(
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
), # padded tensor
torch.randn(
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
).unbind(0), # list of densities
]
# bad ways to define densities
bad_densities = [
None, # omitted
torch.randn(
size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32
), # 6-dim tensor
torch.randn(
size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32
).unbind(0), # list of 5-dim densities
]
# all possible ways to define the voxels sizes
vox_sizes = [
torch.Tensor([1.0, 1.0, 1.0]),
[1.0, 1.0, 1.0],
torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1),
torch.Tensor([1.0])[None].repeat(num_volumes, 1),
1.0,
torch.Tensor([1.0]),
]
# all possible ways to define the volume translations
vol_translations = [
torch.Tensor([1.0, 1.0, 1.0]),
[1.0, 1.0, 1.0],
torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1),
]
# wrong ways to define voxel sizes
bad_vox_sizes = [
torch.Tensor([1.0, 1.0, 1.0, 1.0]),
[1.0, 1.0, 1.0, 1.0],
torch.Tensor([]),
None,
]
# wrong ways to define the volume translations
bad_vol_translations = [
torch.Tensor([1.0, 1.0]),
[1.0, 1.0],
1.0,
torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes + 1, 1),
]
def zip_with_ok_indicator(good, bad):
return zip([*good, *bad], [*([True] * len(good)), *([False] * len(bad))])
for features_, features_ok in zip_with_ok_indicator(features, bad_features):
for densities_, densities_ok in zip_with_ok_indicator(
densities, bad_densities
):
for vox_size, size_ok in zip_with_ok_indicator(
vox_sizes, bad_vox_sizes
):
for vol_translation, trans_ok in zip_with_ok_indicator(
vol_translations, bad_vol_translations
):
if (
size_ok and trans_ok and features_ok and densities_ok
): # if all entries are good we check that this doesnt throw
Volumes(
features=features_,
densities=densities_,
voxel_size=vox_size,
volume_translation=vol_translation,
)
else: # otherwise we check for ValueError
self.assertRaises(
ValueError,
Volumes,
features=features_,
densities=densities_,
voxel_size=vox_size,
volume_translation=vol_translation,
)