mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Volumes data structure.
Summary: Implemented a data structure for volumes. Reviewed By: gkioxari Differential Revision: D20342920 fbshipit-source-id: ccc23eaa183ed8a4e9cd7674b4dcf31e8a65c3c6
This commit is contained in:
parent
1e4a2e8624
commit
03ee1dbf82
@ -3,6 +3,7 @@
|
||||
from .meshes import Meshes, join_meshes_as_batch, join_meshes_as_scene
|
||||
from .pointclouds import Pointclouds
|
||||
from .utils import list_to_packed, list_to_padded, packed_to_list, padded_to_list
|
||||
from .volumes import Volumes
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
703
pytorch3d/structures/volumes.py
Normal file
703
pytorch3d/structures/volumes.py
Normal file
@ -0,0 +1,703 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import copy
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..transforms import Scale, Transform3d
|
||||
from . import utils as struct_utils
|
||||
|
||||
|
||||
class Volumes(object):
|
||||
"""
|
||||
This class provides functions for working with batches of volumetric grids
|
||||
of possibly varying spatial sizes.
|
||||
|
||||
VOLUME DENSITIES
|
||||
|
||||
The Volumes class can be either constructed from a 5D tensor of
|
||||
`densities` of size `batch x density_dim x depth x height x width` or
|
||||
from a list of differently-sized 4D tensors `[D_1, ..., D_batch]`,
|
||||
where each `D_i` is of size `[density_dim x depth_i x height_i x width_i]`.
|
||||
|
||||
In case the `Volumes` object is initialized from the list of `densities`,
|
||||
the list of tensors is internally converted to a single 5D tensor by
|
||||
zero-padding the relevant dimensions. Both list and padded representations can be
|
||||
accessed with the `Volumes.densities()` or `Volumes.densities_list()` getters.
|
||||
The sizes of the individual volumes in the structure can be retrieved
|
||||
with the `Volumes.get_grid_sizes()` getter.
|
||||
|
||||
The `Volumes` class is immutable. I.e. after generating a `Volumes` object,
|
||||
one cannot change its properties, such as `self._densities` or `self._features`
|
||||
anymore.
|
||||
|
||||
|
||||
VOLUME FEATURES
|
||||
|
||||
While the `densities` field is intended to represent various measures of the
|
||||
"density" of the volume cells (opacity, signed/unsigned distances
|
||||
from the nearest surface, ...), one can additionally initialize the
|
||||
object with the `features` argument. `features` are either a 5D tensor
|
||||
of shape `batch x feature_dim x depth x height x width` or a list of
|
||||
of differently-sized 4D tensors `[F_1, ..., F_batch]`,
|
||||
where each `F_i` is of size `[feature_dim x depth_i x height_i x width_i]`.
|
||||
`features` are intended to describe other properties of volume cells,
|
||||
such as per-voxel 3D vectors of RGB colors that can be later used
|
||||
for rendering the volume.
|
||||
|
||||
|
||||
VOLUME COORDINATES
|
||||
|
||||
Additionally, the `Volumes` class keeps track of the locations of the
|
||||
centers of the volume cells in the local volume coordinates as well as in
|
||||
the world coordinates.
|
||||
|
||||
Local coordinates:
|
||||
- Represent the locations of the volume cells in the local coordinate
|
||||
frame of the volume.
|
||||
- The center of the voxel indexed with `[·, ·, 0, 0, 0]` in the volume
|
||||
has its 3D local coordinate set to `[-1, -1, -1]`, while the voxel
|
||||
at index `[·, ·, depth_i-1, height_i-1, width_i-1]` has its
|
||||
3D local coordinate set to `[1, 1, 1]`.
|
||||
- The first/second/third coordinate of each of the 3D per-voxel
|
||||
XYZ vector denotes the horizontal/vertical/depth-wise position
|
||||
respectively. I.e the order of the coordinate dimensions in the
|
||||
volume is reversed w.r.t. the order of the 3D coordinate vectors.
|
||||
- The intermediate coordinates between `[-1, -1, -1]` and `[1, 1, 1]`.
|
||||
are linearly interpolated over the spatial dimensions of the volume.
|
||||
- Note that the convention is the same as for the 5D version of the
|
||||
`torch.nn.functional.grid_sample` function called with
|
||||
`align_corners==True`.
|
||||
- Note that the local coordinate convention of `Volumes`
|
||||
(+X = left to right, +Y = top to bottom, +Z = away from the user)
|
||||
is *different* from the world coordinate convention of the
|
||||
renderer for `Meshes` or `Pointclouds`
|
||||
(+X = right to left, +Y = bottom to top, +Z = away from the user).
|
||||
|
||||
World coordinates:
|
||||
- These define the locations of the centers of the volume cells
|
||||
in the world coordinates.
|
||||
- They are specifiied with the following mapping that converts
|
||||
points `x_local` in the local coordinates to points `x_world`
|
||||
in the world coordinates:
|
||||
```
|
||||
x_world = (
|
||||
x_local * (volume_size - 1) * 0.5 * voxel_size
|
||||
) - volume_translation,
|
||||
```
|
||||
here `voxel_size` specifies the size of each voxel of the volume,
|
||||
and `volume_translation` is the 3D offset of the central voxel of
|
||||
the volume w.r.t. the origin of the world coordinate frame.
|
||||
Both `voxel_size` and `volume_translation` are specified in
|
||||
the world coordinate units. `volume_size` is the spatial size of
|
||||
the volume in form of a 3D vector `[width, height, depth]`.
|
||||
- Given the above definition of `x_world`, one can derive the
|
||||
inverse mapping from `x_world` to `x_local` as follows:
|
||||
```
|
||||
x_local = (
|
||||
(x_world + volume_translation) / (0.5 * voxel_size)
|
||||
) / (volume_size - 1)
|
||||
```
|
||||
- For a trivial volume with `volume_translation==[0, 0, 0]`
|
||||
with `voxel_size=-1`, `x_world` would range
|
||||
from -(volume_size-1)/2` to `+(volume_size-1)/2`.
|
||||
|
||||
Coordinate tensors that denote the locations of each of the volume cells in
|
||||
local / world coordinates (with shape `(depth x height x width x 3)`)
|
||||
can be retrieved by calling the `Volumes.get_coord_grid()` getter with the
|
||||
appropriate `world_coordinates` argument.
|
||||
|
||||
Internally, the mapping between `x_local` and `x_world` is represented
|
||||
as a `Transform3D` object `Volumes._local_to_world_transform`.
|
||||
Users can access the relevant transformations with the
|
||||
`Volumes.get_world_to_local_coords_transform()` and
|
||||
`Volumes.get_local_to_world_coords_transform()`
|
||||
functions.
|
||||
|
||||
Example coordinate conversion:
|
||||
- For a "trivial" volume with `voxel_size = 1.`,
|
||||
`volume_translation=[0., 0., 0.]`, and the spatial size of
|
||||
`DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped
|
||||
to `x_local=(-1, 0, 1)`.
|
||||
- For a "trivial" volume `v` with `voxel_size = 1.`,
|
||||
`volume_translation=[0., 0., 0.]`, the following holds:
|
||||
```
|
||||
torch.nn.functional.grid_sample(
|
||||
v.densities(),
|
||||
v.get_coord_grid(world_coordinates=False),
|
||||
align_corners=True,
|
||||
) == v.densities(),
|
||||
```
|
||||
i.e. sampling the volume at trivial local coordinates
|
||||
(no scaling with `voxel_size`` or shift with `volume_translation`)
|
||||
results in the same volume.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
densities: Union[List[torch.Tensor], torch.Tensor],
|
||||
features: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0,
|
||||
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0),
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
**densities**: Batch of input feature volume occupancies of shape
|
||||
`(minibatch, density_dim, depth, height, width)`, or a list
|
||||
of 4D tensors `[D_1, ..., D_minibatch]` where each `D_i` has
|
||||
shape `(density_dim, depth_i, height_i, width_i)`.
|
||||
Typically, each voxel contains a non-negative number
|
||||
corresponding to its opaqueness.
|
||||
**features**: Batch of input feature volumes of shape:
|
||||
`(minibatch, feature_dim, depth, height, width)` or a list
|
||||
of 4D tensors `[F_1, ..., F_minibatch]` where each `F_i` has
|
||||
shape `(feature_dim, depth_i, height_i, width_i)`.
|
||||
The field is optional and can be set to `None` in case features are
|
||||
not required.
|
||||
**voxel_size**: Denotes the size of each volume voxel in world units.
|
||||
Has to be one of:
|
||||
a) A scalar (square voxels)
|
||||
b) 3-tuple or a 3-list of scalars
|
||||
c) a Tensor of shape (3,)
|
||||
d) a Tensor of shape (minibatch, 3)
|
||||
e) a Tensor of shape (minibatch, 1)
|
||||
f) a Tensor of shape (1,) (square voxels)
|
||||
**volume_translation**: Denotes the 3D translation of the center
|
||||
of the volume in world units. Has to be one of:
|
||||
a) 3-tuple or a 3-list of scalars
|
||||
b) a Tensor of shape (3,)
|
||||
c) a Tensor of shape (minibatch, 3)
|
||||
d) a Tensor of shape (1,) (square voxels)
|
||||
"""
|
||||
|
||||
# handle densities
|
||||
densities, grid_sizes = self._convert_densities_features_to_tensor(
|
||||
densities, "densities"
|
||||
)
|
||||
|
||||
# take device from densities
|
||||
self.device = densities.device
|
||||
|
||||
# assign to the internal buffers
|
||||
self._densities = densities
|
||||
self._grid_sizes = grid_sizes
|
||||
|
||||
# handle features
|
||||
self._features = None
|
||||
if features is not None:
|
||||
self._set_features(features)
|
||||
|
||||
# set the local_to_world transform
|
||||
self._set_local_to_world_transform(
|
||||
voxel_size=voxel_size, volume_translation=volume_translation
|
||||
)
|
||||
|
||||
def _convert_densities_features_to_tensor(
|
||||
self, x: Union[List[torch.Tensor], torch.Tensor], var_name: str
|
||||
):
|
||||
"""
|
||||
Handle the `densities` or `features` arguments to the constructor.
|
||||
"""
|
||||
if isinstance(x, (list, tuple)):
|
||||
x_tensor = struct_utils.list_to_padded(x)
|
||||
if any(x_.ndim != 4 for x_ in x):
|
||||
raise ValueError(
|
||||
f"`{var_name}` has to be a list of 4-dim tensors of shape: "
|
||||
f"({var_name}_dim, height, width, depth)"
|
||||
)
|
||||
if any(x_.shape[0] != x[0].shape[0] for x_ in x):
|
||||
raise ValueError(
|
||||
f"Each entry in the list of `{var_name}` has to have the "
|
||||
"same number of channels (first dimension in the tensor)."
|
||||
)
|
||||
x_shapes = torch.stack(
|
||||
[
|
||||
torch.tensor(
|
||||
list(x_.shape[1:]), dtype=torch.long, device=x_tensor.device
|
||||
)
|
||||
for x_ in x
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
elif torch.is_tensor(x):
|
||||
if x.ndim != 5:
|
||||
raise ValueError(
|
||||
f"`{var_name}` has to be a 5-dim tensor of shape: "
|
||||
f"(minibatch, {var_name}_dim, height, width, depth)"
|
||||
)
|
||||
x_tensor = x
|
||||
x_shapes = torch.tensor(
|
||||
list(x.shape[2:]), dtype=torch.long, device=x.device
|
||||
)[None].repeat(x.shape[0], 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{var_name} must be either a list or a tensor with "
|
||||
f"shape (batch_size, {var_name}_dim, H, W, D)."
|
||||
)
|
||||
return x_tensor, x_shapes
|
||||
|
||||
def _vsize_translation_to_transform(
|
||||
self, voxel_size, volume_translation, batch_size
|
||||
):
|
||||
"""
|
||||
Converts the `voxel_size` and `volume_translation` constructor arguments
|
||||
to the internal `Transform3D` object `local_to_world_transform`.
|
||||
"""
|
||||
volume_size_zyx = self.get_grid_sizes().float()
|
||||
volume_size_xyz = volume_size_zyx[:, [2, 1, 0]]
|
||||
|
||||
# x_local = (
|
||||
# (x_world + volume_translation) / (0.5 * voxel_size)
|
||||
# ) / (volume_size - 1)
|
||||
|
||||
# x_world = (
|
||||
# x_local * (volume_size - 1) * 0.5 * voxel_size
|
||||
# ) - volume_translation
|
||||
|
||||
local_to_world_transform = Scale(
|
||||
(volume_size_xyz - 1) * voxel_size * 0.5, device=self.device
|
||||
).translate(-volume_translation)
|
||||
|
||||
return local_to_world_transform
|
||||
|
||||
def _handle_voxel_size(self, voxel_size, batch_size):
|
||||
"""
|
||||
Handle the `voxel_size` argument to the `Volumes` constructor.
|
||||
"""
|
||||
err_msg = (
|
||||
"voxel_size has to be either a 3-tuple of scalars, or a scalar, or"
|
||||
" a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)."
|
||||
)
|
||||
if isinstance(voxel_size, (float, int)):
|
||||
# convert a scalar to a 3-element tensor
|
||||
voxel_size = (voxel_size,) * 3
|
||||
if torch.is_tensor(voxel_size):
|
||||
if voxel_size.numel() == 1:
|
||||
# convert a single-element tensor to a 3-element one
|
||||
voxel_size = voxel_size.view(-1).repeat(3)
|
||||
elif len(voxel_size.shape) == 2 and (
|
||||
voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1
|
||||
):
|
||||
voxel_size = voxel_size.repeat(1, 3)
|
||||
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
|
||||
|
||||
def _handle_volume_translation(self, translation, batch_size):
|
||||
"""
|
||||
Handle the `volume_translation` argument to the `Volumes` constructor.
|
||||
"""
|
||||
err_msg = (
|
||||
"`volume_translation` has to be either a 3-tuple of scalars, or"
|
||||
" a Tensor of shape (1,3) or (minibatch, 3) or (3,)`."
|
||||
)
|
||||
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
|
||||
|
||||
def _convert_volume_property_to_tensor(
|
||||
self, x, batch_size, err_msg
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Handle the `volume_translation` or `voxel_size` argument to
|
||||
the Volumes constructor.
|
||||
Return a tensor of shape (N, 3) where N is the batch_size or 1
|
||||
if batch_size is None.
|
||||
"""
|
||||
if isinstance(x, (list, tuple)):
|
||||
if len(x) != 3:
|
||||
raise ValueError(err_msg)
|
||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
|
||||
x = x.repeat((batch_size, 1))
|
||||
elif torch.is_tensor(x):
|
||||
ok = (
|
||||
(x.shape[0] == 1 and x.shape[1] == 3)
|
||||
or (x.shape[0] == 3 and len(x.shape) == 1)
|
||||
or (x.shape[0] == batch_size and x.shape[1] == 3)
|
||||
)
|
||||
if not ok:
|
||||
raise ValueError(err_msg)
|
||||
if x.device != self.device:
|
||||
x = x.to(self.device)
|
||||
if x.shape[0] == 3 and len(x.shape) == 1:
|
||||
x = x[None]
|
||||
if x.shape[0] == 1:
|
||||
x = x.repeat((batch_size, 1))
|
||||
else:
|
||||
raise ValueError(err_msg)
|
||||
|
||||
return x
|
||||
|
||||
def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Return the 3D coordinate grid of the volumetric grid
|
||||
in local (`world_coordinates=False`) or world coordinates
|
||||
(`world_coordinates=True`).
|
||||
|
||||
The grid records location of each center of the corresponding volume voxel.
|
||||
|
||||
Local coordinates are scaled s.t. the values along one side of the
|
||||
volume are in range [-1, 1].
|
||||
|
||||
Args:
|
||||
**world_coordinates**: if `True`, the method
|
||||
returns the grid in the world coordinates,
|
||||
otherwise, in local coordinates.
|
||||
|
||||
Returns:
|
||||
**coordinate_grid**: The grid of coordinates of shape
|
||||
`(minibatch, depth, height, width, 3)`, where `minibatch`,
|
||||
`height`, `width` and `depth` are the batch size, height, width
|
||||
and depth of the volume `features` or `densities`.
|
||||
"""
|
||||
# TODO(dnovotny): Implement caching of the coordinate grid.
|
||||
return self._calculate_coordinate_grid(world_coordinates=world_coordinates)
|
||||
|
||||
def _calculate_coordinate_grid(
|
||||
self, world_coordinates: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the 3D coordinate grid of the volumetric grid either in
|
||||
in local (`world_coordinates=False`) or
|
||||
world coordinates (`world_coordinates=True`) .
|
||||
"""
|
||||
|
||||
densities = self.densities()
|
||||
ba, _, de, he, wi = densities.shape
|
||||
grid_sizes = self.get_grid_sizes()
|
||||
|
||||
# generate coordinate axes
|
||||
vol_axes = [
|
||||
torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device)
|
||||
for r in (de, he, wi)
|
||||
]
|
||||
|
||||
# generate per-coord meshgrids
|
||||
Z, Y, X = torch.meshgrid(vol_axes)
|
||||
|
||||
# stack the coord grids ... this order matches the coordinate convention
|
||||
# of torch.nn.grid_sample
|
||||
vol_coords_local = torch.stack((X, Y, Z), dim=3)[None].repeat(ba, 1, 1, 1, 1)
|
||||
|
||||
# get grid sizes relative to the maximal volume size
|
||||
grid_sizes_relative = (
|
||||
torch.tensor([[de, he, wi]], device=grid_sizes.device, dtype=torch.float32)
|
||||
- 1
|
||||
) / (grid_sizes - 1).float()
|
||||
|
||||
if (grid_sizes_relative != 1.0).any():
|
||||
# if any of the relative sizes != 1.0, adjust the grid
|
||||
grid_sizes_relative_reshape = grid_sizes_relative[:, [2, 1, 0]][
|
||||
:, None, None, None
|
||||
]
|
||||
vol_coords_local *= grid_sizes_relative_reshape
|
||||
vol_coords_local += grid_sizes_relative_reshape - 1
|
||||
|
||||
if world_coordinates:
|
||||
vol_coords = self.local_to_world_coords(vol_coords_local)
|
||||
else:
|
||||
vol_coords = vol_coords_local
|
||||
|
||||
return vol_coords
|
||||
|
||||
def get_local_to_world_coords_transform(self) -> Transform3d:
|
||||
"""
|
||||
Return a Transform3d object that converts points in the
|
||||
the local coordinate frame of the volume to world coordinates.
|
||||
Local volume coordinates are scaled s.t. the coordinates along one
|
||||
side of the volume are in range [-1, 1].
|
||||
|
||||
Returns:
|
||||
**local_to_world_transform**: A Transform3d object converting
|
||||
points from local coordinates to the world coordinates.
|
||||
"""
|
||||
return self._local_to_world_transform
|
||||
|
||||
def get_world_to_local_coords_transform(self) -> Transform3d:
|
||||
"""
|
||||
Return a Transform3d object that converts points in the
|
||||
world coordinates to the local coordinate frame of the volume.
|
||||
Local volume coordinates are scaled s.t. the coordinates along one
|
||||
side of the volume are in range [-1, 1].
|
||||
|
||||
Returns:
|
||||
**world_to_local_transform**: A Transform3d object converting
|
||||
points from world coordinates to local coordinates.
|
||||
"""
|
||||
return self.get_local_to_world_coords_transform().inverse()
|
||||
|
||||
def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert a batch of 3D point coordinates `points_3d_world` of shape
|
||||
(minibatch, ..., dim) in the world coordinates to
|
||||
the local coordinate frame of the volume. Local volume
|
||||
coordinates are scaled s.t. the coordinates along one side of the volume
|
||||
are in range [-1, 1].
|
||||
|
||||
Args:
|
||||
**points_3d_world**: A tensor of shape `(minibatch, ..., 3)`
|
||||
containing the 3D coordinates of a set of points that will
|
||||
be converted from the local volume coordinates (ranging
|
||||
within [-1, 1]) to the world coordinates
|
||||
defined by the `self.center` and `self.voxel_size` parameters.
|
||||
|
||||
Returns:
|
||||
**points_3d_local**: `points_3d_world` converted to the local
|
||||
volume coordinates of shape `(minibatch, ..., 3)`.
|
||||
"""
|
||||
pts_shape = points_3d_world.shape
|
||||
return (
|
||||
self.get_world_to_local_coords_transform()
|
||||
.transform_points(points_3d_world.view(pts_shape[0], -1, 3))
|
||||
.view(pts_shape)
|
||||
)
|
||||
|
||||
def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert a batch of 3D point coordinates `points_3d_local` of shape
|
||||
(minibatch, ..., dim) in the local coordinate frame of the volume
|
||||
to the world coordinates.
|
||||
|
||||
Args:
|
||||
**points_3d_local**: A tensor of shape `(minibatch, ..., 3)`
|
||||
containing the 3D coordinates of a set of points that will
|
||||
be converted from the local volume coordinates (ranging
|
||||
within [-1, 1]) to the world coordinates
|
||||
defined by the `self.center` and `self.voxel_size` parameters.
|
||||
|
||||
Returns:
|
||||
**points_3d_world**: `points_3d_local` converted to the world
|
||||
coordinates of the volume of shape `(minibatch, ..., 3)`.
|
||||
"""
|
||||
pts_shape = points_3d_local.shape
|
||||
return (
|
||||
self.get_local_to_world_coords_transform()
|
||||
.transform_points(points_3d_local.view(pts_shape[0], -1, 3))
|
||||
.view(pts_shape)
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self._densities.shape[0]
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, List[int], Tuple[int], slice, torch.Tensor]
|
||||
) -> "Volumes":
|
||||
"""
|
||||
Args:
|
||||
index: Specifying the index of the volume to retrieve.
|
||||
Can be an int, slice, list of ints or a boolean or a long tensor.
|
||||
|
||||
Returns:
|
||||
Volumes object with selected volumes. The tensors are not cloned.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
index = torch.LongTensor([index])
|
||||
elif isinstance(index, (slice, list, tuple)):
|
||||
pass
|
||||
elif torch.is_tensor(index):
|
||||
if index.dim() != 1 or index.dtype.is_floating_point:
|
||||
raise IndexError(index)
|
||||
else:
|
||||
raise IndexError(index)
|
||||
|
||||
new = self.__class__(
|
||||
features=self.features()[index] if self._features is not None else None,
|
||||
densities=self.densities()[index],
|
||||
)
|
||||
# dont forget to update grid_sizes!
|
||||
new._grid_sizes = self.get_grid_sizes()[index]
|
||||
new._local_to_world_transform = self._local_to_world_transform[index]
|
||||
return new
|
||||
|
||||
def features(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Returns the features of the volume.
|
||||
|
||||
Returns:
|
||||
**features**: The tensor of volume features.
|
||||
"""
|
||||
return self._features
|
||||
|
||||
def densities(self) -> torch.Tensor:
|
||||
"""
|
||||
Returns the densities of the volume.
|
||||
|
||||
Returns:
|
||||
**densities**: The tensor of volume densities.
|
||||
"""
|
||||
return self._densities
|
||||
|
||||
def densities_list(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Get the list representation of the densities.
|
||||
|
||||
Returns:
|
||||
list of tensors of densities of shape (dim_i, D_i, H_i, W_i).
|
||||
"""
|
||||
return self._features_densities_list(self.densities())
|
||||
|
||||
def features_list(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Get the list representation of the features.
|
||||
|
||||
Returns:
|
||||
list of tensors of features of shape (dim_i, D_i, H_i, W_i)
|
||||
or `None` for feature-less volumes.
|
||||
"""
|
||||
if self._features is None:
|
||||
# No features provided so return None
|
||||
return None
|
||||
return self._features_densities_list(self.features())
|
||||
|
||||
def _features_densities_list(self, x) -> List[torch.Tensor]:
|
||||
"""
|
||||
Retrieve the list representation of features/densities.
|
||||
|
||||
Args:
|
||||
x: self.features() or self.densities()
|
||||
|
||||
Returns:
|
||||
list of tensors of features/densities of shape (dim_i, D_i, H_i, W_i).
|
||||
"""
|
||||
x_dim = x.shape[1]
|
||||
pad_sizes = torch.nn.functional.pad(
|
||||
self.get_grid_sizes(), [1, 0], mode="constant", value=x_dim
|
||||
)
|
||||
x_list = struct_utils.padded_to_list(x, pad_sizes.tolist())
|
||||
return x_list
|
||||
|
||||
def get_grid_sizes(self) -> torch.LongTensor:
|
||||
"""
|
||||
Returns the sizes of individual volumetric grids in the structure.
|
||||
|
||||
Returns:
|
||||
**grid_sizes**: Tensor of spatial sizes of each of the volumes
|
||||
of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i).
|
||||
"""
|
||||
return self._grid_sizes
|
||||
|
||||
def update_padded(self, new_densities, new_features=None) -> "Volumes":
|
||||
"""
|
||||
Returns a Volumes structure with updated padded tensors and copies of
|
||||
the auxiliary tensors `self._local_to_world_transform`,
|
||||
`device` and `self._grid_sizes`. This function allows for an update of
|
||||
densities (and features) without having to explicitly
|
||||
convert it to the list representation for heterogeneous batches.
|
||||
|
||||
Args:
|
||||
new_densities: FloatTensor of shape (N, dim_density, D, H, W)
|
||||
new_features: (optional) FloatTensor of shape (N, dim_feature, D, H, W)
|
||||
|
||||
Returns:
|
||||
Volumes with updated features and densities
|
||||
"""
|
||||
new = copy.copy(self)
|
||||
new._set_densities(new_densities)
|
||||
if new_features is None:
|
||||
new._features = None
|
||||
else:
|
||||
new._set_features(new_features)
|
||||
return new
|
||||
|
||||
def _set_features(self, features):
|
||||
self._set_densities_features(features, "features")
|
||||
|
||||
def _set_densities(self, densities):
|
||||
self._set_densities_features(densities, "densities")
|
||||
|
||||
def _set_densities_features(self, x, var_name):
|
||||
x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
|
||||
if x_tensor.device != self.device:
|
||||
raise ValueError(
|
||||
f"`{var_name}` have to be on the same device as `self.densities`."
|
||||
)
|
||||
if len(x_tensor.shape) != 5:
|
||||
raise ValueError(
|
||||
f"{var_name} has to be a 5-dim tensor of shape: "
|
||||
f"(minibatch, {var_name}_dim, height, width, depth)"
|
||||
)
|
||||
|
||||
if not (
|
||||
(self.get_grid_sizes().shape == grid_sizes.shape)
|
||||
and torch.allclose(self.get_grid_sizes(), grid_sizes)
|
||||
):
|
||||
raise ValueError(
|
||||
f"The size of every grid in `{var_name}` has to match the size of"
|
||||
"the corresponding `densities` grid."
|
||||
)
|
||||
setattr(self, "_" + var_name, x_tensor)
|
||||
|
||||
def _set_local_to_world_transform(
|
||||
self,
|
||||
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0,
|
||||
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0),
|
||||
):
|
||||
"""
|
||||
Sets the internal representation of the transformation between the
|
||||
world and local volume coordinates by specifying
|
||||
`voxel_size` and `volume_translation`
|
||||
|
||||
Args:
|
||||
**voxel_size**: Denotes the size of input voxels. Has to be one of:
|
||||
a) A scalar (square voxels)
|
||||
b) 3-tuple or a 3-list of scalars
|
||||
c) a Tensor of shape (3,)
|
||||
d) a Tensor of shape (minibatch, 3)
|
||||
e) a Tensor of shape (1,) (square voxels)
|
||||
**volume_translation**: Denotes the 3D translation of the center
|
||||
of the volume in world units. Has to be one of:
|
||||
a) 3-tuple or a 3-list of scalars
|
||||
b) a Tensor of shape (3,)
|
||||
c) a Tensor of shape (minibatch, 3)
|
||||
d) a Tensor of shape (1,) (square voxels)
|
||||
"""
|
||||
# handle voxel size and center
|
||||
# here we force the tensors to lie on self.device
|
||||
voxel_size = self._handle_voxel_size(voxel_size, len(self))
|
||||
volume_translation = self._handle_volume_translation(
|
||||
volume_translation, len(self)
|
||||
)
|
||||
self._local_to_world_transform = self._vsize_translation_to_transform(
|
||||
voxel_size, volume_translation, len(self)
|
||||
)
|
||||
|
||||
def clone(self) -> "Volumes":
|
||||
"""
|
||||
Deep copy of Volumes object. All internal tensors are cloned
|
||||
individually.
|
||||
|
||||
Returns:
|
||||
new Volumes object.
|
||||
"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def to(self, device, copy: bool = False) -> "Volumes":
|
||||
"""
|
||||
Match the functionality of torch.Tensor.to()
|
||||
If copy = True or the self Tensor is on a different device, the
|
||||
returned tensor is a copy of self with the desired torch.device.
|
||||
If copy = False and the self Tensor already has the correct torch.device,
|
||||
then self is returned.
|
||||
|
||||
Args:
|
||||
**device**: Device id for the new tensor.
|
||||
**copy**: Boolean indicator whether or not to clone self. Default False.
|
||||
|
||||
Returns:
|
||||
Volumes object.
|
||||
"""
|
||||
if not copy and self.device == device:
|
||||
return self
|
||||
other = self.clone()
|
||||
if self.device != device:
|
||||
other.device = device
|
||||
other._densities = self._densities.to(device)
|
||||
if self._features is not None:
|
||||
other._features = self.features().to(device)
|
||||
other._local_to_world_transform = (
|
||||
self.get_local_to_world_coords_transform().to(device)
|
||||
)
|
||||
other._grid_sizes = self._grid_sizes.to(device)
|
||||
return other
|
||||
|
||||
def cpu(self) -> "Volumes":
|
||||
return self.to(torch.device("cpu"))
|
||||
|
||||
def cuda(self) -> "Volumes":
|
||||
return self.to(torch.device("cuda"))
|
849
tests/test_volumes.py
Normal file
849
tests/test_volumes.py
Normal file
@ -0,0 +1,849 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import copy
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.structures.volumes import Volumes
|
||||
from pytorch3d.transforms import Scale
|
||||
|
||||
|
||||
class TestVolumes(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
np.random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
|
||||
@staticmethod
|
||||
def _random_volume_list(
|
||||
num_volumes, min_size, max_size, num_channels, device, rand_sizes=None
|
||||
):
|
||||
"""
|
||||
Init a list of `num_volumes` random tensors of size [num_channels, *rand_size].
|
||||
If `rand_sizes` is None, rand_size is a 3D long vector sampled
|
||||
from [min_size, max_size]. Otherwise, rand_size should be a list
|
||||
[rand_size_1, rand_size_2, ..., rand_size_num_volumes] where each
|
||||
`rand_size_i` denotes the size of the corresponding `i`-th tensor.
|
||||
"""
|
||||
if rand_sizes is None:
|
||||
rand_sizes = [
|
||||
[random.randint(min_size, vs) for vs in max_size]
|
||||
for _ in range(num_volumes)
|
||||
]
|
||||
|
||||
volume_list = [
|
||||
torch.randn(
|
||||
size=[num_channels, *rand_size], device=device, dtype=torch.float32
|
||||
)
|
||||
for rand_size in rand_sizes
|
||||
]
|
||||
|
||||
return volume_list, rand_sizes
|
||||
|
||||
def _check_indexed_volumes(self, v, selected, indices):
|
||||
for selectedIdx, index in indices:
|
||||
self.assertClose(selected.densities()[selectedIdx], v.densities()[index])
|
||||
self.assertClose(
|
||||
v._local_to_world_transform.get_matrix()[index],
|
||||
selected._local_to_world_transform.get_matrix()[selectedIdx],
|
||||
)
|
||||
if selected.features() is not None:
|
||||
self.assertClose(selected.features()[selectedIdx], v.features()[index])
|
||||
|
||||
def test_get_item(
|
||||
self,
|
||||
num_volumes=5,
|
||||
num_channels=4,
|
||||
volume_size=(10, 13, 8),
|
||||
dtype=torch.float32,
|
||||
):
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# make sure we have at least 3 volumes to prevent indexing crash
|
||||
num_volumes = max(num_volumes, 3)
|
||||
|
||||
features = torch.randn(
|
||||
size=[num_volumes, num_channels, *volume_size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
densities = torch.randn(
|
||||
size=[num_volumes, 1, *volume_size], device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
features_list, rand_sizes = TestVolumes._random_volume_list(
|
||||
num_volumes, 3, volume_size, num_channels, device
|
||||
)
|
||||
densities_list, _ = TestVolumes._random_volume_list(
|
||||
num_volumes, 3, volume_size, 1, device, rand_sizes=rand_sizes
|
||||
)
|
||||
|
||||
volume_translation = -torch.randn(num_volumes, 3).type_as(features)
|
||||
voxel_size = torch.rand(num_volumes, 1).type_as(features) + 0.5
|
||||
|
||||
for features_, densities_ in zip(
|
||||
(None, features, features_list), (densities, densities, densities_list)
|
||||
):
|
||||
|
||||
# init the volume structure
|
||||
v = Volumes(
|
||||
features=features_,
|
||||
densities=densities_,
|
||||
volume_translation=volume_translation,
|
||||
voxel_size=voxel_size,
|
||||
)
|
||||
|
||||
# int index
|
||||
index = 1
|
||||
v_selected = v[index]
|
||||
self.assertEqual(len(v_selected), 1)
|
||||
self._check_indexed_volumes(v, v_selected, [(0, 1)])
|
||||
|
||||
# list index
|
||||
index = [1, 2]
|
||||
v_selected = v[index]
|
||||
self.assertEqual(len(v_selected), len(index))
|
||||
self._check_indexed_volumes(v, v_selected, enumerate(index))
|
||||
|
||||
# slice index
|
||||
index = slice(0, 2, 1)
|
||||
v_selected = v[0:2]
|
||||
self.assertEqual(len(v_selected), 2)
|
||||
self._check_indexed_volumes(v, v_selected, [(0, 0), (1, 1)])
|
||||
|
||||
# bool tensor
|
||||
index = (torch.rand(num_volumes) > 0.5).to(device)
|
||||
index[:2] = True # make sure smth is selected
|
||||
v_selected = v[index]
|
||||
self.assertEqual(len(v_selected), index.sum())
|
||||
self._check_indexed_volumes(
|
||||
v,
|
||||
v_selected,
|
||||
zip(
|
||||
torch.arange(index.sum()),
|
||||
torch.nonzero(index, as_tuple=False).squeeze(),
|
||||
),
|
||||
)
|
||||
|
||||
# int tensor
|
||||
index = torch.tensor([1, 2], dtype=torch.int64, device=device)
|
||||
v_selected = v[index]
|
||||
self.assertEqual(len(v_selected), index.numel())
|
||||
self._check_indexed_volumes(v, v_selected, enumerate(index.tolist()))
|
||||
|
||||
# invalid index
|
||||
index = torch.tensor([1, 0, 1], dtype=torch.float32, device=device)
|
||||
with self.assertRaises(IndexError):
|
||||
v_selected = v[index]
|
||||
index = 1.2 # floating point index
|
||||
with self.assertRaises(IndexError):
|
||||
v_selected = v[index]
|
||||
|
||||
def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32):
|
||||
"""
|
||||
Test the correctness of the conversion between the internal
|
||||
Transform3D Volumes._local_to_world_transform and the initialization
|
||||
from the translation and voxel_size.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# try for 10 sets of different random sizes/centers/voxel_sizes
|
||||
for _ in range(10):
|
||||
|
||||
size = torch.randint(high=10, size=(3,), low=3).tolist()
|
||||
|
||||
densities = torch.randn(
|
||||
size=[num_volumes, num_channels, *size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# init the transformation params
|
||||
volume_translation = torch.randn(num_volumes, 3)
|
||||
voxel_size = torch.rand(num_volumes, 3) * 3.0 + 0.5
|
||||
|
||||
# get the corresponding Transform3d object
|
||||
local_offset = torch.tensor(list(size), dtype=torch.float32, device=device)[
|
||||
[2, 1, 0]
|
||||
][None].repeat(num_volumes, 1)
|
||||
local_to_world_transform = (
|
||||
Scale(0.5 * local_offset - 0.5, device=device)
|
||||
.scale(voxel_size)
|
||||
.translate(-volume_translation)
|
||||
)
|
||||
|
||||
# init the volume structures with the scale and translation,
|
||||
# then get the coord grid in world coords
|
||||
v_trans_vs = Volumes(
|
||||
densities=densities,
|
||||
voxel_size=voxel_size,
|
||||
volume_translation=volume_translation,
|
||||
)
|
||||
grid_rot_trans_vs = v_trans_vs.get_coord_grid(world_coordinates=True)
|
||||
|
||||
# map the default local coords to the world coords
|
||||
# with local_to_world_transform
|
||||
v_default = Volumes(densities=densities)
|
||||
grid_default_local = v_default.get_coord_grid(world_coordinates=False)
|
||||
grid_default_world = local_to_world_transform.transform_points(
|
||||
grid_default_local.view(num_volumes, -1, 3)
|
||||
).view(num_volumes, *size, 3)
|
||||
|
||||
# check that both grids are the same
|
||||
self.assertClose(grid_rot_trans_vs, grid_default_world, atol=1e-5)
|
||||
|
||||
# check that the transformations are the same
|
||||
self.assertClose(
|
||||
v_trans_vs.get_local_to_world_coords_transform().get_matrix(),
|
||||
local_to_world_transform.get_matrix(),
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
def test_coord_grid_convention(
|
||||
self, num_volumes=3, num_channels=4, dtype=torch.float32
|
||||
):
|
||||
"""
|
||||
Check that for a trivial volume with spatial size DxHxW=5x7x5:
|
||||
1) xyz_world=(0, 0, 0) lands right in the middle of the volume
|
||||
with xyz_local=(0, 0, 0).
|
||||
2) xyz_world=(-2, 3, 2) results in xyz_local=(-1, 1, -1).
|
||||
3) The centeral voxel of the volume coordinate grid
|
||||
has coords x_world=(0, 0, 0) and x_local=(0, 0, 0)
|
||||
4) grid_sampler(world_coordinate_grid, local_coordinate_grid)
|
||||
is the same as world_coordinate_grid itself. I.e. the local coordinate
|
||||
grid matches the grid_sampler coordinate convention.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
densities = torch.randn(
|
||||
size=[num_volumes, num_channels, 5, 7, 5],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
v_trivial = Volumes(densities=densities)
|
||||
|
||||
# check the case with x_world=(0,0,0)
|
||||
pts_world = torch.zeros(num_volumes, 1, 3, device=device, dtype=torch.float32)
|
||||
pts_local = v_trivial.world_to_local_coords(pts_world)
|
||||
pts_local_expected = torch.zeros_like(pts_local)
|
||||
self.assertClose(pts_local, pts_local_expected)
|
||||
|
||||
# check the case with x_world=(-2, 3, -2)
|
||||
pts_world = torch.tensor([-2, 3, -2], device=device, dtype=torch.float32)[
|
||||
None, None
|
||||
].repeat(num_volumes, 1, 1)
|
||||
pts_local = v_trivial.world_to_local_coords(pts_world)
|
||||
pts_local_expected = torch.tensor(
|
||||
[-1, 1, -1], device=device, dtype=torch.float32
|
||||
)[None, None].repeat(num_volumes, 1, 1)
|
||||
self.assertClose(pts_local, pts_local_expected)
|
||||
|
||||
# check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
|
||||
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
|
||||
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
|
||||
for grid in (grid_world, grid_local):
|
||||
x0 = grid[0, :, :, 2, 0]
|
||||
y0 = grid[0, :, 3, :, 1]
|
||||
z0 = grid[0, 2, :, :, 2]
|
||||
for coord_line in (x0, y0, z0):
|
||||
self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)
|
||||
|
||||
# resample grid_world using grid_sampler with local coords
|
||||
# -> make sure the resampled version is the same as original
|
||||
grid_world_resampled = torch.nn.functional.grid_sample(
|
||||
grid_world.permute(0, 4, 1, 2, 3), grid_local, align_corners=True
|
||||
).permute(0, 2, 3, 4, 1)
|
||||
self.assertClose(grid_world_resampled, grid_world, atol=1e-7)
|
||||
|
||||
def test_coord_grid_convention_heterogeneous(
|
||||
self, num_channels=4, dtype=torch.float32
|
||||
):
|
||||
"""
|
||||
Check that for a list of 2 trivial volumes with
|
||||
spatial sizes DxHxW=(5x7x5, 3x5x5):
|
||||
1) xyz_world=(0, 0, 0) lands right in the middle of the volume
|
||||
with xyz_local=(0, 0, 0).
|
||||
2) xyz_world=((-2, 3, -2), (-2, -2, 1)) results
|
||||
in xyz_local=((-1, 1, -1), (-1, -1, 1)).
|
||||
3) The centeral voxel of the volume coordinate grid
|
||||
has coords x_world=(0, 0, 0) and x_local=(0, 0, 0)
|
||||
4) grid_sampler(world_coordinate_grid, local_coordinate_grid)
|
||||
is the same as world_coordinate_grid itself. I.e. the local coordinate
|
||||
grid matches the grid_sampler coordinate convention.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
sizes = [(5, 7, 5), (3, 5, 5)]
|
||||
|
||||
densities_list = [
|
||||
torch.randn(size=[num_channels, *size], device=device, dtype=torch.float32)
|
||||
for size in sizes
|
||||
]
|
||||
|
||||
# init the volume
|
||||
v_trivial = Volumes(densities=densities_list)
|
||||
|
||||
# check the border point locations
|
||||
pts_world = torch.tensor(
|
||||
[[-2.0, 3.0, -2.0], [-2.0, -2.0, 1.0]], device=device, dtype=torch.float32
|
||||
)[:, None]
|
||||
pts_local = v_trivial.world_to_local_coords(pts_world)
|
||||
pts_local_expected = torch.tensor(
|
||||
[[-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0]], device=device, dtype=torch.float32
|
||||
)[:, None]
|
||||
self.assertClose(pts_local, pts_local_expected)
|
||||
|
||||
# check that the central voxel has coords x_world=(0, 0, 0) and x_local(0, 0, 0)
|
||||
grid_world = v_trivial.get_coord_grid(world_coordinates=True)
|
||||
grid_local = v_trivial.get_coord_grid(world_coordinates=False)
|
||||
for grid in (grid_world, grid_local):
|
||||
x0 = grid[0, :, :, 2, 0]
|
||||
y0 = grid[0, :, 3, :, 1]
|
||||
z0 = grid[0, 2, :, :, 2]
|
||||
for coord_line in (x0, y0, z0):
|
||||
self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)
|
||||
x0 = grid[1, :, :, 2, 0]
|
||||
y0 = grid[1, :, 2, :, 1]
|
||||
z0 = grid[1, 1, :, :, 2]
|
||||
for coord_line in (x0, y0, z0):
|
||||
self.assertClose(coord_line, torch.zeros_like(coord_line), atol=1e-7)
|
||||
|
||||
# resample grid_world using grid_sampler with local coords
|
||||
# -> make sure the resampled version is the same as original
|
||||
for grid_world_, grid_local_, size in zip(grid_world, grid_local, sizes):
|
||||
grid_world_crop = grid_world_[: size[0], : size[1], : size[2], :][None]
|
||||
grid_local_crop = grid_local_[: size[0], : size[1], : size[2], :][None]
|
||||
grid_world_crop_resampled = torch.nn.functional.grid_sample(
|
||||
grid_world_crop.permute(0, 4, 1, 2, 3),
|
||||
grid_local_crop,
|
||||
align_corners=True,
|
||||
).permute(0, 2, 3, 4, 1)
|
||||
self.assertClose(grid_world_crop_resampled, grid_world_crop, atol=1e-7)
|
||||
|
||||
def test_coord_grid_transforms(
|
||||
self, num_volumes=3, num_channels=4, dtype=torch.float32
|
||||
):
|
||||
"""
|
||||
Test whether conversion between local-world coordinates of the
|
||||
volume returns correct results.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# try for 10 sets of different random sizes/centers/voxel_sizes
|
||||
for _ in range(10):
|
||||
|
||||
size = torch.randint(high=10, size=(3,), low=3).tolist()
|
||||
|
||||
center = torch.randn(num_volumes, 3, dtype=torch.float32, device=device)
|
||||
voxel_size = torch.rand(1, dtype=torch.float32, device=device) * 5.0 + 0.5
|
||||
|
||||
for densities in (
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, *size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
TestVolumes._random_volume_list(
|
||||
num_volumes, 3, size, num_channels, device, rand_sizes=None
|
||||
)[0],
|
||||
):
|
||||
|
||||
# init the volume structure
|
||||
v = Volumes(
|
||||
densities=densities,
|
||||
voxel_size=voxel_size,
|
||||
volume_translation=-center,
|
||||
)
|
||||
|
||||
# get local coord grid
|
||||
grid_local = v.get_coord_grid(world_coordinates=False)
|
||||
|
||||
# convert from world to local to world
|
||||
grid_world = v.get_coord_grid(world_coordinates=True)
|
||||
grid_local_2 = v.world_to_local_coords(grid_world)
|
||||
grid_world_2 = v.local_to_world_coords(grid_local_2)
|
||||
|
||||
# assertions on shape and values of grid_world and grid_local
|
||||
self.assertClose(grid_world, grid_world_2, atol=1e-5)
|
||||
self.assertClose(grid_local, grid_local_2, atol=1e-5)
|
||||
|
||||
# check that the individual slices of the location grid have
|
||||
# constant values along expected dimensions
|
||||
for plane_dim in (1, 2, 3):
|
||||
for grid_plane in grid_world.split(1, dim=plane_dim):
|
||||
grid_coord_dim = {1: 2, 2: 1, 3: 0}[plane_dim]
|
||||
grid_coord_plane = grid_plane.squeeze()[..., grid_coord_dim]
|
||||
# check that all elements of grid_coord_plane are
|
||||
# the same for each batch element
|
||||
self.assertClose(
|
||||
grid_coord_plane.reshape(num_volumes, -1).max(dim=1).values,
|
||||
grid_coord_plane.reshape(num_volumes, -1).min(dim=1).values,
|
||||
)
|
||||
|
||||
def test_clone(
|
||||
self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
|
||||
):
|
||||
"""
|
||||
Test cloning of a `Volumes` object
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
features = torch.randn(
|
||||
size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32
|
||||
)
|
||||
densities = torch.rand(
|
||||
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
for has_features in (True, False):
|
||||
v = Volumes(
|
||||
densities=densities, features=features if has_features else None
|
||||
)
|
||||
vnew = v.clone()
|
||||
vnew._densities.data[0, 0, 0, 0, 0] += 1.0
|
||||
self.assertNotAlmostEqual(
|
||||
float(
|
||||
(vnew.densities()[0, 0, 0, 0, 0] - v.densities()[0, 0, 0, 0, 0])
|
||||
.abs()
|
||||
.max()
|
||||
),
|
||||
0.0,
|
||||
)
|
||||
|
||||
if has_features:
|
||||
vnew._features.data[0, 0, 0, 0, 0] += 1.0
|
||||
self.assertNotAlmostEqual(
|
||||
float(
|
||||
(vnew.features()[0, 0, 0, 0, 0] - v.features()[0, 0, 0, 0, 0])
|
||||
.abs()
|
||||
.max()
|
||||
),
|
||||
0.0,
|
||||
)
|
||||
|
||||
def _check_vars_on_device(self, v, desired_device):
|
||||
for var_name, var in vars(v).items():
|
||||
if var_name != "device":
|
||||
if var is not None:
|
||||
self.assertTrue(var.device.type == desired_device.type)
|
||||
else:
|
||||
self.assertTrue(var.type == desired_device.type)
|
||||
|
||||
def test_to(
|
||||
self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
|
||||
):
|
||||
"""
|
||||
Test the moving of the volumes from/to gpu and cpu
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
device_cpu = torch.device("cpu")
|
||||
|
||||
features = torch.randn(
|
||||
size=[num_volumes, num_channels, *size], device=device, dtype=torch.float32
|
||||
)
|
||||
densities = torch.rand(size=[num_volumes, 1, *size], device=device, dtype=dtype)
|
||||
|
||||
for features_ in (features, None):
|
||||
v = Volumes(densities=densities, features=features_)
|
||||
|
||||
v_cpu = v.cpu()
|
||||
v_cuda = v_cpu.cuda()
|
||||
v_cuda_2 = v_cuda.cuda()
|
||||
v_cpu_2 = v_cuda_2.cpu()
|
||||
|
||||
for v1, v2 in itertools.combinations(
|
||||
(v, v_cpu, v_cpu_2, v_cuda, v_cuda_2), 2
|
||||
):
|
||||
if v1 is v_cuda and v2 is v_cuda_2:
|
||||
# checks that we do not copy if the devices stay the same
|
||||
assert_fun = self.assertIs
|
||||
else:
|
||||
assert_fun = self.assertSeparate
|
||||
assert_fun(v1._densities, v2._densities)
|
||||
if features_ is not None:
|
||||
assert_fun(v1._features, v2._features)
|
||||
for v_ in (v1, v2):
|
||||
if v_ in (v_cpu, v_cpu_2):
|
||||
self._check_vars_on_device(v_, device_cpu)
|
||||
else:
|
||||
self._check_vars_on_device(v_, device)
|
||||
|
||||
def _check_padded(self, x_pad, x_list, grid_sizes):
|
||||
"""
|
||||
Check that padded tensors x_pad are the same as x_list tensors.
|
||||
"""
|
||||
num_volumes = len(x_list)
|
||||
for i in range(num_volumes):
|
||||
self.assertClose(
|
||||
x_pad[i][:, : grid_sizes[i][0], : grid_sizes[i][1], : grid_sizes[i][2]],
|
||||
x_list[i],
|
||||
)
|
||||
|
||||
def test_feature_density_setters(self):
|
||||
"""
|
||||
Tests getters and setters for padded/list representations.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
diff_device = torch.device("cpu")
|
||||
|
||||
num_volumes = 30
|
||||
num_channels = 4
|
||||
K = 20
|
||||
|
||||
densities = []
|
||||
features = []
|
||||
grid_sizes = []
|
||||
diff_grid_sizes = []
|
||||
|
||||
for _ in range(num_volumes):
|
||||
grid_size = torch.randint(K - 1, size=(3,)).long() + 1
|
||||
densities.append(
|
||||
torch.rand((1, *grid_size), device=device, dtype=torch.float32)
|
||||
)
|
||||
features.append(
|
||||
torch.rand(
|
||||
(num_channels, *grid_size), device=device, dtype=torch.float32
|
||||
)
|
||||
)
|
||||
grid_sizes.append(grid_size)
|
||||
|
||||
diff_grid_size = (
|
||||
copy.deepcopy(grid_size) + torch.randint(2, size=(3,)).long() + 1
|
||||
)
|
||||
diff_grid_sizes.append(diff_grid_size)
|
||||
grid_sizes = torch.stack(grid_sizes).to(device)
|
||||
diff_grid_sizes = torch.stack(diff_grid_sizes).to(device)
|
||||
|
||||
volumes = Volumes(densities=densities, features=features)
|
||||
self.assertClose(volumes.get_grid_sizes(), grid_sizes)
|
||||
|
||||
# test the getters
|
||||
features_padded = volumes.features()
|
||||
densities_padded = volumes.densities()
|
||||
features_list = volumes.features_list()
|
||||
densities_list = volumes.densities_list()
|
||||
for x_pad, x_list in zip(
|
||||
(densities_padded, features_padded, densities_padded, features_padded),
|
||||
(densities_list, features_list, densities, features),
|
||||
):
|
||||
self._check_padded(x_pad, x_list, grid_sizes)
|
||||
|
||||
# test feature setters
|
||||
features_new = [
|
||||
torch.rand((num_channels, *grid_size), device=device, dtype=torch.float32)
|
||||
for grid_size in grid_sizes
|
||||
]
|
||||
volumes._set_features(features_new)
|
||||
features_new_list = volumes.features_list()
|
||||
features_new_padded = volumes.features()
|
||||
for x_pad, x_list in zip(
|
||||
(features_new_padded, features_new_padded),
|
||||
(features_new, features_new_list),
|
||||
):
|
||||
self._check_padded(x_pad, x_list, grid_sizes)
|
||||
|
||||
# wrong features to update
|
||||
bad_features_new = [
|
||||
[
|
||||
torch.rand(
|
||||
(num_channels, *grid_size), device=diff_device, dtype=torch.float32
|
||||
)
|
||||
for grid_size in diff_grid_sizes
|
||||
],
|
||||
torch.rand(
|
||||
(num_volumes, num_channels, K + 1, K, K),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
None,
|
||||
]
|
||||
for bad_features_new_ in bad_features_new:
|
||||
with self.assertRaises(ValueError):
|
||||
volumes._set_densities(bad_features_new_)
|
||||
|
||||
# test density setters
|
||||
densities_new = [
|
||||
torch.rand((1, *grid_size), device=device, dtype=torch.float32)
|
||||
for grid_size in grid_sizes
|
||||
]
|
||||
volumes._set_densities(densities_new)
|
||||
densities_new_list = volumes.densities_list()
|
||||
densities_new_padded = volumes.densities()
|
||||
for x_pad, x_list in zip(
|
||||
(densities_new_padded, densities_new_padded),
|
||||
(densities_new, densities_new_list),
|
||||
):
|
||||
self._check_padded(x_pad, x_list, grid_sizes)
|
||||
|
||||
# wrong densities to update
|
||||
bad_densities_new = [
|
||||
[
|
||||
torch.rand((1, *grid_size), device=diff_device, dtype=torch.float32)
|
||||
for grid_size in diff_grid_sizes
|
||||
],
|
||||
torch.rand(
|
||||
(num_volumes, 1, K + 1, K, K), device=device, dtype=torch.float32
|
||||
),
|
||||
None,
|
||||
]
|
||||
for bad_densities_new_ in bad_densities_new:
|
||||
with self.assertRaises(ValueError):
|
||||
volumes._set_densities(bad_densities_new_)
|
||||
|
||||
# test update_padded
|
||||
volumes = Volumes(densities=densities, features=features)
|
||||
volumes_updated = volumes.update_padded(
|
||||
densities_new, new_features=features_new
|
||||
)
|
||||
densities_new_list = volumes_updated.densities_list()
|
||||
densities_new_padded = volumes_updated.densities()
|
||||
features_new_list = volumes_updated.features_list()
|
||||
features_new_padded = volumes_updated.features()
|
||||
for x_pad, x_list in zip(
|
||||
(
|
||||
densities_new_padded,
|
||||
densities_new_padded,
|
||||
features_new_padded,
|
||||
features_new_padded,
|
||||
),
|
||||
(densities_new, densities_new_list, features_new, features_new_list),
|
||||
):
|
||||
self._check_padded(x_pad, x_list, grid_sizes)
|
||||
self.assertIs(volumes.get_grid_sizes(), volumes_updated.get_grid_sizes())
|
||||
self.assertIs(
|
||||
volumes.get_local_to_world_coords_transform(),
|
||||
volumes_updated.get_local_to_world_coords_transform(),
|
||||
)
|
||||
self.assertIs(volumes.device, volumes_updated.device)
|
||||
|
||||
def test_constructor_for_padded_lists(self):
|
||||
"""
|
||||
Tests constructor for padded/list representations.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
diff_device = torch.device("cpu")
|
||||
|
||||
num_volumes = 3
|
||||
num_channels = 4
|
||||
size = (6, 8, 10)
|
||||
diff_size = (6, 8, 11)
|
||||
|
||||
# good ways to define densities
|
||||
ok_densities = [
|
||||
torch.randn(
|
||||
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
||||
).unbind(0),
|
||||
torch.randn(
|
||||
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
||||
),
|
||||
]
|
||||
|
||||
# bad ways to define features
|
||||
bad_features = [
|
||||
torch.randn(
|
||||
size=[num_volumes + 1, num_channels, *size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
).unbind(
|
||||
0
|
||||
), # list with diff batch size
|
||||
torch.randn(
|
||||
size=[num_volumes + 1, num_channels, *size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
), # diff batch size
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, *diff_size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
).unbind(
|
||||
0
|
||||
), # list with different size
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, *diff_size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
), # different size
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, *size],
|
||||
device=diff_device,
|
||||
dtype=torch.float32,
|
||||
), # different device
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, *size],
|
||||
device=diff_device,
|
||||
dtype=torch.float32,
|
||||
).unbind(
|
||||
0
|
||||
), # list with different device
|
||||
]
|
||||
|
||||
# good ways to define features
|
||||
ok_features = [
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, *size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
).unbind(
|
||||
0
|
||||
), # list of features of correct size
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, *size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
),
|
||||
]
|
||||
|
||||
for densities in ok_densities:
|
||||
for features in bad_features:
|
||||
self.assertRaises(
|
||||
ValueError, Volumes, densities=densities, features=features
|
||||
)
|
||||
for features in ok_features:
|
||||
Volumes(densities=densities, features=features)
|
||||
|
||||
def test_constructor(
|
||||
self, num_volumes=3, num_channels=4, size=(6, 8, 10), dtype=torch.float32
|
||||
):
|
||||
"""
|
||||
Test different ways of calling the `Volumes` constructor
|
||||
"""
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# all ways to define features
|
||||
features = [
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, *size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
), # padded tensor
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, *size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
).unbind(
|
||||
0
|
||||
), # list of features
|
||||
None, # no features
|
||||
]
|
||||
|
||||
# bad ways to define features
|
||||
bad_features = [
|
||||
torch.randn(
|
||||
size=[num_volumes, num_channels, 2, *size],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
), # 6 dims
|
||||
torch.randn(
|
||||
size=[num_volumes, *size], device=device, dtype=torch.float32
|
||||
), # 4 dims
|
||||
torch.randn(
|
||||
size=[num_volumes, *size], device=device, dtype=torch.float32
|
||||
).unbind(
|
||||
0
|
||||
), # list of 4 dim tensors
|
||||
]
|
||||
|
||||
# all ways to define densities
|
||||
densities = [
|
||||
torch.randn(
|
||||
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
||||
), # padded tensor
|
||||
torch.randn(
|
||||
size=[num_volumes, 1, *size], device=device, dtype=torch.float32
|
||||
).unbind(
|
||||
0
|
||||
), # list of densities
|
||||
]
|
||||
|
||||
# bad ways to define densities
|
||||
bad_densities = [
|
||||
None, # omitted
|
||||
torch.randn(
|
||||
size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32
|
||||
), # 6-dim tensor
|
||||
torch.randn(
|
||||
size=[num_volumes, 1, 1, *size], device=device, dtype=torch.float32
|
||||
).unbind(
|
||||
0
|
||||
), # list of 5-dim densities
|
||||
]
|
||||
|
||||
# all possible ways to define the voxels sizes
|
||||
vox_sizes = [
|
||||
torch.Tensor([1.0, 1.0, 1.0]),
|
||||
[1.0, 1.0, 1.0],
|
||||
torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1),
|
||||
torch.Tensor([1.0])[None].repeat(num_volumes, 1),
|
||||
1.0,
|
||||
torch.Tensor([1.0]),
|
||||
]
|
||||
|
||||
# all possible ways to define the volume translations
|
||||
vol_translations = [
|
||||
torch.Tensor([1.0, 1.0, 1.0]),
|
||||
[1.0, 1.0, 1.0],
|
||||
torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes, 1),
|
||||
]
|
||||
|
||||
# wrong ways to define voxel sizes
|
||||
bad_vox_sizes = [
|
||||
torch.Tensor([1.0, 1.0, 1.0, 1.0]),
|
||||
[1.0, 1.0, 1.0, 1.0],
|
||||
torch.Tensor([]),
|
||||
None,
|
||||
]
|
||||
|
||||
# wrong ways to define the volume translations
|
||||
bad_vol_translations = [
|
||||
torch.Tensor([1.0, 1.0]),
|
||||
[1.0, 1.0],
|
||||
1.0,
|
||||
torch.Tensor([1.0, 1.0, 1.0])[None].repeat(num_volumes + 1, 1),
|
||||
]
|
||||
|
||||
def zip_with_ok_indicator(good, bad):
|
||||
return zip([*good, *bad], [*([True] * len(good)), *([False] * len(bad))])
|
||||
|
||||
for features_, features_ok in zip_with_ok_indicator(features, bad_features):
|
||||
for densities_, densities_ok in zip_with_ok_indicator(
|
||||
densities, bad_densities
|
||||
):
|
||||
for vox_size, size_ok in zip_with_ok_indicator(
|
||||
vox_sizes, bad_vox_sizes
|
||||
):
|
||||
for vol_translation, trans_ok in zip_with_ok_indicator(
|
||||
vol_translations, bad_vol_translations
|
||||
):
|
||||
if (
|
||||
size_ok and trans_ok and features_ok and densities_ok
|
||||
): # if all entries are good we check that this doesnt throw
|
||||
Volumes(
|
||||
features=features_,
|
||||
densities=densities_,
|
||||
voxel_size=vox_size,
|
||||
volume_translation=vol_translation,
|
||||
)
|
||||
|
||||
else: # otherwise we check for ValueError
|
||||
self.assertRaises(
|
||||
ValueError,
|
||||
Volumes,
|
||||
features=features_,
|
||||
densities=densities_,
|
||||
voxel_size=vox_size,
|
||||
volume_translation=vol_translation,
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user