mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Summary: Converts the directory specified to use the Ruff formatter in pyfmt ruff_dog If this diff causes merge conflicts when rebasing, please run `hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt` on your diff, and amend any changes before rebasing onto latest. That should help reduce or eliminate any merge conflicts. allow-large-files Reviewed By: bottler Differential Revision: D66472063 fbshipit-source-id: 35841cb397e4f8e066e2159550d2f56b403b1bef
967 lines
38 KiB
Python
967 lines
38 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import copy
|
|
import itertools
|
|
import random
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from pytorch3d.structures.volumes import VolumeLocator, Volumes
|
|
from pytorch3d.transforms import Scale
|
|
|
|
from .common_testing import TestCaseMixin
|
|
|
|
|
|
class TestVolumes(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
np.random.seed(42)
|
|
torch.manual_seed(42)
|
|
random.seed(42)
|
|
|
|
@staticmethod
|
|
def _random_volume_list(
|
|
num_volumes, min_size, max_size, num_channels, device, rand_sizes=None
|
|
):
|
|
"""
|
|
Init a list of `num_volumes` random tensors of size [num_channels, *rand_size].
|
|
If `rand_sizes` is None, rand_size is a 3D long vector sampled
|
|
from [min_size, max_size]. Otherwise, rand_size should be a list
|
|
[rand_size_1, rand_size_2, ..., rand_size_num_volumes] where each
|
|
`rand_size_i` denotes the size of the corresponding `i`-th tensor.
|
|
"""
|
|
if rand_sizes is None:
|
|
rand_sizes = [
|
|
[random.randint(min_size, vs) for vs in max_size]
|
|
for _ in range(num_volumes)
|
|
]
|
|
|
|
volume_list = [
|
|
torch.randn(
|
|
size=[num_channels, *rand_size], device=device, dtype=torch.float32
|
|
)
|
|
for rand_size in rand_sizes
|
|
]
|
|
|
|
return volume_list, rand_sizes
|
|
|
|
def _check_indexed_volumes(self, v, selected, indices):
|
|
for selectedIdx, index in indices:
|
|
self.assertClose(selected.densities()[selectedIdx], v.densities()[index])
|
|
self.assertClose(
|
|
v.locator._local_to_world_transform.get_matrix()[index],
|
|
selected.locator._local_to_world_transform.get_matrix()[selectedIdx],
|
|
)
|
|
if selected.features() is not None:
|
|
self.assertClose(selected.features()[selectedIdx], v.features()[index])
|
|
|
|
def test_get_item(
|
|
self,
|
|
num_volumes=5,
|
|
num_channels=4,
|
|
volume_size=(10, 13, 8),
|
|
dtype=torch.float32,
|
|
):
|
|
device = torch.device("cuda:0")
|
|
|
|
# make sure we have at least 3 volumes to prevent indexing crash
|
|
num_volumes = max(num_volumes, 3)
|
|
|
|
features = torch.randn(
|
|
size=[num_volumes, num_channels, *volume_size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
)
|
|
densities = torch.randn(
|
|
size=[num_volumes, 1, *volume_size], device=device, dtype=torch.float32
|
|
)
|
|
|
|
features_list, rand_sizes = TestVolumes._random_volume_list(
|
|
num_volumes, 3, volume_size, num_channels, device
|
|
)
|
|
densities_list, _ = TestVolumes._random_volume_list(
|
|
num_volumes, 3, volume_size, 1, device, rand_sizes=rand_sizes
|
|
)
|
|
|
|
volume_translation = -torch.randn(num_volumes, 3).type_as(features)
|
|
voxel_size = torch.rand(num_volumes, 1).type_as(features) + 0.5
|
|
|
|
for features_, densities_ in zip(
|
|
(None, features, features_list), (densities, densities, densities_list)
|
|
):
|
|
# init the volume structure
|
|
v = Volumes(
|
|
features=features_,
|
|
densities=densities_,
|
|
volume_translation=volume_translation,
|
|
voxel_size=voxel_size,
|
|
)
|
|
|
|
# int index
|
|
index = 1
|
|
v_selected = v[index]
|
|
self.assertEqual(len(v_selected), 1)
|
|
self._check_indexed_volumes(v, v_selected, [(0, 1)])
|
|
|
|
# list index
|
|
index = [1, 2]
|
|
v_selected = v[index]
|
|
self.assertEqual(len(v_selected), len(index))
|
|
self._check_indexed_volumes(v, v_selected, enumerate(index))
|
|
|
|
# slice index
|
|
index = slice(0, 2, 1)
|
|
v_selected = v[0:2]
|
|
self.assertEqual(len(v_selected), 2)
|
|
self._check_indexed_volumes(v, v_selected, [(0, 0), (1, 1)])
|
|
|
|
# bool tensor
|
|
index = (torch.rand(num_volumes) > 0.5).to(device)
|
|
index[:2] = True # make sure smth is selected
|
|
v_selected = v[index]
|
|
self.assertEqual(len(v_selected), index.sum())
|
|
self._check_indexed_volumes(
|
|
v,
|
|
v_selected,
|
|
zip(
|
|
torch.arange(index.sum()),
|
|
torch.nonzero(index, as_tuple=False).squeeze(),
|
|
),
|
|
)
|
|
|
|
# int tensor
|
|
index = torch.tensor([1, 2], dtype=torch.int64, device=device)
|
|
v_selected = v[index]
|
|
self.assertEqual(len(v_selected), index.numel())
|
|
self._check_indexed_volumes(v, v_selected, enumerate(index.tolist()))
|
|
|
|
# invalid index
|
|
index = torch.tensor([1, 0, 1], dtype=torch.float32, device=device)
|
|
with self.assertRaises(IndexError):
|
|
v_selected = v[index]
|
|
index = 1.2 # floating point index
|
|
with self.assertRaises(IndexError):
|
|
v_selected = v[index]
|
|
|
|
def test_locator_init(self, batch_size=9, resolution=(3, 5, 7)):
|
|
with self.subTest("VolumeLocator init with all sizes equal"):
|
|
grid_sizes = [resolution for _ in range(batch_size)]
|
|
locator_tuple = VolumeLocator(
|
|
batch_size=batch_size, grid_sizes=resolution, device=torch.device("cpu")
|
|
)
|
|
locator_list = VolumeLocator(
|
|
batch_size=batch_size, grid_sizes=grid_sizes, device=torch.device("cpu")
|
|
)
|
|
locator_tensor = VolumeLocator(
|
|
batch_size=batch_size,
|
|
grid_sizes=torch.tensor(grid_sizes),
|
|
device=torch.device("cpu"),
|
|
)
|
|
expected_grid_sizes = torch.tensor(grid_sizes)
|
|
expected_resolution = resolution
|
|
assert torch.allclose(expected_grid_sizes, locator_tuple._grid_sizes)
|
|
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
|
|
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
|
|
self.assertEqual(expected_resolution, locator_tuple._resolution)
|
|
self.assertEqual(expected_resolution, locator_list._resolution)
|
|
self.assertEqual(expected_resolution, locator_tensor._resolution)
|
|
|
|
with self.subTest("VolumeLocator with different sizes in different grids"):
|
|
grid_sizes_list = [
|
|
torch.randint(low=1, high=42, size=(3,)) for _ in range(batch_size)
|
|
]
|
|
grid_sizes_tensor = torch.cat([el[None] for el in grid_sizes_list])
|
|
locator_list = VolumeLocator(
|
|
batch_size=batch_size,
|
|
grid_sizes=grid_sizes_list,
|
|
device=torch.device("cpu"),
|
|
)
|
|
locator_tensor = VolumeLocator(
|
|
batch_size=batch_size,
|
|
grid_sizes=grid_sizes_tensor,
|
|
device=torch.device("cpu"),
|
|
)
|
|
expected_grid_sizes = grid_sizes_tensor
|
|
expected_resolution = tuple(torch.max(expected_grid_sizes, dim=0).values)
|
|
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
|
|
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
|
|
self.assertEqual(expected_resolution, locator_list._resolution)
|
|
self.assertEqual(expected_resolution, locator_tensor._resolution)
|
|
|
|
def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32):
|
|
"""
|
|
Test the correctness of the conversion between the internal
|
|
Transform3D Volumes.VolumeLocator._local_to_world_transform and the initialization
|
|
from the translation and voxel_size.
|
|
"""
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
# try for 10 sets of different random sizes/centers/voxel_sizes
|
|
for _ in range(10):
|
|
size = torch.randint(high=10, size=(3,), low=3).tolist()
|
|
|
|
densities = torch.randn(
|
|
size=[num_volumes, num_channels, *size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
# init the transformation params
|
|
volume_translation = torch.randn(num_volumes, 3)
|
|
voxel_size = torch.rand(num_volumes, 3) * 3.0 + 0.5
|
|
|
|
# get the corresponding Transform3d object
|
|
local_offset = torch.tensor(list(size), dtype=torch.float32, device=device)[
|
|
[2, 1, 0]
|
|
][None].repeat(num_volumes, 1)
|
|
local_to_world_transform = (
|
|
Scale(0.5 * local_offset - 0.5, device=device)
|
|
.scale(voxel_size)
|
|
.translate(-volume_translation)
|
|
)
|
|
|
|
# init the volume structures with the scale and translation,
|
|
# then get the coord grid in world coords
|
|
v_trans_vs = Volumes(
|
|
densities=densities,
|
|
voxel_size=voxel_size,
|
|
volume_translation=volume_translation,
|
|
)
|
|
grid_rot_trans_vs = v_trans_vs.get_coord_grid(world_coordinates=True)
|
|
|
|
# map the default local coords to the world coords
|
|
# with local_to_world_transform
|
|
v_default = Volumes(densities=densities)
|
|
grid_default_local = v_default.get_coord_grid(world_coordinates=False)
|
|
grid_default_world = local_to_world_transform.transform_points(
|
|
grid_default_local.view(num_volumes, -1, 3)
|
|
).view(num_volumes, *size, 3)
|
|
|
|
# check that both grids are the same
|
|
self.assertClose(grid_rot_trans_vs, grid_default_world, atol=1e-5)
|
|
|
|
# check that the transformations are the same
|
|
self.assertClose(
|
|
v_trans_vs.get_local_to_world_coords_transform().get_matrix(),
|
|
local_to_world_transform.get_matrix(),
|
|
atol=1e-5,
|
|
)
|
|
|
|
def test_coord_grid_convention(
|
|
self, num_volumes=3, num_channels=4, dtype=torch.float32
|
|
):
|
|
"""
|
|
Check that for a trivial volume with spatial size DxHxW=5x7x5:
|
|
1) xyz_world=(0, 0, 0) lands right in the middle of the volume
|
|
with xyz_local=(0, 0, 0).
|
|
2) xyz_world=(-2, 3, 2) results in xyz_local=(-1, 1, -1).
|
|
3) The centeral voxel of the volume coordinate grid
|
|
has coords x_world=(0, 0, 0) and x_local=(0, 0, 0)
|
|
4) grid_sampler(world_coordinate_grid, local_coordinate_grid)
|
|
is the same as world_coordinate_grid itself. I.e. the local coordinate
|
|
grid matches the grid_sampler coordinate convention.
|
|
"""
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
densities = torch.randn(
|
|
size=[num_volumes, num_channels, 5, 7, 5],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
)
|
|
v_trivial = Volumes(densities=densities)
|
|
|
|
# check the case with x_world=(0,0,0)
|
|
pts_world = torch.zeros(num_volumes, 1, 3, device=device, dtype=torch.float32)
|
|
pts_local = v_trivial.world_to_local_coords(pts_world)
|
|
pts_local_expected = torch.zeros_like(pts_local)
|
|
self.assertClose(pts_local, pts_local_expected)
|
|
|
|
# check the case with x_world=(-2, 3, -2)
|
|
pts_world = torch.tensor([-2, 3, -2], device=device, dtype=torch.float32)[
|
|
None, None
|
|
].repeat(num_volumes, 1, 1)
|
|
pts_local = v_trivial.world_to_local_coords(pts_world)
|
|
pts_local_expected = torch.tensor(
|
|
[-1, 1, -1], device=device, dtype=torch.float32
|
|
)[None, None].repeat(num_volumes, 1, 1)
|
|
self.assertClose(pts_local, pts_local_expected)
|
|
|
|
# check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
|
|
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
|
|
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
|
|
for grid in (grid_world, grid_local):
|
|
x0 = grid[0, :, :, 2, 0]
|
|
y0 = grid[0, :, 3, :, 1]
|
|
z0 = grid[0, 2, :, :, 2]
|
|
for coord_line in (x0, y0, z0):
|
|
self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)
|
|
|
|
# resample grid_world using grid_sampler with local coords
|
|
# -> make sure the resampled version is the same as original
|
|
grid_world_resampled = torch.nn.functional.grid_sample(
|
|
grid_world.permute(0, 4, 1, 2, 3), grid_local, align_corners=True
|
|
).permute(0, 2, 3, 4, 1)
|
|
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
|
|
|
|
for align_corners in [True, False]:
|
|
v_trivial = Volumes(densities=densities, align_corners=align_corners)
|
|
|
|
# check the case with x_world=(0,0,0)
|
|
pts_world = torch.zeros(
|
|
num_volumes, 1, 3, device=device, dtype=torch.float32
|
|
)
|
|
pts_local = v_trivial.world_to_local_coords(pts_world)
|
|
pts_local_expected = torch.zeros_like(pts_local)
|
|
self.assertClose(pts_local, pts_local_expected)
|
|
|
|
# check the case with x_world=(-2, 3, -2)
|
|
pts_world_tuple = [-2, 3, -2]
|
|
pts_world = torch.tensor(
|
|
pts_world_tuple, device=device, dtype=torch.float32
|
|
)[None, None].repeat(num_volumes, 1, 1)
|
|
pts_local = v_trivial.world_to_local_coords(pts_world)
|
|
pts_local_expected = torch.tensor(
|
|
[-1, 1, -1], device=device, dtype=torch.float32
|
|
)[None, None].repeat(num_volumes, 1, 1)
|
|
self.assertClose(pts_local, pts_local_expected)
|
|
|
|
# # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
|
|
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
|
|
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
|
|
for grid in (grid_world, grid_local):
|
|
x0 = grid[0, :, :, 2, 0]
|
|
y0 = grid[0, :, 3, :, 1]
|
|
z0 = grid[0, 2, :, :, 2]
|
|
for coord_line in (x0, y0, z0):
|
|
self.assertClose(
|
|
coord_line, torch.zeros_like(coord_line), atol=1e-7
|
|
)
|
|
|
|
# resample grid_world using grid_sampler with local coords
|
|
# -> make sure the resampled version is the same as original
|
|
grid_world_resampled = torch.nn.functional.grid_sample(
|
|
grid_world.permute(0, 4, 1, 2, 3),
|
|
grid_local,
|
|
align_corners=align_corners,
|
|
).permute(0, 2, 3, 4, 1)
|
|
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
|
|
|
|
def test_coord_grid_convention_heterogeneous(
|
|
self, num_channels=4, dtype=torch.float32
|
|
):
|
|
"""
|
|
Check that for a list of 2 trivial volumes with
|
|
spatial sizes DxHxW=(5x7x5, 3x5x5):
|
|
1) xyz_world=(0, 0, 0) lands right in the middle of the volume
|
|
with xyz_local=(0, 0, 0).
|
|
2) xyz_world=((-2, 3, -2), (-2, -2, 1)) results
|
|
in xyz_local=((-1, 1, -1), (-1, -1, 1)).
|
|
3) The centeral voxel of the volume coordinate grid
|
|
has coords x_world=(0, 0, 0) and x_local=(0, 0, 0)
|
|
4) grid_sampler(world_coordinate_grid, local_coordinate_grid)
|
|
is the same as world_coordinate_grid itself. I.e. the local coordinate
|
|
grid matches the grid_sampler coordinate convention.
|
|
"""
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
sizes = [(5, 7, 5), (3, 5, 5)]
|
|
|
|
densities_list = [
|
|
torch.randn(size=[num_channels, *size], device=device, dtype=torch.float32)
|
|
for size in sizes
|
|
]
|
|
|
|
# init the volume
|
|
v_trivial = Volumes(densities=densities_list)
|
|
|
|
# check the border point locations
|
|
pts_world = torch.tensor(
|
|
[[-2.0, 3.0, -2.0], [-2.0, -2.0, 1.0]], device=device, dtype=torch.float32
|
|
)[:, None]
|
|
pts_local = v_trivial.world_to_local_coords(pts_world)
|
|
pts_local_expected = torch.tensor(
|
|
[[-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0]], device=device, dtype=torch.float32
|
|
)[:, None]
|
|
self.assertClose(pts_local, pts_local_expected)
|
|
|
|
# check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
|
|
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
|
|
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
|
|
for grid in (grid_world, grid_local):
|
|
x0 = grid[0, :, :, 2, 0]
|
|
y0 = grid[0, :, 3, :, 1]
|
|
z0 = grid[0, 2, :, :, 2]
|
|
for coord_line in (x0, y0, z0):
|
|
self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)
|
|
x0 = grid[1, :, :, 2, 0]
|
|
y0 = grid[1, :, 2, :, 1]
|
|
z0 = grid[1, 1, :, :, 2]
|
|
for coord_line in (x0, y0, z0):
|
|
self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)
|
|
|
|
# resample grid_world using grid_sampler with local coords
|
|
# -> make sure the resampled version is the same as original
|
|
for grid_world_, grid_local_, size in zip(grid_world, grid_local, sizes):
|
|
grid_world_crop = grid_world_[: size[0], : size[1], : size[2], :][None]
|
|
grid_local_crop = grid_local_[: size[0], : size[1], : size[2], :][None]
|
|
grid_world_crop_resampled = torch.nn.functional.grid_sample(
|
|
grid_world_crop.permute(0, 4, 1, 2, 3),
|
|
grid_local_crop,
|
|
align_corners=True,
|
|
).permute(0, 2, 3, 4, 1)
|
|
self.assertClose(grid_world_crop_resampled, grid_world_crop, atol=1e-7)
|
|
|
|
def test_coord_grid_transforms(
|
|
self, num_volumes=3, num_channels=4, dtype=torch.float32
|
|
):
|
|
"""
|
|
Test whether conversion between local-world coordinates of the
|
|
volume returns correct results.
|
|
"""
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
# try for 10 sets of different random sizes/centers/voxel_sizes
|
|
for _ in range(10):
|
|
size = torch.randint(high=10, size=(3,), low=3).tolist()
|
|
|
|
center = torch.randn(num_volumes, 3, dtype=torch.float32, device=device)
|
|
voxel_size = torch.rand(1, dtype=torch.float32, device=device) * 5.0 + 0.5
|
|
|
|
for densities in (
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, *size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
),
|
|
TestVolumes._random_volume_list(
|
|
num_volumes, 3, size, num_channels, device, rand_sizes=None
|
|
)[0],
|
|
):
|
|
# init the volume structure
|
|
v = Volumes(
|
|
densities=densities,
|
|
voxel_size=voxel_size,
|
|
volume_translation=-center,
|
|
)
|
|
|
|
# get local coord grid
|
|
grid_local = v.get_coord_grid(world_coordinates=False)
|
|
|
|
# convert from world to local to world
|
|
grid_world = v.get_coord_grid(world_coordinates=True)
|
|
grid_local_2 = v.world_to_local_coords(grid_world)
|
|
grid_world_2 = v.local_to_world_coords(grid_local_2)
|
|
|
|
# assertions on shape and values of grid_world and grid_local
|
|
self.assertClose(grid_world, grid_world_2, atol=1e-5)
|
|
self.assertClose(grid_local, grid_local_2, atol=1e-5)
|
|
|
|
# check that the individual slices of the location grid have
|
|
# constant values along expected dimensions
|
|
for plane_dim in (1, 2, 3):
|
|
for grid_plane in grid_world.split(1, dim=plane_dim):
|
|
grid_coord_dim = {1: 2, 2: 1, 3: 0}[plane_dim]
|
|
grid_coord_plane = grid_plane.squeeze()[..., grid_coord_dim]
|
|
# check that all elements of grid_coord_plane are
|
|
# the same for each batch element
|
|
self.assertClose(
|
|
grid_coord_plane.reshape(num_volumes, -1).max(dim=1).values,
|
|
grid_coord_plane.reshape(num_volumes, -1).min(dim=1).values,
|
|
)
|
|
|
|
def test_clone(
|
|
self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
|
|
):
|
|
"""
|
|
Test cloning of a `Volumes` object
|
|
"""
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
features = torch.randn(
|
|
size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32
|
|
)
|
|
densities = torch.rand(
|
|
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
|
)
|
|
|
|
for has_features in (True, False):
|
|
v = Volumes(
|
|
densities=densities, features=features if has_features else None
|
|
)
|
|
vnew = v.clone()
|
|
vnew._densities.data[0, 0, 0, 0, 0] += 1.0
|
|
self.assertNotAlmostEqual(
|
|
float(
|
|
(vnew.densities()[0, 0, 0, 0, 0] - v.densities()[0, 0, 0, 0, 0])
|
|
.abs()
|
|
.max()
|
|
),
|
|
0.0,
|
|
)
|
|
|
|
if has_features:
|
|
vnew._features.data[0, 0, 0, 0, 0] += 1.0
|
|
self.assertNotAlmostEqual(
|
|
float(
|
|
(vnew.features()[0, 0, 0, 0, 0] - v.features()[0, 0, 0, 0, 0])
|
|
.abs()
|
|
.max()
|
|
),
|
|
0.0,
|
|
)
|
|
|
|
def _check_vars_on_device(self, v, desired_device):
|
|
for var_name, var in vars(v).items():
|
|
if var_name != "device":
|
|
if var is not None:
|
|
self.assertTrue(
|
|
var.device.type == desired_device.type,
|
|
(var_name, var.device, desired_device),
|
|
)
|
|
else:
|
|
self.assertTrue(var.type == desired_device.type)
|
|
|
|
def test_to(
|
|
self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
|
|
):
|
|
"""
|
|
Test the moving of the volumes from/to gpu and cpu
|
|
"""
|
|
|
|
features = torch.randn(
|
|
size=[num_volumes, num_channels, *size], dtype=torch.float32
|
|
)
|
|
densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype)
|
|
volumes = Volumes(densities=densities, features=features)
|
|
locator = VolumeLocator(
|
|
batch_size=5, grid_sizes=(3, 5, 7), device=volumes.device
|
|
)
|
|
|
|
for name, obj in (("VolumeLocator", locator), ("Volumes", volumes)):
|
|
with self.subTest(f"Moving {name} from/to gpu and cpu"):
|
|
# Test support for str and torch.device
|
|
cpu_device = torch.device("cpu")
|
|
|
|
converted_obj = obj.to("cpu")
|
|
self.assertEqual(cpu_device, converted_obj.device)
|
|
self.assertEqual(cpu_device, obj.device)
|
|
self.assertIs(obj, converted_obj)
|
|
|
|
converted_obj = obj.to(cpu_device)
|
|
self.assertEqual(cpu_device, converted_obj.device)
|
|
self.assertEqual(cpu_device, obj.device)
|
|
self.assertIs(obj, converted_obj)
|
|
|
|
cuda_device = torch.device("cuda:0")
|
|
|
|
converted_obj = obj.to("cuda:0")
|
|
self.assertEqual(cuda_device, converted_obj.device)
|
|
self.assertEqual(cpu_device, obj.device)
|
|
self.assertIsNot(obj, converted_obj)
|
|
|
|
converted_obj = obj.to(cuda_device)
|
|
self.assertEqual(cuda_device, converted_obj.device)
|
|
self.assertEqual(cpu_device, obj.device)
|
|
self.assertIsNot(obj, converted_obj)
|
|
|
|
with self.subTest("Test device placement of internal tensors of Volumes"):
|
|
features = features.to(cuda_device)
|
|
densities = features.to(cuda_device)
|
|
|
|
for features_ in (features, None):
|
|
volumes = Volumes(densities=densities, features=features_)
|
|
|
|
cpu_volumes = volumes.cpu()
|
|
cuda_volumes = cpu_volumes.cuda()
|
|
cuda_volumes2 = cuda_volumes.cuda()
|
|
cpu_volumes2 = cuda_volumes2.cpu()
|
|
|
|
for volumes1, volumes2 in itertools.combinations(
|
|
(volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
|
|
):
|
|
if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
|
|
# checks that we do not copy if the devices stay the same
|
|
assert_fun = self.assertIs
|
|
else:
|
|
assert_fun = self.assertSeparate
|
|
assert_fun(volumes1._densities, volumes2._densities)
|
|
if features_ is not None:
|
|
assert_fun(volumes1._features, volumes2._features)
|
|
for volumes_ in (volumes1, volumes2):
|
|
if volumes_ in (cpu_volumes, cpu_volumes2):
|
|
self._check_vars_on_device(volumes_, cpu_device)
|
|
else:
|
|
self._check_vars_on_device(volumes_, cuda_device)
|
|
|
|
with self.subTest("Test device placement of internal tensors of VolumeLocator"):
|
|
for device1, device2 in itertools.combinations(
|
|
(torch.device("cpu"), torch.device("cuda:0")), 2
|
|
):
|
|
locator = locator.to(device1)
|
|
locator = locator.to(device2)
|
|
self.assertEqual(locator._grid_sizes.device, device2)
|
|
self.assertEqual(locator._local_to_world_transform.device, device2)
|
|
|
|
def _check_padded(self, x_pad, x_list, grid_sizes):
|
|
"""
|
|
Check that padded tensors x_pad are the same as x_list tensors.
|
|
"""
|
|
num_volumes = len(x_list)
|
|
for i in range(num_volumes):
|
|
self.assertClose(
|
|
x_pad[i][:, : grid_sizes[i][0], : grid_sizes[i][1], : grid_sizes[i][2]],
|
|
x_list[i],
|
|
)
|
|
|
|
def test_feature_density_setters(self):
|
|
"""
|
|
Tests getters and setters for padded/list representations.
|
|
"""
|
|
|
|
device = torch.device("cuda:0")
|
|
diff_device = torch.device("cpu")
|
|
|
|
num_volumes = 30
|
|
num_channels = 4
|
|
K = 20
|
|
|
|
densities = []
|
|
features = []
|
|
grid_sizes = []
|
|
diff_grid_sizes = []
|
|
|
|
for _ in range(num_volumes):
|
|
grid_size = torch.randint(K - 1, size=(3,)).long() + 1
|
|
densities.append(
|
|
torch.rand((1, *grid_size), device=device, dtype=torch.float32)
|
|
)
|
|
features.append(
|
|
torch.rand(
|
|
(num_channels, *grid_size), device=device, dtype=torch.float32
|
|
)
|
|
)
|
|
grid_sizes.append(grid_size)
|
|
|
|
diff_grid_size = (
|
|
copy.deepcopy(grid_size) + torch.randint(2, size=(3,)).long() + 1
|
|
)
|
|
diff_grid_sizes.append(diff_grid_size)
|
|
grid_sizes = torch.stack(grid_sizes).to(device)
|
|
diff_grid_sizes = torch.stack(diff_grid_sizes).to(device)
|
|
|
|
volumes = Volumes(densities=densities, features=features)
|
|
self.assertClose(volumes.get_grid_sizes(), grid_sizes)
|
|
|
|
# test the getters
|
|
features_padded = volumes.features()
|
|
densities_padded = volumes.densities()
|
|
features_list = volumes.features_list()
|
|
densities_list = volumes.densities_list()
|
|
for x_pad, x_list in zip(
|
|
(densities_padded, features_padded, densities_padded, features_padded),
|
|
(densities_list, features_list, densities, features),
|
|
):
|
|
self._check_padded(x_pad, x_list, grid_sizes)
|
|
|
|
# test feature setters
|
|
features_new = [
|
|
torch.rand((num_channels, *grid_size), device=device, dtype=torch.float32)
|
|
for grid_size in grid_sizes
|
|
]
|
|
volumes._set_features(features_new)
|
|
features_new_list = volumes.features_list()
|
|
features_new_padded = volumes.features()
|
|
for x_pad, x_list in zip(
|
|
(features_new_padded, features_new_padded),
|
|
(features_new, features_new_list),
|
|
):
|
|
self._check_padded(x_pad, x_list, grid_sizes)
|
|
|
|
# wrong features to update
|
|
bad_features_new = [
|
|
[
|
|
torch.rand(
|
|
(num_channels, *grid_size), device=diff_device, dtype=torch.float32
|
|
)
|
|
for grid_size in diff_grid_sizes
|
|
],
|
|
torch.rand(
|
|
(num_volumes, num_channels, K + 1, K, K),
|
|
device=device,
|
|
dtype=torch.float32,
|
|
),
|
|
None,
|
|
]
|
|
for bad_features_new_ in bad_features_new:
|
|
with self.assertRaises(ValueError):
|
|
volumes._set_densities(bad_features_new_)
|
|
|
|
# test density setters
|
|
densities_new = [
|
|
torch.rand((1, *grid_size), device=device, dtype=torch.float32)
|
|
for grid_size in grid_sizes
|
|
]
|
|
volumes._set_densities(densities_new)
|
|
densities_new_list = volumes.densities_list()
|
|
densities_new_padded = volumes.densities()
|
|
for x_pad, x_list in zip(
|
|
(densities_new_padded, densities_new_padded),
|
|
(densities_new, densities_new_list),
|
|
):
|
|
self._check_padded(x_pad, x_list, grid_sizes)
|
|
|
|
# wrong densities to update
|
|
bad_densities_new = [
|
|
[
|
|
torch.rand((1, *grid_size), device=diff_device, dtype=torch.float32)
|
|
for grid_size in diff_grid_sizes
|
|
],
|
|
torch.rand(
|
|
(num_volumes, 1, K + 1, K, K), device=device, dtype=torch.float32
|
|
),
|
|
None,
|
|
]
|
|
for bad_densities_new_ in bad_densities_new:
|
|
with self.assertRaises(ValueError):
|
|
volumes._set_densities(bad_densities_new_)
|
|
|
|
# test update_padded
|
|
volumes = Volumes(densities=densities, features=features)
|
|
volumes_updated = volumes.update_padded(
|
|
densities_new, new_features=features_new
|
|
)
|
|
densities_new_list = volumes_updated.densities_list()
|
|
densities_new_padded = volumes_updated.densities()
|
|
features_new_list = volumes_updated.features_list()
|
|
features_new_padded = volumes_updated.features()
|
|
for x_pad, x_list in zip(
|
|
(
|
|
densities_new_padded,
|
|
densities_new_padded,
|
|
features_new_padded,
|
|
features_new_padded,
|
|
),
|
|
(densities_new, densities_new_list, features_new, features_new_list),
|
|
):
|
|
self._check_padded(x_pad, x_list, grid_sizes)
|
|
self.assertIs(volumes.get_grid_sizes(), volumes_updated.get_grid_sizes())
|
|
self.assertIs(
|
|
volumes.get_local_to_world_coords_transform(),
|
|
volumes_updated.get_local_to_world_coords_transform(),
|
|
)
|
|
self.assertIs(volumes.device, volumes_updated.device)
|
|
|
|
def test_constructor_for_padded_lists(self):
|
|
"""
|
|
Tests constructor for padded/list representations.
|
|
"""
|
|
|
|
device = torch.device("cuda:0")
|
|
diff_device = torch.device("cpu")
|
|
|
|
num_volumes = 3
|
|
num_channels = 4
|
|
size = (6, 8, 10)
|
|
diff_size = (6, 8, 11)
|
|
|
|
# good ways to define densities
|
|
ok_densities = [
|
|
torch.randn(
|
|
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
|
).unbind(0),
|
|
torch.randn(
|
|
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
|
),
|
|
]
|
|
|
|
# bad ways to define features
|
|
bad_features = [
|
|
torch.randn(
|
|
size=[num_volumes + 1, num_channels, *size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
).unbind(0), # list with diff batch size
|
|
torch.randn(
|
|
size=[num_volumes + 1, num_channels, *size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
), # diff batch size
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, *diff_size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
).unbind(0), # list with different size
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, *diff_size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
), # different size
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, *size],
|
|
device=diff_device,
|
|
dtype=torch.float32,
|
|
), # different device
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, *size],
|
|
device=diff_device,
|
|
dtype=torch.float32,
|
|
).unbind(0), # list with different device
|
|
]
|
|
|
|
# good ways to define features
|
|
ok_features = [
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, *size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
).unbind(0), # list of features of correct size
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, *size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
),
|
|
]
|
|
|
|
for densities in ok_densities:
|
|
for features in bad_features:
|
|
self.assertRaises(
|
|
ValueError, Volumes, densities=densities, features=features
|
|
)
|
|
for features in ok_features:
|
|
Volumes(densities=densities, features=features)
|
|
|
|
def test_constructor(
|
|
self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
|
|
):
|
|
"""
|
|
Test different ways of calling the `Volumes` constructor
|
|
"""
|
|
|
|
device = torch.device("cuda:0")
|
|
|
|
# all ways to define features
|
|
features = [
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, *size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
), # padded tensor
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, *size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
).unbind(0), # list of features
|
|
None, # no features
|
|
]
|
|
|
|
# bad ways to define features
|
|
bad_features = [
|
|
torch.randn(
|
|
size=[num_volumes, num_channels, 2, *size],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
), # 6 dims
|
|
torch.randn(
|
|
size=[num_volumes, *size], device=device, dtype=torch.float32
|
|
), # 4 dims
|
|
torch.randn(
|
|
size=[num_volumes, *size], device=device, dtype=torch.float32
|
|
).unbind(0), # list of 4 dim tensors
|
|
]
|
|
|
|
# all ways to define densities
|
|
densities = [
|
|
torch.randn(
|
|
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
|
), # padded tensor
|
|
torch.randn(
|
|
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
|
).unbind(0), # list of densities
|
|
]
|
|
|
|
# bad ways to define densities
|
|
bad_densities = [
|
|
None, # omitted
|
|
torch.randn(
|
|
size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32
|
|
), # 6-dim tensor
|
|
torch.randn(
|
|
size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32
|
|
).unbind(0), # list of 5-dim densities
|
|
]
|
|
|
|
# all possible ways to define the voxels sizes
|
|
vox_sizes = [
|
|
torch.Tensor([1.0, 1.0, 1.0]),
|
|
[1.0, 1.0, 1.0],
|
|
torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1),
|
|
torch.Tensor([1.0])[None].repeat(num_volumes, 1),
|
|
1.0,
|
|
torch.Tensor([1.0]),
|
|
]
|
|
|
|
# all possible ways to define the volume translations
|
|
vol_translations = [
|
|
torch.Tensor([1.0, 1.0, 1.0]),
|
|
[1.0, 1.0, 1.0],
|
|
torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1),
|
|
]
|
|
|
|
# wrong ways to define voxel sizes
|
|
bad_vox_sizes = [
|
|
torch.Tensor([1.0, 1.0, 1.0, 1.0]),
|
|
[1.0, 1.0, 1.0, 1.0],
|
|
torch.Tensor([]),
|
|
None,
|
|
]
|
|
|
|
# wrong ways to define the volume translations
|
|
bad_vol_translations = [
|
|
torch.Tensor([1.0, 1.0]),
|
|
[1.0, 1.0],
|
|
1.0,
|
|
torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes + 1, 1),
|
|
]
|
|
|
|
def zip_with_ok_indicator(good, bad):
|
|
return zip([*good, *bad], [*([True] * len(good)), *([False] * len(bad))])
|
|
|
|
for features_, features_ok in zip_with_ok_indicator(features, bad_features):
|
|
for densities_, densities_ok in zip_with_ok_indicator(
|
|
densities, bad_densities
|
|
):
|
|
for vox_size, size_ok in zip_with_ok_indicator(
|
|
vox_sizes, bad_vox_sizes
|
|
):
|
|
for vol_translation, trans_ok in zip_with_ok_indicator(
|
|
vol_translations, bad_vol_translations
|
|
):
|
|
if (
|
|
size_ok and trans_ok and features_ok and densities_ok
|
|
): # if all entries are good we check that this doesnt throw
|
|
Volumes(
|
|
features=features_,
|
|
densities=densities_,
|
|
voxel_size=vox_size,
|
|
volume_translation=vol_translation,
|
|
)
|
|
|
|
else: # otherwise we check for ValueError
|
|
self.assertRaises(
|
|
ValueError,
|
|
Volumes,
|
|
features=features_,
|
|
densities=densities_,
|
|
voxel_size=vox_size,
|
|
volume_translation=vol_translation,
|
|
)
|