Split Volumes class to data and location part

Summary: Split Volumes class to data and location part so that location part can be reused in planned VoxelGrid classes.

Reviewed By: bottler

Differential Revision: D38782015

fbshipit-source-id: 489da09c5c236f3b81961ce9b09edbd97afaa7c8
This commit is contained in:
Darijan Gudelj 2022-08-18 08:12:33 -07:00 committed by Facebook GitHub Bot
parent fdaaa299a7
commit f825f7e42c
2 changed files with 721 additions and 297 deletions

View File

@ -23,6 +23,7 @@ _VoxelSize = _ScalarOrVector
_Translation = _Vector
_TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
_ALL_CONTENT: slice = slice(0, None)
class Volumes:
@ -65,9 +66,9 @@ class Volumes:
VOLUME COORDINATES
Additionally, the `Volumes` class keeps track of the locations of the
centers of the volume cells in the local volume coordinates as well as in
the world coordinates.
Additionally, using the `VolumeLocator` class the `Volumes` class keeps track
of the locations of the centers of the volume cells in the local volume
coordinates as well as in the world coordinates.
Local coordinates:
- Represent the locations of the volume cells in the local coordinate
@ -125,7 +126,7 @@ class Volumes:
appropriate `world_coordinates` argument.
Internally, the mapping between `x_local` and `x_world` is represented
as a `Transform3d` object `Volumes._local_to_world_transform`.
as a `Transform3d` object `Volumes.VolumeLocator._local_to_world_transform`.
Users can access the relevant transformations with the
`Volumes.get_world_to_local_coords_transform()` and
`Volumes.get_local_to_world_coords_transform()`
@ -197,21 +198,24 @@ class Volumes:
# assign to the internal buffers
self._densities = densities_
self._grid_sizes = grid_sizes
# assign a coordinate transformation member
self.locator = VolumeLocator(
batch_size=len(self),
grid_sizes=grid_sizes,
voxel_size=voxel_size,
volume_translation=volume_translation,
device=self.device,
)
# handle features
self._features = None
if features is not None:
self._set_features(features)
# set the local_to_world transform
self._set_local_to_world_transform(
voxel_size=voxel_size, volume_translation=volume_translation
)
def _convert_densities_features_to_tensor(
self, x: _TensorBatch, var_name: str
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""
Handle the `densities` or `features` arguments to the constructor.
"""
@ -251,8 +255,492 @@ class Volumes:
f"{var_name} must be either a list or a tensor with "
f"shape (batch_size, {var_name}_dim, H, W, D)."
)
# pyre-ignore[7]
return x_tensor, x_shapes
def __len__(self) -> int:
return self._densities.shape[0]
def __getitem__(
self,
index: Union[
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
],
) -> "Volumes":
"""
Args:
index: Specifying the index of the volume to retrieve.
Can be an int, slice, list of ints or a boolean or a long tensor.
Returns:
Volumes object with selected volumes. The tensors are not cloned.
"""
if isinstance(index, int):
index = torch.LongTensor([index])
elif isinstance(index, (slice, list, tuple)):
pass
elif torch.is_tensor(index):
if index.dim() != 1 or index.dtype.is_floating_point:
raise IndexError(index)
else:
raise IndexError(index)
new = self.__class__(
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
features=self.features()[index] if self._features is not None else None,
densities=self.densities()[index],
)
# dont forget to update grid_sizes!
self.locator._copy_transform_and_sizes(new.locator, index=index)
return new
def features(self) -> Optional[torch.Tensor]:
"""
Returns the features of the volume.
Returns:
**features**: The tensor of volume features.
"""
return self._features
def densities(self) -> torch.Tensor:
"""
Returns the densities of the volume.
Returns:
**densities**: The tensor of volume densities.
"""
return self._densities
def densities_list(self) -> List[torch.Tensor]:
"""
Get the list representation of the densities.
Returns:
list of tensors of densities of shape (dim_i, D_i, H_i, W_i).
"""
return self._features_densities_list(self.densities())
def features_list(self) -> List[torch.Tensor]:
"""
Get the list representation of the features.
Returns:
list of tensors of features of shape (dim_i, D_i, H_i, W_i)
or `None` for feature-less volumes.
"""
features_ = self.features()
if features_ is None:
# No features provided so return None
# pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
return None
return self._features_densities_list(features_)
def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Retrieve the list representation of features/densities.
Args:
x: self.features() or self.densities()
Returns:
list of tensors of features/densities of shape (dim_i, D_i, H_i, W_i).
"""
x_dim = x.shape[1]
pad_sizes = torch.nn.functional.pad(
self.get_grid_sizes(), [1, 0], mode="constant", value=x_dim
)
x_list = struct_utils.padded_to_list(x, pad_sizes.tolist())
return x_list
def update_padded(
self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
) -> "Volumes":
"""
Returns a Volumes structure with updated padded tensors and copies of
the auxiliary tensors `self._local_to_world_transform`,
`device` and `self._grid_sizes`. This function allows for an update of
densities (and features) without having to explicitly
convert it to the list representation for heterogeneous batches.
Args:
new_densities: FloatTensor of shape (N, dim_density, D, H, W)
new_features: (optional) FloatTensor of shape (N, dim_feature, D, H, W)
Returns:
Volumes with updated features and densities
"""
new = copy.copy(self)
new._set_densities(new_densities)
if new_features is None:
new._features = None
else:
new._set_features(new_features)
return new
def _set_features(self, features: _TensorBatch) -> None:
self._set_densities_features("features", features)
def _set_densities(self, densities: _TensorBatch) -> None:
self._set_densities_features("densities", densities)
def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None:
x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
if x_tensor.device != self.device:
raise ValueError(
f"`{var_name}` have to be on the same device as `self.densities`."
)
if len(x_tensor.shape) != 5:
raise ValueError(
f"{var_name} has to be a 5-dim tensor of shape: "
f"(minibatch, {var_name}_dim, height, width, depth)"
)
if not (
(self.get_grid_sizes().shape == grid_sizes.shape)
and torch.allclose(self.get_grid_sizes(), grid_sizes)
):
raise ValueError(
f"The size of every grid in `{var_name}` has to match the size of"
"the corresponding `densities` grid."
)
setattr(self, "_" + var_name, x_tensor)
def clone(self) -> "Volumes":
"""
Deep copy of Volumes object. All internal tensors are cloned
individually.
Returns:
new Volumes object.
"""
return copy.deepcopy(self)
def to(self, device: Device, copy: bool = False) -> "Volumes":
"""
Match the functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
returned tensor is a copy of self with the desired torch.device.
If copy = False and the self Tensor already has the correct torch.device,
then self is returned.
Args:
device: Device (as str or torch.device) for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False.
Returns:
Volumes object.
"""
device_ = make_device(device)
if not copy and self.device == device_:
return self
other = self.clone()
if self.device == device_:
return other
other.device = device_
other._densities = self._densities.to(device_)
if self._features is not None:
# pyre-fixme[16]: `Optional` has no attribute `to`.
other._features = self.features().to(device_)
self.locator._copy_transform_and_sizes(other.locator, device=device_)
other.locator = other.locator.to(device, copy)
return other
def cpu(self) -> "Volumes":
return self.to("cpu")
def cuda(self) -> "Volumes":
return self.to("cuda")
def get_grid_sizes(self) -> torch.LongTensor:
"""
Returns the sizes of individual volumetric grids in the structure.
Returns:
**grid_sizes**: Tensor of spatial sizes of each of the volumes
of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i).
"""
return self.locator.get_grid_sizes()
def get_local_to_world_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
the local coordinate frame of the volume to world coordinates.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**local_to_world_transform**: A Transform3d object converting
points from local coordinates to the world coordinates.
"""
return self.locator.get_local_to_world_coords_transform()
def get_world_to_local_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
world coordinates to the local coordinate frame of the volume.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**world_to_local_transform**: A Transform3d object converting
points from world coordinates to local coordinates.
"""
return self.get_local_to_world_coords_transform().inverse()
def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_world` of shape
(minibatch, ..., dim) in the world coordinates to
the local coordinate frame of the volume. Local volume
coordinates are scaled s.t. the coordinates along one side of the volume
are in range [-1, 1].
Args:
**points_3d_world**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_local**: `points_3d_world` converted to the local
volume coordinates of shape `(minibatch, ..., 3)`.
"""
return self.locator.world_to_local_coords(points_3d_world)
def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_local` of shape
(minibatch, ..., dim) in the local coordinate frame of the volume
to the world coordinates.
Args:
**points_3d_local**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_world**: `points_3d_local` converted to the world
coordinates of the volume of shape `(minibatch, ..., 3)`.
"""
return self.locator.local_to_world_coords(points_3d_local)
def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
"""
Return the 3D coordinate grid of the volumetric grid
in local (`world_coordinates=False`) or world coordinates
(`world_coordinates=True`).
The grid records location of each center of the corresponding volume voxel.
Local coordinates are scaled s.t. the values along one side of the
volume are in range [-1, 1].
Args:
**world_coordinates**: if `True`, the method
returns the grid in the world coordinates,
otherwise, in local coordinates.
Returns:
**coordinate_grid**: The grid of coordinates of shape
`(minibatch, depth, height, width, 3)`, where `minibatch`,
`height`, `width` and `depth` are the batch size, height, width
and depth of the volume `features` or `densities`.
"""
return self.locator.get_coord_grid(world_coordinates)
class VolumeLocator:
"""
The `VolumeLocator` class keeps track of the locations of the
centers of the volume cells in the local volume coordinates as well as in
the world coordinates for a voxel grid structure in 3D.
Local coordinates:
- Represent the locations of the volume cells in the local coordinate
frame of the volume.
- The center of the voxel indexed with `[·, ·, 0, 0, 0]` in the volume
has its 3D local coordinate set to `[-1, -1, -1]`, while the voxel
at index `[·, ·, depth_i-1, height_i-1, width_i-1]` has its
3D local coordinate set to `[1, 1, 1]`.
- The first/second/third coordinate of each of the 3D per-voxel
XYZ vector denotes the horizontal/vertical/depth-wise position
respectively. I.e the order of the coordinate dimensions in the
volume is reversed w.r.t. the order of the 3D coordinate vectors.
- The intermediate coordinates between `[-1, -1, -1]` and `[1, 1, 1]`.
are linearly interpolated over the spatial dimensions of the volume.
- Note that the convention is the same as for the 5D version of the
`torch.nn.functional.grid_sample` function called with
`align_corners==True`.
- Note that the local coordinate convention of `VolumeLocator`
(+X = left to right, +Y = top to bottom, +Z = away from the user)
is *different* from the world coordinate convention of the
renderer for `Meshes` or `Pointclouds`
(+X = right to left, +Y = bottom to top, +Z = away from the user).
World coordinates:
- These define the locations of the centers of the volume cells
in the world coordinates.
- They are specified with the following mapping that converts
points `x_local` in the local coordinates to points `x_world`
in the world coordinates:
```
x_world = (
x_local * (volume_size - 1) * 0.5 * voxel_size
) - volume_translation,
```
here `voxel_size` specifies the size of each voxel of the volume,
and `volume_translation` is the 3D offset of the central voxel of
the volume w.r.t. the origin of the world coordinate frame.
Both `voxel_size` and `volume_translation` are specified in
the world coordinate units. `volume_size` is the spatial size of
the volume in form of a 3D vector `[width, height, depth]`.
- Given the above definition of `x_world`, one can derive the
inverse mapping from `x_world` to `x_local` as follows:
```
x_local = (
(x_world + volume_translation) / (0.5 * voxel_size)
) / (volume_size - 1)
```
- For a trivial volume with `volume_translation==[0, 0, 0]`
with `voxel_size=-1`, `x_world` would range
from -(volume_size-1)/2` to `+(volume_size-1)/2`.
Coordinate tensors that denote the locations of each of the volume cells in
local / world coordinates (with shape `(depth x height x width x 3)`)
can be retrieved by calling the `VolumeLocator.get_coord_grid()` getter with the
appropriate `world_coordinates` argument.
Internally, the mapping between `x_local` and `x_world` is represented
as a `Transform3d` object `VolumeLocator._local_to_world_transform`.
Users can access the relevant transformations with the
`VolumeLocator.get_world_to_local_coords_transform()` and
`VolumeLocator.get_local_to_world_coords_transform()`
functions.
Example coordinate conversion:
- For a "trivial" volume with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, and the spatial size of
`DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped
to `x_local=(-1, 0, 1)`.
- For a "trivial" volume `v` with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, the following holds:
```
torch.nn.functional.grid_sample(
v.densities(),
v.get_coord_grid(world_coordinates=False),
align_corners=True,
) == v.densities(),
```
i.e. sampling the volume at trivial local coordinates
(no scaling with `voxel_size`` or shift with `volume_translation`)
results in the same volume.
"""
def __init__(
self,
batch_size: int,
grid_sizes: Union[
torch.LongTensor, Tuple[int, int, int], List[torch.LongTensor]
],
device: torch.device,
voxel_size: _VoxelSize = 1.0,
volume_translation: _Translation = (0.0, 0.0, 0.0),
):
"""
**batch_size** : Batch size of the underlaying grids
**grid_sizes** : Represents the resolutions of different grids in the batch. Can be
a) tuple of form (H, W, D)
b) list/tuple of length batch_size of lists/tuples of form (H, W, D)
c) torch.Tensor of shape (batch_size, H, W, D)
H, W, D are height, width, depth respectively. If `grid_sizes` is a tuple than
all the grids in the batch have the same resolution.
**voxel_size**: Denotes the size of each volume voxel in world units.
Has to be one of:
a) A scalar (square voxels)
b) 3-tuple or a 3-list of scalars
c) a Tensor of shape (3,)
d) a Tensor of shape (minibatch, 3)
e) a Tensor of shape (minibatch, 1)
f) a Tensor of shape (1,) (square voxels)
**volume_translation**: Denotes the 3D translation of the center
of the volume in world units. Has to be one of:
a) 3-tuple or a 3-list of scalars
b) a Tensor of shape (3,)
c) a Tensor of shape (minibatch, 3)
d) a Tensor of shape (1,) (square voxels)
"""
self.device = device
self._batch_size = batch_size
self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes)
self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values)
# set the local_to_world transform
self._set_local_to_world_transform(
voxel_size=voxel_size, volume_translation=volume_translation
)
def _convert_grid_sizes2tensor(
self, x: Union[torch.LongTensor, List[torch.LongTensor], Tuple[int, int, int]]
) -> torch.LongTensor:
"""
Handle the grid_sizes argument to the constructor.
"""
if isinstance(x, (list, tuple)):
if isinstance(x[0], (torch.LongTensor, list, tuple)):
if self._batch_size != len(x):
raise ValueError("x should have a batch size of 'batch_size'")
# pyre-ignore[6]
if any(len(x_) != 3 for x_ in x):
raise ValueError(
"`grid_sizes` has to be a list of 3-dim tensors of shape: "
"(height, width, depth)"
)
x_shapes = torch.stack(
[
torch.tensor(
# pyre-ignore[6]
list(x_),
dtype=torch.long,
device=self.device,
)
for x_ in x
],
dim=0,
)
elif isinstance(x[0], int):
x_shapes = torch.stack(
[
torch.tensor(list(x), dtype=torch.long, device=self.device)
for _ in range(self._batch_size)
],
dim=0,
)
else:
raise ValueError(
"`grid_sizes` can be a list/tuple of int or torch.Tensor not of "
+ "{type(x[0])}."
)
elif torch.is_tensor(x):
if x.ndim != 2:
raise ValueError(
"`grid_sizes` has to be a 2-dim tensor of shape: (minibatch, 3)"
)
x_shapes = x.to(self.device)
else:
raise ValueError(
"grid_sizes must be either a list of tensors with shape (H, W, D), tensor with"
"shape (batch_size, H, W, D) or a tuple of (H, W, D)."
)
# pyre-ignore[7]
return x_shapes
def _voxel_size_translation_to_transform(
self,
voxel_size: torch.Tensor,
@ -280,75 +768,6 @@ class Volumes:
return local_to_world_transform
def _handle_voxel_size(
self, voxel_size: _VoxelSize, batch_size: int
) -> torch.Tensor:
"""
Handle the `voxel_size` argument to the `Volumes` constructor.
"""
err_msg = (
"voxel_size has to be either a 3-tuple of scalars, or a scalar, or"
" a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)."
)
if isinstance(voxel_size, (float, int)):
# convert a scalar to a 3-element tensor
voxel_size = torch.full(
(1, 3), voxel_size, device=self.device, dtype=torch.float32
)
elif isinstance(voxel_size, torch.Tensor):
if voxel_size.numel() == 1:
# convert a single-element tensor to a 3-element one
voxel_size = voxel_size.view(-1).repeat(3)
elif len(voxel_size.shape) == 2 and (
voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1
):
voxel_size = voxel_size.repeat(1, 3)
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
def _handle_volume_translation(
self, translation: _Translation, batch_size: int
) -> torch.Tensor:
"""
Handle the `volume_translation` argument to the `Volumes` constructor.
"""
err_msg = (
"`volume_translation` has to be either a 3-tuple of scalars, or"
" a Tensor of shape (1,3) or (minibatch, 3) or (3,)`."
)
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
def _convert_volume_property_to_tensor(
self, x: _Vector, batch_size: int, err_msg: str
) -> torch.Tensor:
"""
Handle the `volume_translation` or `voxel_size` argument to
the Volumes constructor.
Return a tensor of shape (N, 3) where N is the batch_size.
"""
if isinstance(x, (list, tuple)):
if len(x) != 3:
raise ValueError(err_msg)
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
x = x.repeat((batch_size, 1))
elif isinstance(x, torch.Tensor):
ok = (
(x.shape[0] == 1 and x.shape[1] == 3)
or (x.shape[0] == 3 and len(x.shape) == 1)
or (x.shape[0] == batch_size and x.shape[1] == 3)
)
if not ok:
raise ValueError(err_msg)
if x.device != self.device:
x = x.to(self.device)
if x.shape[0] == 3 and len(x.shape) == 1:
x = x[None]
if x.shape[0] == 1:
x = x.repeat((batch_size, 1))
else:
raise ValueError(err_msg)
return x
def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
"""
Return the 3D coordinate grid of the volumetric grid
@ -378,13 +797,12 @@ class Volumes:
self, world_coordinates: bool = True
) -> torch.Tensor:
"""
Calculate the 3D coordinate grid of the volumetric grid either in
Calculate the 3D coordinate grid of the volumetric grid either
in local (`world_coordinates=False`) or
world coordinates (`world_coordinates=True`) .
"""
densities = self.densities()
ba, _, de, he, wi = densities.shape
ba, (de, he, wi) = self._batch_size, self._resolution
grid_sizes = self.get_grid_sizes()
# generate coordinate axes
@ -497,102 +915,6 @@ class Volumes:
.view(pts_shape)
)
def __len__(self) -> int:
return self._densities.shape[0]
def __getitem__(
self,
index: Union[
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
],
) -> "Volumes":
"""
Args:
index: Specifying the index of the volume to retrieve.
Can be an int, slice, list of ints or a boolean or a long tensor.
Returns:
Volumes object with selected volumes. The tensors are not cloned.
"""
if isinstance(index, int):
index = torch.LongTensor([index])
elif isinstance(index, (slice, list, tuple)):
pass
elif torch.is_tensor(index):
if index.dim() != 1 or index.dtype.is_floating_point:
raise IndexError(index)
else:
raise IndexError(index)
new = self.__class__(
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
features=self.features()[index] if self._features is not None else None,
densities=self.densities()[index],
)
# dont forget to update grid_sizes!
new._grid_sizes = self.get_grid_sizes()[index]
new._local_to_world_transform = self._local_to_world_transform[index]
return new
def features(self) -> Optional[torch.Tensor]:
"""
Returns the features of the volume.
Returns:
**features**: The tensor of volume features.
"""
return self._features
def densities(self) -> torch.Tensor:
"""
Returns the densities of the volume.
Returns:
**densities**: The tensor of volume densities.
"""
return self._densities
def densities_list(self) -> List[torch.Tensor]:
"""
Get the list representation of the densities.
Returns:
list of tensors of densities of shape (dim_i, D_i, H_i, W_i).
"""
return self._features_densities_list(self.densities())
def features_list(self) -> List[torch.Tensor]:
"""
Get the list representation of the features.
Returns:
list of tensors of features of shape (dim_i, D_i, H_i, W_i)
or `None` for feature-less volumes.
"""
features_ = self.features()
if features_ is None:
# No features provided so return None
# pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
return None
return self._features_densities_list(features_)
def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Retrieve the list representation of features/densities.
Args:
x: self.features() or self.densities()
Returns:
list of tensors of features/densities of shape (dim_i, D_i, H_i, W_i).
"""
x_dim = x.shape[1]
pad_sizes = torch.nn.functional.pad(
self.get_grid_sizes(), [1, 0], mode="constant", value=x_dim
)
x_list = struct_utils.padded_to_list(x, pad_sizes.tolist())
return x_list
def get_grid_sizes(self) -> torch.LongTensor:
"""
Returns the sizes of individual volumetric grids in the structure.
@ -603,59 +925,6 @@ class Volumes:
"""
return self._grid_sizes
def update_padded(
self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
) -> "Volumes":
"""
Returns a Volumes structure with updated padded tensors and copies of
the auxiliary tensors `self._local_to_world_transform`,
`device` and `self._grid_sizes`. This function allows for an update of
densities (and features) without having to explicitly
convert it to the list representation for heterogeneous batches.
Args:
new_densities: FloatTensor of shape (N, dim_density, D, H, W)
new_features: (optional) FloatTensor of shape (N, dim_feature, D, H, W)
Returns:
Volumes with updated features and densities
"""
new = copy.copy(self)
new._set_densities(new_densities)
if new_features is None:
new._features = None
else:
new._set_features(new_features)
return new
def _set_features(self, features: _TensorBatch) -> None:
self._set_densities_features("features", features)
def _set_densities(self, densities: _TensorBatch) -> None:
self._set_densities_features("densities", densities)
def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None:
x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
if x_tensor.device != self.device:
raise ValueError(
f"`{var_name}` have to be on the same device as `self.densities`."
)
if len(x_tensor.shape) != 5:
raise ValueError(
f"{var_name} has to be a 5-dim tensor of shape: "
f"(minibatch, {var_name}_dim, height, width, depth)"
)
if not (
(self.get_grid_sizes().shape == grid_sizes.shape)
and torch.allclose(self.get_grid_sizes(), grid_sizes)
):
raise ValueError(
f"The size of every grid in `{var_name}` has to match the size of"
"the corresponding `densities` grid."
)
setattr(self, "_" + var_name, x_tensor)
def _set_local_to_world_transform(
self,
voxel_size: _VoxelSize = 1.0,
@ -690,17 +959,104 @@ class Volumes:
voxel_size, volume_translation, len(self)
)
def clone(self) -> "Volumes":
def _copy_transform_and_sizes(
self,
other: "VolumeLocator",
device: Optional[torch.device] = None,
index: Optional[
Union[int, List[int], Tuple[int], slice, torch.Tensor]
] = _ALL_CONTENT,
) -> None:
"""
Deep copy of Volumes object. All internal tensors are cloned
individually.
Copies the local to world transform and grid sizes to other VolumeLocator object
and moves it to specified device. Operates in place on other.
Returns:
new Volumes object.
Args:
other: VolumeLocator object to which to copy
device: torch.device on which to put the result, defatults to self.device
index: Specifies which parts to copy.
Can be an int, slice, list of ints or a boolean or a long tensor.
Defaults to all items (`:`).
"""
return copy.deepcopy(self)
device = device if device is not None else self.device
other._grid_sizes = self._grid_sizes[index].to(device)
other._local_to_world_transform = self.get_local_to_world_coords_transform()[
index
].to(device)
def to(self, device: Device, copy: bool = False) -> "Volumes":
def _handle_voxel_size(
self, voxel_size: _VoxelSize, batch_size: int
) -> torch.Tensor:
"""
Handle the `voxel_size` argument to the `VolumeLocator` constructor.
"""
err_msg = (
"voxel_size has to be either a 3-tuple of scalars, or a scalar, or"
" a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)."
)
if isinstance(voxel_size, (float, int)):
# convert a scalar to a 3-element tensor
voxel_size = torch.full(
(1, 3), voxel_size, device=self.device, dtype=torch.float32
)
elif isinstance(voxel_size, torch.Tensor):
if voxel_size.numel() == 1:
# convert a single-element tensor to a 3-element one
voxel_size = voxel_size.view(-1).repeat(3)
elif len(voxel_size.shape) == 2 and (
voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1
):
voxel_size = voxel_size.repeat(1, 3)
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
def _handle_volume_translation(
self, translation: _Translation, batch_size: int
) -> torch.Tensor:
"""
Handle the `volume_translation` argument to the `VolumeLocator` constructor.
"""
err_msg = (
"`volume_translation` has to be either a 3-tuple of scalars, or"
" a Tensor of shape (1,3) or (minibatch, 3) or (3,)`."
)
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
def __len__(self) -> int:
return self._batch_size
def _convert_volume_property_to_tensor(
self, x: _Vector, batch_size: int, err_msg: str
) -> torch.Tensor:
"""
Handle the `volume_translation` or `voxel_size` argument to
the VolumeLocator constructor.
Return a tensor of shape (N, 3) where N is the batch_size.
"""
if isinstance(x, (list, tuple)):
if len(x) != 3:
raise ValueError(err_msg)
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
x = x.repeat((batch_size, 1))
elif isinstance(x, torch.Tensor):
ok = (
(x.shape[0] == 1 and x.shape[1] == 3)
or (x.shape[0] == 3 and len(x.shape) == 1)
or (x.shape[0] == batch_size and x.shape[1] == 3)
)
if not ok:
raise ValueError(err_msg)
if x.device != self.device:
x = x.to(self.device)
if x.shape[0] == 3 and len(x.shape) == 1:
x = x[None]
if x.shape[0] == 1:
x = x.repeat((batch_size, 1))
else:
raise ValueError(err_msg)
return x
def to(self, device: Device, copy: bool = False) -> "VolumeLocator":
"""
Match the functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
@ -713,7 +1069,7 @@ class Volumes:
copy: Boolean indicator whether or not to clone self. Default False.
Returns:
Volumes object.
VolumeLocator object.
"""
device_ = make_device(device)
if not copy and self.device == device_:
@ -724,18 +1080,24 @@ class Volumes:
return other
other.device = device_
other._densities = self._densities.to(device_)
if self._features is not None:
# pyre-fixme[16]: `Optional` has no attribute `to`.
other._features = self.features().to(device_)
other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
device_
)
other._grid_sizes = self._grid_sizes.to(device_)
other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
device
)
return other
def cpu(self) -> "Volumes":
def clone(self) -> "VolumeLocator":
"""
Deep copy of VoluVolumeLocatormes object. All internal tensors are cloned
individually.
Returns:
new VolumeLocator object.
"""
return copy.deepcopy(self)
def cpu(self) -> "VolumeLocator":
return self.to("cpu")
def cuda(self) -> "Volumes":
def cuda(self) -> "VolumeLocator":
return self.to("cuda")

