diff --git a/pytorch3d/structures/__init__.py b/pytorch3d/structures/__init__.py index e83db39e..bfabd3ec 100644 --- a/pytorch3d/structures/__init__.py +++ b/pytorch3d/structures/__init__.py @@ -3,6 +3,7 @@ from .meshes import Meshes, join_meshes_as_batch, join_meshes_as_scene from .pointclouds import Pointclouds from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list +from .volumes import Volumes __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/structures/volumes.py b/pytorch3d/structures/volumes.py new file mode 100644 index 00000000..83597395 --- /dev/null +++ b/pytorch3d/structures/volumes.py @@ -0,0 +1,703 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import copy +from typing import List, Optional, Tuple, Union + +import torch + +from ..transforms import Scale, Transform3d +from . import utils as struct_utils + + +class Volumes(object): + """ + This class provides functions for working with batches of volumetric grids + of possibly varying spatial sizes. + + VOLUME DENSITIES + + The Volumes class can be either constructed from a 5D tensor of + `densities` of size `batch x density_dim x depth x height x width` or + from a list of differently-sized 4D tensors `[D_1, ..., D_batch]`, + where each `D_i` is of size `[density_dim x depth_i x height_i x width_i]`. + + In case the `Volumes` object is initialized from the list of `densities`, + the list of tensors is internally converted to a single 5D tensor by + zero-padding the relevant dimensions. Both list and padded representations can be + accessed with the `Volumes.densities()` or `Volumes.densities_list()` getters. + The sizes of the individual volumes in the structure can be retrieved + with the `Volumes.get_grid_sizes()` getter. + + The `Volumes` class is immutable. I.e. after generating a `Volumes` object, + one cannot change its properties, such as `self._densities` or `self._features` + anymore. + + + VOLUME FEATURES + + While the `densities` field is intended to represent various measures of the + "density" of the volume cells (opacity, signed/unsigned distances + from the nearest surface, ...), one can additionally initialize the + object with the `features` argument. `features` are either a 5D tensor + of shape `batch x feature_dim x depth x height x width` or a list of + of differently-sized 4D tensors `[F_1, ..., F_batch]`, + where each `F_i` is of size `[feature_dim x depth_i x height_i x width_i]`. + `features` are intended to describe other properties of volume cells, + such as per-voxel 3D vectors of RGB colors that can be later used + for rendering the volume. + + + VOLUME COORDINATES + + Additionally, the `Volumes` class keeps track of the locations of the + centers of the volume cells in the local volume coordinates as well as in + the world coordinates. + + Local coordinates: + - Represent the locations of the volume cells in the local coordinate + frame of the volume. + - The center of the voxel indexed with `[·, ·, 0, 0, 0]` in the volume + has its 3D local coordinate set to `[-1, -1, -1]`, while the voxel + at index `[·, ·, depth_i-1, height_i-1, width_i-1]` has its + 3D local coordinate set to `[1, 1, 1]`. + - The first/second/third coordinate of each of the 3D per-voxel + XYZ vector denotes the horizontal/vertical/depth-wise position + respectively. I.e the order of the coordinate dimensions in the + volume is reversed w.r.t. the order of the 3D coordinate vectors. + - The intermediate coordinates between `[-1, -1, -1]` and `[1, 1, 1]`. + are linearly interpolated over the spatial dimensions of the volume. + - Note that the convention is the same as for the 5D version of the + `torch.nn.functional.grid_sample` function called with + `align_corners==True`. + - Note that the local coordinate convention of `Volumes` + (+X = left to right, +Y = top to bottom, +Z = away from the user) + is *different* from the world coordinate convention of the + renderer for `Meshes` or `Pointclouds` + (+X = right to left, +Y = bottom to top, +Z = away from the user). + + World coordinates: + - These define the locations of the centers of the volume cells + in the world coordinates. + - They are specifiied with the following mapping that converts + points `x_local` in the local coordinates to points `x_world` + in the world coordinates: + ``` + x_world = ( + x_local * (volume_size - 1) * 0.5 * voxel_size + ) - volume_translation, + ``` + here `voxel_size` specifies the size of each voxel of the volume, + and `volume_translation` is the 3D offset of the central voxel of + the volume w.r.t. the origin of the world coordinate frame. + Both `voxel_size` and `volume_translation` are specified in + the world coordinate units. `volume_size` is the spatial size of + the volume in form of a 3D vector `[width, height, depth]`. + - Given the above definition of `x_world`, one can derive the + inverse mapping from `x_world` to `x_local` as follows: + ``` + x_local = ( + (x_world + volume_translation) / (0.5 * voxel_size) + ) / (volume_size - 1) + ``` + - For a trivial volume with `volume_translation==[0, 0, 0]` + with `voxel_size=-1`, `x_world` would range + from -(volume_size-1)/2` to `+(volume_size-1)/2`. + + Coordinate tensors that denote the locations of each of the volume cells in + local / world coordinates (with shape `(depth x height x width x 3)`) + can be retrieved by calling the `Volumes.get_coord_grid()` getter with the + appropriate `world_coordinates` argument. + + Internally, the mapping between `x_local` and `x_world` is represented + as a `Transform3D` object `Volumes._local_to_world_transform`. + Users can access the relevant transformations with the + `Volumes.get_world_to_local_coords_transform()` and + `Volumes.get_local_to_world_coords_transform()` + functions. + + Example coordinate conversion: + - For a "trivial" volume with `voxel_size = 1.`, + `volume_translation=[0., 0., 0.]`, and the spatial size of + `DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped + to `x_local=(-1, 0, 1)`. + - For a "trivial" volume `v` with `voxel_size = 1.`, + `volume_translation=[0., 0., 0.]`, the following holds: + ``` + torch.nn.functional.grid_sample( + v.densities(), + v.get_coord_grid(world_coordinates=False), + align_corners=True, + ) == v.densities(), + ``` + i.e. sampling the volume at trivial local coordinates + (no scaling with `voxel_size`` or shift with `volume_translation`) + results in the same volume. + """ + + def __init__( + self, + densities: Union[List[torch.Tensor], torch.Tensor], + features: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, + voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0, + volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0), + ): + """ + Args: + **densities**: Batch of input feature volume occupancies of shape + `(minibatch, density_dim, depth, height, width)`, or a list + of 4D tensors `[D_1, ..., D_minibatch]` where each `D_i` has + shape `(density_dim, depth_i, height_i, width_i)`. + Typically, each voxel contains a non-negative number + corresponding to its opaqueness. + **features**: Batch of input feature volumes of shape: + `(minibatch, feature_dim, depth, height, width)` or a list + of 4D tensors `[F_1, ..., F_minibatch]` where each `F_i` has + shape `(feature_dim, depth_i, height_i, width_i)`. + The field is optional and can be set to `None` in case features are + not required. + **voxel_size**: Denotes the size of each volume voxel in world units. + Has to be one of: + a) A scalar (square voxels) + b) 3-tuple or a 3-list of scalars + c) a Tensor of shape (3,) + d) a Tensor of shape (minibatch, 3) + e) a Tensor of shape (minibatch, 1) + f) a Tensor of shape (1,) (square voxels) + **volume_translation**: Denotes the 3D translation of the center + of the volume in world units. Has to be one of: + a) 3-tuple or a 3-list of scalars + b) a Tensor of shape (3,) + c) a Tensor of shape (minibatch, 3) + d) a Tensor of shape (1,) (square voxels) + """ + + # handle densities + densities, grid_sizes = self._convert_densities_features_to_tensor( + densities, "densities" + ) + + # take device from densities + self.device = densities.device + + # assign to the internal buffers + self._densities = densities + self._grid_sizes = grid_sizes + + # handle features + self._features = None + if features is not None: + self._set_features(features) + + # set the local_to_world transform + self._set_local_to_world_transform( + voxel_size=voxel_size, volume_translation=volume_translation + ) + + def _convert_densities_features_to_tensor( + self, x: Union[List[torch.Tensor], torch.Tensor], var_name: str + ): + """ + Handle the `densities` or `features` arguments to the constructor. + """ + if isinstance(x, (list, tuple)): + x_tensor = struct_utils.list_to_padded(x) + if any(x_.ndim != 4 for x_ in x): + raise ValueError( + f"`{var_name}` has to be a list of 4-dim tensors of shape: " + f"({var_name}_dim, height, width, depth)" + ) + if any(x_.shape[0] != x[0].shape[0] for x_ in x): + raise ValueError( + f"Each entry in the list of `{var_name}` has to have the " + "same number of channels (first dimension in the tensor)." + ) + x_shapes = torch.stack( + [ + torch.tensor( + list(x_.shape[1:]), dtype=torch.long, device=x_tensor.device + ) + for x_ in x + ], + dim=0, + ) + elif torch.is_tensor(x): + if x.ndim != 5: + raise ValueError( + f"`{var_name}` has to be a 5-dim tensor of shape: " + f"(minibatch, {var_name}_dim, height, width, depth)" + ) + x_tensor = x + x_shapes = torch.tensor( + list(x.shape[2:]), dtype=torch.long, device=x.device + )[None].repeat(x.shape[0], 1) + else: + raise ValueError( + f"{var_name} must be either a list or a tensor with " + f"shape (batch_size, {var_name}_dim, H, W, D)." + ) + return x_tensor, x_shapes + + def _vsize_translation_to_transform( + self, voxel_size, volume_translation, batch_size + ): + """ + Converts the `voxel_size` and `volume_translation` constructor arguments + to the internal `Transform3D` object `local_to_world_transform`. + """ + volume_size_zyx = self.get_grid_sizes().float() + volume_size_xyz = volume_size_zyx[:, [2, 1, 0]] + + # x_local = ( + # (x_world + volume_translation) / (0.5 * voxel_size) + # ) / (volume_size - 1) + + # x_world = ( + # x_local * (volume_size - 1) * 0.5 * voxel_size + # ) - volume_translation + + local_to_world_transform = Scale( + (volume_size_xyz - 1) * voxel_size * 0.5, device=self.device + ).translate(-volume_translation) + + return local_to_world_transform + + def _handle_voxel_size(self, voxel_size, batch_size): + """ + Handle the `voxel_size` argument to the `Volumes` constructor. + """ + err_msg = ( + "voxel_size has to be either a 3-tuple of scalars, or a scalar, or" + " a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)." + ) + if isinstance(voxel_size, (float, int)): + # convert a scalar to a 3-element tensor + voxel_size = (voxel_size,) * 3 + if torch.is_tensor(voxel_size): + if voxel_size.numel() == 1: + # convert a single-element tensor to a 3-element one + voxel_size = voxel_size.view(-1).repeat(3) + elif len(voxel_size.shape) == 2 and ( + voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1 + ): + voxel_size = voxel_size.repeat(1, 3) + return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg) + + def _handle_volume_translation(self, translation, batch_size): + """ + Handle the `volume_translation` argument to the `Volumes` constructor. + """ + err_msg = ( + "`volume_translation` has to be either a 3-tuple of scalars, or" + " a Tensor of shape (1,3) or (minibatch, 3) or (3,)`." + ) + return self._convert_volume_property_to_tensor(translation, batch_size, err_msg) + + def _convert_volume_property_to_tensor( + self, x, batch_size, err_msg + ) -> torch.Tensor: + """ + Handle the `volume_translation` or `voxel_size` argument to + the Volumes constructor. + Return a tensor of shape (N, 3) where N is the batch_size or 1 + if batch_size is None. + """ + if isinstance(x, (list, tuple)): + if len(x) != 3: + raise ValueError(err_msg) + x = torch.tensor(x, device=self.device, dtype=torch.float32)[None] + x = x.repeat((batch_size, 1)) + elif torch.is_tensor(x): + ok = ( + (x.shape[0] == 1 and x.shape[1] == 3) + or (x.shape[0] == 3 and len(x.shape) == 1) + or (x.shape[0] == batch_size and x.shape[1] == 3) + ) + if not ok: + raise ValueError(err_msg) + if x.device != self.device: + x = x.to(self.device) + if x.shape[0] == 3 and len(x.shape) == 1: + x = x[None] + if x.shape[0] == 1: + x = x.repeat((batch_size, 1)) + else: + raise ValueError(err_msg) + + return x + + def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor: + """ + Return the 3D coordinate grid of the volumetric grid + in local (`world_coordinates=False`) or world coordinates + (`world_coordinates=True`). + + The grid records location of each center of the corresponding volume voxel. + + Local coordinates are scaled s.t. the values along one side of the + volume are in range [-1, 1]. + + Args: + **world_coordinates**: if `True`, the method + returns the grid in the world coordinates, + otherwise, in local coordinates. + + Returns: + **coordinate_grid**: The grid of coordinates of shape + `(minibatch, depth, height, width, 3)`, where `minibatch`, + `height`, `width` and `depth` are the batch size, height, width + and depth of the volume `features` or `densities`. + """ + # TODO(dnovotny): Implement caching of the coordinate grid. + return self._calculate_coordinate_grid(world_coordinates=world_coordinates) + + def _calculate_coordinate_grid( + self, world_coordinates: bool = True + ) -> torch.Tensor: + """ + Calculate the 3D coordinate grid of the volumetric grid either in + in local (`world_coordinates=False`) or + world coordinates (`world_coordinates=True`) . + """ + + densities = self.densities() + ba, _, de, he, wi = densities.shape + grid_sizes = self.get_grid_sizes() + + # generate coordinate axes + vol_axes = [ + torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device) + for r in (de, he, wi) + ] + + # generate per-coord meshgrids + Z, Y, X = torch.meshgrid(vol_axes) + + # stack the coord grids ... this order matches the coordinate convention + # of torch.nn.grid_sample + vol_coords_local = torch.stack((X, Y, Z), dim=3)[None].repeat(ba, 1, 1, 1, 1) + + # get grid sizes relative to the maximal volume size + grid_sizes_relative = ( + torch.tensor([[de, he, wi]], device=grid_sizes.device, dtype=torch.float32) + - 1 + ) / (grid_sizes - 1).float() + + if (grid_sizes_relative != 1.0).any(): + # if any of the relative sizes != 1.0, adjust the grid + grid_sizes_relative_reshape = grid_sizes_relative[:, [2, 1, 0]][ + :, None, None, None + ] + vol_coords_local *= grid_sizes_relative_reshape + vol_coords_local += grid_sizes_relative_reshape - 1 + + if world_coordinates: + vol_coords = self.local_to_world_coords(vol_coords_local) + else: + vol_coords = vol_coords_local + + return vol_coords + + def get_local_to_world_coords_transform(self) -> Transform3d: + """ + Return a Transform3d object that converts points in the + the local coordinate frame of the volume to world coordinates. + Local volume coordinates are scaled s.t. the coordinates along one + side of the volume are in range [-1, 1]. + + Returns: + **local_to_world_transform**: A Transform3d object converting + points from local coordinates to the world coordinates. + """ + return self._local_to_world_transform + + def get_world_to_local_coords_transform(self) -> Transform3d: + """ + Return a Transform3d object that converts points in the + world coordinates to the local coordinate frame of the volume. + Local volume coordinates are scaled s.t. the coordinates along one + side of the volume are in range [-1, 1]. + + Returns: + **world_to_local_transform**: A Transform3d object converting + points from world coordinates to local coordinates. + """ + return self.get_local_to_world_coords_transform().inverse() + + def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor: + """ + Convert a batch of 3D point coordinates `points_3d_world` of shape + (minibatch, ..., dim) in the world coordinates to + the local coordinate frame of the volume. Local volume + coordinates are scaled s.t. the coordinates along one side of the volume + are in range [-1, 1]. + + Args: + **points_3d_world**: A tensor of shape `(minibatch, ..., 3)` + containing the 3D coordinates of a set of points that will + be converted from the local volume coordinates (ranging + within [-1, 1]) to the world coordinates + defined by the `self.center` and `self.voxel_size` parameters. + + Returns: + **points_3d_local**: `points_3d_world` converted to the local + volume coordinates of shape `(minibatch, ..., 3)`. + """ + pts_shape = points_3d_world.shape + return ( + self.get_world_to_local_coords_transform() + .transform_points(points_3d_world.view(pts_shape[0], -1, 3)) + .view(pts_shape) + ) + + def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor: + """ + Convert a batch of 3D point coordinates `points_3d_local` of shape + (minibatch, ..., dim) in the local coordinate frame of the volume + to the world coordinates. + + Args: + **points_3d_local**: A tensor of shape `(minibatch, ..., 3)` + containing the 3D coordinates of a set of points that will + be converted from the local volume coordinates (ranging + within [-1, 1]) to the world coordinates + defined by the `self.center` and `self.voxel_size` parameters. + + Returns: + **points_3d_world**: `points_3d_local` converted to the world + coordinates of the volume of shape `(minibatch, ..., 3)`. + """ + pts_shape = points_3d_local.shape + return ( + self.get_local_to_world_coords_transform() + .transform_points(points_3d_local.view(pts_shape[0], -1, 3)) + .view(pts_shape) + ) + + def __len__(self) -> int: + return self._densities.shape[0] + + def __getitem__( + self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor] + ) -> "Volumes": + """ + Args: + index: Specifying the index of the volume to retrieve. + Can be an int, slice, list of ints or a boolean or a long tensor. + + Returns: + Volumes object with selected volumes. The tensors are not cloned. + """ + if isinstance(index, int): + index = torch.LongTensor([index]) + elif isinstance(index, (slice, list, tuple)): + pass + elif torch.is_tensor(index): + if index.dim() != 1 or index.dtype.is_floating_point: + raise IndexError(index) + else: + raise IndexError(index) + + new = self.__class__( + features=self.features()[index] if self._features is not None else None, + densities=self.densities()[index], + ) + # dont forget to update grid_sizes! + new._grid_sizes = self.get_grid_sizes()[index] + new._local_to_world_transform = self._local_to_world_transform[index] + return new + + def features(self) -> Optional[torch.Tensor]: + """ + Returns the features of the volume. + + Returns: + **features**: The tensor of volume features. + """ + return self._features + + def densities(self) -> torch.Tensor: + """ + Returns the densities of the volume. + + Returns: + **densities**: The tensor of volume densities. + """ + return self._densities + + def densities_list(self) -> List[torch.Tensor]: + """ + Get the list representation of the densities. + + Returns: + list of tensors of densities of shape (dim_i, D_i, H_i, W_i). + """ + return self._features_densities_list(self.densities()) + + def features_list(self) -> List[torch.Tensor]: + """ + Get the list representation of the features. + + Returns: + list of tensors of features of shape (dim_i, D_i, H_i, W_i) + or `None` for feature-less volumes. + """ + if self._features is None: + # No features provided so return None + return None + return self._features_densities_list(self.features()) + + def _features_densities_list(self, x) -> List[torch.Tensor]: + """ + Retrieve the list representation of features/densities. + + Args: + x: self.features() or self.densities() + + Returns: + list of tensors of features/densities of shape (dim_i, D_i, H_i, W_i). + """ + x_dim = x.shape[1] + pad_sizes = torch.nn.functional.pad( + self.get_grid_sizes(), [1, 0], mode="constant", value=x_dim + ) + x_list = struct_utils.padded_to_list(x, pad_sizes.tolist()) + return x_list + + def get_grid_sizes(self) -> torch.LongTensor: + """ + Returns the sizes of individual volumetric grids in the structure. + + Returns: + **grid_sizes**: Tensor of spatial sizes of each of the volumes + of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i). + """ + return self._grid_sizes + + def update_padded(self, new_densities, new_features=None) -> "Volumes": + """ + Returns a Volumes structure with updated padded tensors and copies of + the auxiliary tensors `self._local_to_world_transform`, + `device` and `self._grid_sizes`. This function allows for an update of + densities (and features) without having to explicitly + convert it to the list representation for heterogeneous batches. + + Args: + new_densities: FloatTensor of shape (N, dim_density, D, H, W) + new_features: (optional) FloatTensor of shape (N, dim_feature, D, H, W) + + Returns: + Volumes with updated features and densities + """ + new = copy.copy(self) + new._set_densities(new_densities) + if new_features is None: + new._features = None + else: + new._set_features(new_features) + return new + + def _set_features(self, features): + self._set_densities_features(features, "features") + + def _set_densities(self, densities): + self._set_densities_features(densities, "densities") + + def _set_densities_features(self, x, var_name): + x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name) + if x_tensor.device != self.device: + raise ValueError( + f"`{var_name}` have to be on the same device as `self.densities`." + ) + if len(x_tensor.shape) != 5: + raise ValueError( + f"{var_name} has to be a 5-dim tensor of shape: " + f"(minibatch, {var_name}_dim, height, width, depth)" + ) + + if not ( + (self.get_grid_sizes().shape == grid_sizes.shape) + and torch.allclose(self.get_grid_sizes(), grid_sizes) + ): + raise ValueError( + f"The size of every grid in `{var_name}` has to match the size of" + "the corresponding `densities` grid." + ) + setattr(self, "_" + var_name, x_tensor) + + def _set_local_to_world_transform( + self, + voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0, + volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0), + ): + """ + Sets the internal representation of the transformation between the + world and local volume coordinates by specifying + `voxel_size` and `volume_translation` + + Args: + **voxel_size**: Denotes the size of input voxels. Has to be one of: + a) A scalar (square voxels) + b) 3-tuple or a 3-list of scalars + c) a Tensor of shape (3,) + d) a Tensor of shape (minibatch, 3) + e) a Tensor of shape (1,) (square voxels) + **volume_translation**: Denotes the 3D translation of the center + of the volume in world units. Has to be one of: + a) 3-tuple or a 3-list of scalars + b) a Tensor of shape (3,) + c) a Tensor of shape (minibatch, 3) + d) a Tensor of shape (1,) (square voxels) + """ + # handle voxel size and center + # here we force the tensors to lie on self.device + voxel_size = self._handle_voxel_size(voxel_size, len(self)) + volume_translation = self._handle_volume_translation( + volume_translation, len(self) + ) + self._local_to_world_transform = self._vsize_translation_to_transform( + voxel_size, volume_translation, len(self) + ) + + def clone(self) -> "Volumes": + """ + Deep copy of Volumes object. All internal tensors are cloned + individually. + + Returns: + new Volumes object. + """ + return copy.deepcopy(self) + + def to(self, device, copy: bool = False) -> "Volumes": + """ + Match the functionality of torch.Tensor.to() + If copy = True or the self Tensor is on a different device, the + returned tensor is a copy of self with the desired torch.device. + If copy = False and the self Tensor already has the correct torch.device, + then self is returned. + + Args: + **device**: Device id for the new tensor. + **copy**: Boolean indicator whether or not to clone self. Default False. + + Returns: + Volumes object. + """ + if not copy and self.device == device: + return self + other = self.clone() + if self.device != device: + other.device = device + other._densities = self._densities.to(device) + if self._features is not None: + other._features = self.features().to(device) + other._local_to_world_transform = ( + self.get_local_to_world_coords_transform().to(device) + ) + other._grid_sizes = self._grid_sizes.to(device) + return other + + def cpu(self) -> "Volumes": + return self.to(torch.device("cpu")) + + def cuda(self) -> "Volumes": + return self.to(torch.device("cuda")) diff --git a/tests/test_volumes.py b/tests/test_volumes.py new file mode 100644 index 00000000..fd7bc27e --- /dev/null +++ b/tests/test_volumes.py @@ -0,0 +1,849 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import copy +import itertools +import random +import unittest + +import numpy as np +import torch +from common_testing import TestCaseMixin +from pytorch3d.structures.volumes import Volumes +from pytorch3d.transforms import Scale + + +class TestVolumes(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + np.random.seed(42) + torch.manual_seed(42) + random.seed(42) + + @staticmethod + def _random_volume_list( + num_volumes, min_size, max_size, num_channels, device, rand_sizes=None + ): + """ + Init a list of `num_volumes` random tensors of size [num_channels, *rand_size]. + If `rand_sizes` is None, rand_size is a 3D long vector sampled + from [min_size, max_size]. Otherwise, rand_size should be a list + [rand_size_1, rand_size_2, ..., rand_size_num_volumes] where each + `rand_size_i` denotes the size of the corresponding `i`-th tensor. + """ + if rand_sizes is None: + rand_sizes = [ + [random.randint(min_size, vs) for vs in max_size] + for _ in range(num_volumes) + ] + + volume_list = [ + torch.randn( + size=[num_channels, *rand_size], device=device, dtype=torch.float32 + ) + for rand_size in rand_sizes + ] + + return volume_list, rand_sizes + + def _check_indexed_volumes(self, v, selected, indices): + for selectedIdx, index in indices: + self.assertClose(selected.densities()[selectedIdx], v.densities()[index]) + self.assertClose( + v._local_to_world_transform.get_matrix()[index], + selected._local_to_world_transform.get_matrix()[selectedIdx], + ) + if selected.features() is not None: + self.assertClose(selected.features()[selectedIdx], v.features()[index]) + + def test_get_item( + self, + num_volumes=5, + num_channels=4, + volume_size=(10, 13, 8), + dtype=torch.float32, + ): + + device = torch.device("cuda:0") + + # make sure we have at least 3 volumes to prevent indexing crash + num_volumes = max(num_volumes, 3) + + features = torch.randn( + size=[num_volumes, num_channels, *volume_size], + device=device, + dtype=torch.float32, + ) + densities = torch.randn( + size=[num_volumes, 1, *volume_size], device=device, dtype=torch.float32 + ) + + features_list, rand_sizes = TestVolumes._random_volume_list( + num_volumes, 3, volume_size, num_channels, device + ) + densities_list, _ = TestVolumes._random_volume_list( + num_volumes, 3, volume_size, 1, device, rand_sizes=rand_sizes + ) + + volume_translation = -torch.randn(num_volumes, 3).type_as(features) + voxel_size = torch.rand(num_volumes, 1).type_as(features) + 0.5 + + for features_, densities_ in zip( + (None, features, features_list), (densities, densities, densities_list) + ): + + # init the volume structure + v = Volumes( + features=features_, + densities=densities_, + volume_translation=volume_translation, + voxel_size=voxel_size, + ) + + # int index + index = 1 + v_selected = v[index] + self.assertEqual(len(v_selected), 1) + self._check_indexed_volumes(v, v_selected, [(0, 1)]) + + # list index + index = [1, 2] + v_selected = v[index] + self.assertEqual(len(v_selected), len(index)) + self._check_indexed_volumes(v, v_selected, enumerate(index)) + + # slice index + index = slice(0, 2, 1) + v_selected = v[0:2] + self.assertEqual(len(v_selected), 2) + self._check_indexed_volumes(v, v_selected, [(0, 0), (1, 1)]) + + # bool tensor + index = (torch.rand(num_volumes) > 0.5).to(device) + index[:2] = True # make sure smth is selected + v_selected = v[index] + self.assertEqual(len(v_selected), index.sum()) + self._check_indexed_volumes( + v, + v_selected, + zip( + torch.arange(index.sum()), + torch.nonzero(index, as_tuple=False).squeeze(), + ), + ) + + # int tensor + index = torch.tensor([1, 2], dtype=torch.int64, device=device) + v_selected = v[index] + self.assertEqual(len(v_selected), index.numel()) + self._check_indexed_volumes(v, v_selected, enumerate(index.tolist())) + + # invalid index + index = torch.tensor([1, 0, 1], dtype=torch.float32, device=device) + with self.assertRaises(IndexError): + v_selected = v[index] + index = 1.2 # floating point index + with self.assertRaises(IndexError): + v_selected = v[index] + + def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32): + """ + Test the correctness of the conversion between the internal + Transform3D Volumes._local_to_world_transform and the initialization + from the translation and voxel_size. + """ + + device = torch.device("cuda:0") + + # try for 10 sets of different random sizes/centers/voxel_sizes + for _ in range(10): + + size = torch.randint(high=10, size=(3,), low=3).tolist() + + densities = torch.randn( + size=[num_volumes, num_channels, *size], + device=device, + dtype=torch.float32, + ) + + # init the transformation params + volume_translation = torch.randn(num_volumes, 3) + voxel_size = torch.rand(num_volumes, 3) * 3.0 + 0.5 + + # get the corresponding Transform3d object + local_offset = torch.tensor(list(size), dtype=torch.float32, device=device)[ + [2, 1, 0] + ][None].repeat(num_volumes, 1) + local_to_world_transform = ( + Scale(0.5 * local_offset - 0.5, device=device) + .scale(voxel_size) + .translate(-volume_translation) + ) + + # init the volume structures with the scale and translation, + # then get the coord grid in world coords + v_trans_vs = Volumes( + densities=densities, + voxel_size=voxel_size, + volume_translation=volume_translation, + ) + grid_rot_trans_vs = v_trans_vs.get_coord_grid(world_coordinates=True) + + # map the default local coords to the world coords + # with local_to_world_transform + v_default = Volumes(densities=densities) + grid_default_local = v_default.get_coord_grid(world_coordinates=False) + grid_default_world = local_to_world_transform.transform_points( + grid_default_local.view(num_volumes, -1, 3) + ).view(num_volumes, *size, 3) + + # check that both grids are the same + self.assertClose(grid_rot_trans_vs, grid_default_world, atol=1e-5) + + # check that the transformations are the same + self.assertClose( + v_trans_vs.get_local_to_world_coords_transform().get_matrix(), + local_to_world_transform.get_matrix(), + atol=1e-5, + ) + + def test_coord_grid_convention( + self, num_volumes=3, num_channels=4, dtype=torch.float32 + ): + """ + Check that for a trivial volume with spatial size DxHxW=5x7x5: + 1) xyz_world=(0, 0, 0) lands right in the middle of the volume + with xyz_local=(0, 0, 0). + 2) xyz_world=(-2, 3, 2) results in xyz_local=(-1, 1, -1). + 3) The centeral voxel of the volume coordinate grid + has coords x_world=(0, 0, 0) and x_local=(0, 0, 0) + 4) grid_sampler(world_coordinate_grid, local_coordinate_grid) + is the same as world_coordinate_grid itself. I.e. the local coordinate + grid matches the grid_sampler coordinate convention. + """ + + device = torch.device("cuda:0") + + densities = torch.randn( + size=[num_volumes, num_channels, 5, 7, 5], + device=device, + dtype=torch.float32, + ) + v_trivial = Volumes(densities=densities) + + # check the case with x_world=(0,0,0) + pts_world = torch.zeros(num_volumes, 1, 3, device=device, dtype=torch.float32) + pts_local = v_trivial.world_to_local_coords(pts_world) + pts_local_expected = torch.zeros_like(pts_local) + self.assertClose(pts_local, pts_local_expected) + + # check the case with x_world=(-2, 3, -2) + pts_world = torch.tensor([-2, 3, -2], device=device, dtype=torch.float32)[ + None, None + ].repeat(num_volumes, 1, 1) + pts_local = v_trivial.world_to_local_coords(pts_world) + pts_local_expected = torch.tensor( + [-1, 1, -1], device=device, dtype=torch.float32 + )[None, None].repeat(num_volumes, 1, 1) + self.assertClose(pts_local, pts_local_expected) + + # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0) + grid_world = v_trivial.get_coord_grid(world_coordinates=True) + grid_local = v_trivial.get_coord_grid(world_coordinates=False) + for grid in (grid_world, grid_local): + x0 = grid[0, :, :, 2, 0] + y0 = grid[0, :, 3, :, 1] + z0 = grid[0, 2, :, :, 2] + for coord_line in (x0, y0, z0): + self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7) + + # resample grid_world using grid_sampler with local coords + # -> make sure the resampled version is the same as original + grid_world_resampled = torch.nn.functional.grid_sample( + grid_world.permute(0, 4, 1, 2, 3), grid_local, align_corners=True + ).permute(0, 2, 3, 4, 1) + self.assertClose(grid_world_resampled, grid_world, atol=1e-7) + + def test_coord_grid_convention_heterogeneous( + self, num_channels=4, dtype=torch.float32 + ): + """ + Check that for a list of 2 trivial volumes with + spatial sizes DxHxW=(5x7x5, 3x5x5): + 1) xyz_world=(0, 0, 0) lands right in the middle of the volume + with xyz_local=(0, 0, 0). + 2) xyz_world=((-2, 3, -2), (-2, -2, 1)) results + in xyz_local=((-1, 1, -1), (-1, -1, 1)). + 3) The centeral voxel of the volume coordinate grid + has coords x_world=(0, 0, 0) and x_local=(0, 0, 0) + 4) grid_sampler(world_coordinate_grid, local_coordinate_grid) + is the same as world_coordinate_grid itself. I.e. the local coordinate + grid matches the grid_sampler coordinate convention. + """ + + device = torch.device("cuda:0") + + sizes = [(5, 7, 5), (3, 5, 5)] + + densities_list = [ + torch.randn(size=[num_channels, *size], device=device, dtype=torch.float32) + for size in sizes + ] + + # init the volume + v_trivial = Volumes(densities=densities_list) + + # check the border point locations + pts_world = torch.tensor( + [[-2.0, 3.0, -2.0], [-2.0, -2.0, 1.0]], device=device, dtype=torch.float32 + )[:, None] + pts_local = v_trivial.world_to_local_coords(pts_world) + pts_local_expected = torch.tensor( + [[-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0]], device=device, dtype=torch.float32 + )[:, None] + self.assertClose(pts_local, pts_local_expected) + + # check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0) + grid_world = v_trivial.get_coord_grid(world_coordinates=True) + grid_local = v_trivial.get_coord_grid(world_coordinates=False) + for grid in (grid_world, grid_local): + x0 = grid[0, :, :, 2, 0] + y0 = grid[0, :, 3, :, 1] + z0 = grid[0, 2, :, :, 2] + for coord_line in (x0, y0, z0): + self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7) + x0 = grid[1, :, :, 2, 0] + y0 = grid[1, :, 2, :, 1] + z0 = grid[1, 1, :, :, 2] + for coord_line in (x0, y0, z0): + self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7) + + # resample grid_world using grid_sampler with local coords + # -> make sure the resampled version is the same as original + for grid_world_, grid_local_, size in zip(grid_world, grid_local, sizes): + grid_world_crop = grid_world_[: size[0], : size[1], : size[2], :][None] + grid_local_crop = grid_local_[: size[0], : size[1], : size[2], :][None] + grid_world_crop_resampled = torch.nn.functional.grid_sample( + grid_world_crop.permute(0, 4, 1, 2, 3), + grid_local_crop, + align_corners=True, + ).permute(0, 2, 3, 4, 1) + self.assertClose(grid_world_crop_resampled, grid_world_crop, atol=1e-7) + + def test_coord_grid_transforms( + self, num_volumes=3, num_channels=4, dtype=torch.float32 + ): + """ + Test whether conversion between local-world coordinates of the + volume returns correct results. + """ + + device = torch.device("cuda:0") + + # try for 10 sets of different random sizes/centers/voxel_sizes + for _ in range(10): + + size = torch.randint(high=10, size=(3,), low=3).tolist() + + center = torch.randn(num_volumes, 3, dtype=torch.float32, device=device) + voxel_size = torch.rand(1, dtype=torch.float32, device=device) * 5.0 + 0.5 + + for densities in ( + torch.randn( + size=[num_volumes, num_channels, *size], + device=device, + dtype=torch.float32, + ), + TestVolumes._random_volume_list( + num_volumes, 3, size, num_channels, device, rand_sizes=None + )[0], + ): + + # init the volume structure + v = Volumes( + densities=densities, + voxel_size=voxel_size, + volume_translation=-center, + ) + + # get local coord grid + grid_local = v.get_coord_grid(world_coordinates=False) + + # convert from world to local to world + grid_world = v.get_coord_grid(world_coordinates=True) + grid_local_2 = v.world_to_local_coords(grid_world) + grid_world_2 = v.local_to_world_coords(grid_local_2) + + # assertions on shape and values of grid_world and grid_local + self.assertClose(grid_world, grid_world_2, atol=1e-5) + self.assertClose(grid_local, grid_local_2, atol=1e-5) + + # check that the individual slices of the location grid have + # constant values along expected dimensions + for plane_dim in (1, 2, 3): + for grid_plane in grid_world.split(1, dim=plane_dim): + grid_coord_dim = {1: 2, 2: 1, 3: 0}[plane_dim] + grid_coord_plane = grid_plane.squeeze()[..., grid_coord_dim] + # check that all elements of grid_coord_plane are + # the same for each batch element + self.assertClose( + grid_coord_plane.reshape(num_volumes, -1).max(dim=1).values, + grid_coord_plane.reshape(num_volumes, -1).min(dim=1).values, + ) + + def test_clone( + self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32 + ): + """ + Test cloning of a `Volumes` object + """ + + device = torch.device("cuda:0") + + features = torch.randn( + size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32 + ) + densities = torch.rand( + size=[num_volumes, 1, *size], device=device, dtype=torch.float32 + ) + + for has_features in (True, False): + v = Volumes( + densities=densities, features=features if has_features else None + ) + vnew = v.clone() + vnew._densities.data[0, 0, 0, 0, 0] += 1.0 + self.assertNotAlmostEqual( + float( + (vnew.densities()[0, 0, 0, 0, 0] - v.densities()[0, 0, 0, 0, 0]) + .abs() + .max() + ), + 0.0, + ) + + if has_features: + vnew._features.data[0, 0, 0, 0, 0] += 1.0 + self.assertNotAlmostEqual( + float( + (vnew.features()[0, 0, 0, 0, 0] - v.features()[0, 0, 0, 0, 0]) + .abs() + .max() + ), + 0.0, + ) + + def _check_vars_on_device(self, v, desired_device): + for var_name, var in vars(v).items(): + if var_name != "device": + if var is not None: + self.assertTrue(var.device.type == desired_device.type) + else: + self.assertTrue(var.type == desired_device.type) + + def test_to( + self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32 + ): + """ + Test the moving of the volumes from/to gpu and cpu + """ + + device = torch.device("cuda:0") + device_cpu = torch.device("cpu") + + features = torch.randn( + size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32 + ) + densities = torch.rand(size=[num_volumes, 1, *size], device=device, dtype=dtype) + + for features_ in (features, None): + v = Volumes(densities=densities, features=features_) + + v_cpu = v.cpu() + v_cuda = v_cpu.cuda() + v_cuda_2 = v_cuda.cuda() + v_cpu_2 = v_cuda_2.cpu() + + for v1, v2 in itertools.combinations( + (v, v_cpu, v_cpu_2, v_cuda, v_cuda_2), 2 + ): + if v1 is v_cuda and v2 is v_cuda_2: + # checks that we do not copy if the devices stay the same + assert_fun = self.assertIs + else: + assert_fun = self.assertSeparate + assert_fun(v1._densities, v2._densities) + if features_ is not None: + assert_fun(v1._features, v2._features) + for v_ in (v1, v2): + if v_ in (v_cpu, v_cpu_2): + self._check_vars_on_device(v_, device_cpu) + else: + self._check_vars_on_device(v_, device) + + def _check_padded(self, x_pad, x_list, grid_sizes): + """ + Check that padded tensors x_pad are the same as x_list tensors. + """ + num_volumes = len(x_list) + for i in range(num_volumes): + self.assertClose( + x_pad[i][:, : grid_sizes[i][0], : grid_sizes[i][1], : grid_sizes[i][2]], + x_list[i], + ) + + def test_feature_density_setters(self): + """ + Tests getters and setters for padded/list representations. + """ + + device = torch.device("cuda:0") + diff_device = torch.device("cpu") + + num_volumes = 30 + num_channels = 4 + K = 20 + + densities = [] + features = [] + grid_sizes = [] + diff_grid_sizes = [] + + for _ in range(num_volumes): + grid_size = torch.randint(K - 1, size=(3,)).long() + 1 + densities.append( + torch.rand((1, *grid_size), device=device, dtype=torch.float32) + ) + features.append( + torch.rand( + (num_channels, *grid_size), device=device, dtype=torch.float32 + ) + ) + grid_sizes.append(grid_size) + + diff_grid_size = ( + copy.deepcopy(grid_size) + torch.randint(2, size=(3,)).long() + 1 + ) + diff_grid_sizes.append(diff_grid_size) + grid_sizes = torch.stack(grid_sizes).to(device) + diff_grid_sizes = torch.stack(diff_grid_sizes).to(device) + + volumes = Volumes(densities=densities, features=features) + self.assertClose(volumes.get_grid_sizes(), grid_sizes) + + # test the getters + features_padded = volumes.features() + densities_padded = volumes.densities() + features_list = volumes.features_list() + densities_list = volumes.densities_list() + for x_pad, x_list in zip( + (densities_padded, features_padded, densities_padded, features_padded), + (densities_list, features_list, densities, features), + ): + self._check_padded(x_pad, x_list, grid_sizes) + + # test feature setters + features_new = [ + torch.rand((num_channels, *grid_size), device=device, dtype=torch.float32) + for grid_size in grid_sizes + ] + volumes._set_features(features_new) + features_new_list = volumes.features_list() + features_new_padded = volumes.features() + for x_pad, x_list in zip( + (features_new_padded, features_new_padded), + (features_new, features_new_list), + ): + self._check_padded(x_pad, x_list, grid_sizes) + + # wrong features to update + bad_features_new = [ + [ + torch.rand( + (num_channels, *grid_size), device=diff_device, dtype=torch.float32 + ) + for grid_size in diff_grid_sizes + ], + torch.rand( + (num_volumes, num_channels, K + 1, K, K), + device=device, + dtype=torch.float32, + ), + None, + ] + for bad_features_new_ in bad_features_new: + with self.assertRaises(ValueError): + volumes._set_densities(bad_features_new_) + + # test density setters + densities_new = [ + torch.rand((1, *grid_size), device=device, dtype=torch.float32) + for grid_size in grid_sizes + ] + volumes._set_densities(densities_new) + densities_new_list = volumes.densities_list() + densities_new_padded = volumes.densities() + for x_pad, x_list in zip( + (densities_new_padded, densities_new_padded), + (densities_new, densities_new_list), + ): + self._check_padded(x_pad, x_list, grid_sizes) + + # wrong densities to update + bad_densities_new = [ + [ + torch.rand((1, *grid_size), device=diff_device, dtype=torch.float32) + for grid_size in diff_grid_sizes + ], + torch.rand( + (num_volumes, 1, K + 1, K, K), device=device, dtype=torch.float32 + ), + None, + ] + for bad_densities_new_ in bad_densities_new: + with self.assertRaises(ValueError): + volumes._set_densities(bad_densities_new_) + + # test update_padded + volumes = Volumes(densities=densities, features=features) + volumes_updated = volumes.update_padded( + densities_new, new_features=features_new + ) + densities_new_list = volumes_updated.densities_list() + densities_new_padded = volumes_updated.densities() + features_new_list = volumes_updated.features_list() + features_new_padded = volumes_updated.features() + for x_pad, x_list in zip( + ( + densities_new_padded, + densities_new_padded, + features_new_padded, + features_new_padded, + ), + (densities_new, densities_new_list, features_new, features_new_list), + ): + self._check_padded(x_pad, x_list, grid_sizes) + self.assertIs(volumes.get_grid_sizes(), volumes_updated.get_grid_sizes()) + self.assertIs( + volumes.get_local_to_world_coords_transform(), + volumes_updated.get_local_to_world_coords_transform(), + ) + self.assertIs(volumes.device, volumes_updated.device) + + def test_constructor_for_padded_lists(self): + """ + Tests constructor for padded/list representations. + """ + + device = torch.device("cuda:0") + diff_device = torch.device("cpu") + + num_volumes = 3 + num_channels = 4 + size = (6, 8, 10) + diff_size = (6, 8, 11) + + # good ways to define densities + ok_densities = [ + torch.randn( + size=[num_volumes, 1, *size], device=device, dtype=torch.float32 + ).unbind(0), + torch.randn( + size=[num_volumes, 1, *size], device=device, dtype=torch.float32 + ), + ] + + # bad ways to define features + bad_features = [ + torch.randn( + size=[num_volumes + 1, num_channels, *size], + device=device, + dtype=torch.float32, + ).unbind( + 0 + ), # list with diff batch size + torch.randn( + size=[num_volumes + 1, num_channels, *size], + device=device, + dtype=torch.float32, + ), # diff batch size + torch.randn( + size=[num_volumes, num_channels, *diff_size], + device=device, + dtype=torch.float32, + ).unbind( + 0 + ), # list with different size + torch.randn( + size=[num_volumes, num_channels, *diff_size], + device=device, + dtype=torch.float32, + ), # different size + torch.randn( + size=[num_volumes, num_channels, *size], + device=diff_device, + dtype=torch.float32, + ), # different device + torch.randn( + size=[num_volumes, num_channels, *size], + device=diff_device, + dtype=torch.float32, + ).unbind( + 0 + ), # list with different device + ] + + # good ways to define features + ok_features = [ + torch.randn( + size=[num_volumes, num_channels, *size], + device=device, + dtype=torch.float32, + ).unbind( + 0 + ), # list of features of correct size + torch.randn( + size=[num_volumes, num_channels, *size], + device=device, + dtype=torch.float32, + ), + ] + + for densities in ok_densities: + for features in bad_features: + self.assertRaises( + ValueError, Volumes, densities=densities, features=features + ) + for features in ok_features: + Volumes(densities=densities, features=features) + + def test_constructor( + self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32 + ): + """ + Test different ways of calling the `Volumes` constructor + """ + + device = torch.device("cuda:0") + + # all ways to define features + features = [ + torch.randn( + size=[num_volumes, num_channels, *size], + device=device, + dtype=torch.float32, + ), # padded tensor + torch.randn( + size=[num_volumes, num_channels, *size], + device=device, + dtype=torch.float32, + ).unbind( + 0 + ), # list of features + None, # no features + ] + + # bad ways to define features + bad_features = [ + torch.randn( + size=[num_volumes, num_channels, 2, *size], + device=device, + dtype=torch.float32, + ), # 6 dims + torch.randn( + size=[num_volumes, *size], device=device, dtype=torch.float32 + ), # 4 dims + torch.randn( + size=[num_volumes, *size], device=device, dtype=torch.float32 + ).unbind( + 0 + ), # list of 4 dim tensors + ] + + # all ways to define densities + densities = [ + torch.randn( + size=[num_volumes, 1, *size], device=device, dtype=torch.float32 + ), # padded tensor + torch.randn( + size=[num_volumes, 1, *size], device=device, dtype=torch.float32 + ).unbind( + 0 + ), # list of densities + ] + + # bad ways to define densities + bad_densities = [ + None, # omitted + torch.randn( + size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32 + ), # 6-dim tensor + torch.randn( + size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32 + ).unbind( + 0 + ), # list of 5-dim densities + ] + + # all possible ways to define the voxels sizes + vox_sizes = [ + torch.Tensor([1.0, 1.0, 1.0]), + [1.0, 1.0, 1.0], + torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1), + torch.Tensor([1.0])[None].repeat(num_volumes, 1), + 1.0, + torch.Tensor([1.0]), + ] + + # all possible ways to define the volume translations + vol_translations = [ + torch.Tensor([1.0, 1.0, 1.0]), + [1.0, 1.0, 1.0], + torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1), + ] + + # wrong ways to define voxel sizes + bad_vox_sizes = [ + torch.Tensor([1.0, 1.0, 1.0, 1.0]), + [1.0, 1.0, 1.0, 1.0], + torch.Tensor([]), + None, + ] + + # wrong ways to define the volume translations + bad_vol_translations = [ + torch.Tensor([1.0, 1.0]), + [1.0, 1.0], + 1.0, + torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes + 1, 1), + ] + + def zip_with_ok_indicator(good, bad): + return zip([*good, *bad], [*([True] * len(good)), *([False] * len(bad))]) + + for features_, features_ok in zip_with_ok_indicator(features, bad_features): + for densities_, densities_ok in zip_with_ok_indicator( + densities, bad_densities + ): + for vox_size, size_ok in zip_with_ok_indicator( + vox_sizes, bad_vox_sizes + ): + for vol_translation, trans_ok in zip_with_ok_indicator( + vol_translations, bad_vol_translations + ): + if ( + size_ok and trans_ok and features_ok and densities_ok + ): # if all entries are good we check that this doesnt throw + Volumes( + features=features_, + densities=densities_, + voxel_size=vox_size, + volume_translation=vol_translation, + ) + + else: # otherwise we check for ValueError + self.assertRaises( + ValueError, + Volumes, + features=features_, + densities=densities_, + voxel_size=vox_size, + volume_translation=vol_translation, + )