View File

@ -11,7 +11,7 @@ import unittest
import numpy as np
import torch
from pytorch3d.structures.volumes import Volumes
from pytorch3d.structures.volumes import VolumeLocator, Volumes
from pytorch3d.transforms import Scale
from .common_testing import TestCaseMixin
@ -53,8 +53,8 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
for selectedIdx, index in indices:
self.assertClose(selected.densities()[selectedIdx], v.densities()[index])
self.assertClose(
v._local_to_world_transform.get_matrix()[index],
selected._local_to_world_transform.get_matrix()[selectedIdx],
v.locator._local_to_world_transform.get_matrix()[index],
selected.locator._local_to_world_transform.get_matrix()[selectedIdx],
)
if selected.features() is not None:
self.assertClose(selected.features()[selectedIdx], v.features()[index])
@ -149,10 +149,55 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
with self.assertRaises(IndexError):
v_selected = v[index]
def test_locator_init(self, batch_size=9, resolution=(3, 5, 7)):
with self.subTest("VolumeLocator init with all sizes equal"):
grid_sizes = [resolution for _ in range(batch_size)]
locator_tuple = VolumeLocator(
batch_size=batch_size, grid_sizes=resolution, device=torch.device("cpu")
)
locator_list = VolumeLocator(
batch_size=batch_size, grid_sizes=grid_sizes, device=torch.device("cpu")
)
locator_tensor = VolumeLocator(
batch_size=batch_size,
grid_sizes=torch.tensor(grid_sizes),
device=torch.device("cpu"),
)
expected_grid_sizes = torch.tensor(grid_sizes)
expected_resolution = resolution
assert torch.allclose(expected_grid_sizes, locator_tuple._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
self.assertEqual(expected_resolution, locator_tuple._resolution)
self.assertEqual(expected_resolution, locator_list._resolution)
self.assertEqual(expected_resolution, locator_tensor._resolution)
with self.subTest("VolumeLocator with different sizes in different grids"):
grid_sizes_list = [
torch.randint(low=1, high=42, size=(3,)) for _ in range(batch_size)
]
grid_sizes_tensor = torch.cat([el[None] for el in grid_sizes_list])
locator_list = VolumeLocator(
batch_size=batch_size,
grid_sizes=grid_sizes_list,
device=torch.device("cpu"),
)
locator_tensor = VolumeLocator(
batch_size=batch_size,
grid_sizes=grid_sizes_tensor,
device=torch.device("cpu"),
)
expected_grid_sizes = grid_sizes_tensor
expected_resolution = tuple(torch.max(expected_grid_sizes, dim=0).values)
assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
self.assertEqual(expected_resolution, locator_list._resolution)
self.assertEqual(expected_resolution, locator_tensor._resolution)
def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32):
"""
Test the correctness of the conversion between the internal
Transform3D Volumes._local_to_world_transform and the initialization
Transform3D Volumes.VolumeLocator._local_to_world_transform and the initialization
from the translation and voxel_size.
"""
@ -440,7 +485,10 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
for var_name, var in vars(v).items():
if var_name != "device":
if var is not None:
self.assertTrue(var.device.type == desired_device.type)
self.assertTrue(
var.device.type == desired_device.type,
(var_name, var.device, desired_device),
)
else:
self.assertTrue(var.type == desired_device.type)
@ -456,60 +504,74 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
)
densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype)
volumes = Volumes(densities=densities, features=features)
locator = VolumeLocator(
batch_size=5, grid_sizes=(3, 5, 7), device=volumes.device
)
# Test support for str and torch.device
cpu_device = torch.device("cpu")
for name, obj in (("VolumeLocator", locator), ("Volumes", volumes)):
with self.subTest(f"Moving {name} from/to gpu and cpu"):
# Test support for str and torch.device
cpu_device = torch.device("cpu")
converted_volumes = volumes.to("cpu")
self.assertEqual(cpu_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIs(volumes, converted_volumes)
converted_obj = obj.to("cpu")
self.assertEqual(cpu_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIs(obj, converted_obj)
converted_volumes = volumes.to(cpu_device)
self.assertEqual(cpu_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIs(volumes, converted_volumes)
converted_obj = obj.to(cpu_device)
self.assertEqual(cpu_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIs(obj, converted_obj)
cuda_device = torch.device("cuda:0")
cuda_device = torch.device("cuda:0")
converted_volumes = volumes.to("cuda:0")
self.assertEqual(cuda_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIsNot(volumes, converted_volumes)
converted_obj = obj.to("cuda:0")
self.assertEqual(cuda_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIsNot(obj, converted_obj)
converted_volumes = volumes.to(cuda_device)
self.assertEqual(cuda_device, converted_volumes.device)
self.assertEqual(cpu_device, volumes.device)
self.assertIsNot(volumes, converted_volumes)
converted_obj = obj.to(cuda_device)
self.assertEqual(cuda_device, converted_obj.device)
self.assertEqual(cpu_device, obj.device)
self.assertIsNot(obj, converted_obj)
# Test device placement of internal tensors
features = features.to(cuda_device)
densities = features.to(cuda_device)
with self.subTest("Test device placement of internal tensors of Volumes"):
features = features.to(cuda_device)
densities = features.to(cuda_device)
for features_ in (features, None):
volumes = Volumes(densities=densities, features=features_)
for features_ in (features, None):
volumes = Volumes(densities=densities, features=features_)
cpu_volumes = volumes.cpu()
cuda_volumes = cpu_volumes.cuda()
cuda_volumes2 = cuda_volumes.cuda()
cpu_volumes2 = cuda_volumes2.cpu()
cpu_volumes = volumes.cpu()
cuda_volumes = cpu_volumes.cuda()
cuda_volumes2 = cuda_volumes.cuda()
cpu_volumes2 = cuda_volumes2.cpu()
for volumes1, volumes2 in itertools.combinations(
(volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
):
if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
# checks that we do not copy if the devices stay the same
assert_fun = self.assertIs
else:
assert_fun = self.assertSeparate
assert_fun(volumes1._densities, volumes2._densities)
if features_ is not None:
assert_fun(volumes1._features, volumes2._features)
for volumes_ in (volumes1, volumes2):
if volumes_ in (cpu_volumes, cpu_volumes2):
self._check_vars_on_device(volumes_, cpu_device)
for volumes1, volumes2 in itertools.combinations(
(volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
):
if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
# checks that we do not copy if the devices stay the same
assert_fun = self.assertIs
else:
self._check_vars_on_device(volumes_, cuda_device)
assert_fun = self.assertSeparate
assert_fun(volumes1._densities, volumes2._densities)
if features_ is not None:
assert_fun(volumes1._features, volumes2._features)
for volumes_ in (volumes1, volumes2):
if volumes_ in (cpu_volumes, cpu_volumes2):
self._check_vars_on_device(volumes_, cpu_device)
else:
self._check_vars_on_device(volumes_, cuda_device)
with self.subTest("Test device placement of internal tensors of VolumeLocator"):
for device1, device2 in itertools.combinations(
(torch.device("cpu"), torch.device("cuda:0")), 2
):
locator = locator.to(device1)
locator = locator.to(device2)
self.assertEqual(locator._grid_sizes.device, device2)
self.assertEqual(locator._local_to_world_transform.device, device2)
def _check_padded(self, x_pad, x_list, grid_sizes):
"""