mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Split Volumes class to data and location part
Summary: Split Volumes class to data and location part so that location part can be reused in planned VoxelGrid classes. Reviewed By: bottler Differential Revision: D38782015 fbshipit-source-id: 489da09c5c236f3b81961ce9b09edbd97afaa7c8
This commit is contained in:
		
							parent
							
								
									fdaaa299a7
								
							
						
					
					
						commit
						f825f7e42c
					
				@ -23,6 +23,7 @@ _VoxelSize = _ScalarOrVector
 | 
			
		||||
_Translation = _Vector
 | 
			
		||||
 | 
			
		||||
_TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
 | 
			
		||||
_ALL_CONTENT: slice = slice(0, None)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Volumes:
 | 
			
		||||
@ -65,9 +66,9 @@ class Volumes:
 | 
			
		||||
 | 
			
		||||
    VOLUME COORDINATES
 | 
			
		||||
 | 
			
		||||
    Additionally, the `Volumes` class keeps track of the locations of the
 | 
			
		||||
    centers of the volume cells in the local volume coordinates as well as in
 | 
			
		||||
    the world coordinates.
 | 
			
		||||
    Additionally, using the `VolumeLocator` class the `Volumes` class keeps track
 | 
			
		||||
    of the locations of the centers of the volume cells in the local volume
 | 
			
		||||
    coordinates as well as in the world coordinates.
 | 
			
		||||
 | 
			
		||||
        Local coordinates:
 | 
			
		||||
            - Represent the locations of the volume cells in the local coordinate
 | 
			
		||||
@ -125,7 +126,7 @@ class Volumes:
 | 
			
		||||
    appropriate `world_coordinates` argument.
 | 
			
		||||
 | 
			
		||||
    Internally, the mapping between `x_local` and `x_world` is represented
 | 
			
		||||
    as a `Transform3d` object `Volumes._local_to_world_transform`.
 | 
			
		||||
    as a `Transform3d` object `Volumes.VolumeLocator._local_to_world_transform`.
 | 
			
		||||
    Users can access the relevant transformations with the
 | 
			
		||||
    `Volumes.get_world_to_local_coords_transform()` and
 | 
			
		||||
    `Volumes.get_local_to_world_coords_transform()`
 | 
			
		||||
@ -197,21 +198,24 @@ class Volumes:
 | 
			
		||||
 | 
			
		||||
        # assign to the internal buffers
 | 
			
		||||
        self._densities = densities_
 | 
			
		||||
        self._grid_sizes = grid_sizes
 | 
			
		||||
 | 
			
		||||
        # assign a coordinate transformation member
 | 
			
		||||
        self.locator = VolumeLocator(
 | 
			
		||||
            batch_size=len(self),
 | 
			
		||||
            grid_sizes=grid_sizes,
 | 
			
		||||
            voxel_size=voxel_size,
 | 
			
		||||
            volume_translation=volume_translation,
 | 
			
		||||
            device=self.device,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # handle features
 | 
			
		||||
        self._features = None
 | 
			
		||||
        if features is not None:
 | 
			
		||||
            self._set_features(features)
 | 
			
		||||
 | 
			
		||||
        # set the local_to_world transform
 | 
			
		||||
        self._set_local_to_world_transform(
 | 
			
		||||
            voxel_size=voxel_size, volume_translation=volume_translation
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _convert_densities_features_to_tensor(
 | 
			
		||||
        self, x: _TensorBatch, var_name: str
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Handle the `densities` or `features` arguments to the constructor.
 | 
			
		||||
        """
 | 
			
		||||
@ -251,8 +255,492 @@ class Volumes:
 | 
			
		||||
                f"{var_name} must be either a list or a tensor with "
 | 
			
		||||
                f"shape (batch_size, {var_name}_dim, H, W, D)."
 | 
			
		||||
            )
 | 
			
		||||
        # pyre-ignore[7]
 | 
			
		||||
        return x_tensor, x_shapes
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return self._densities.shape[0]
 | 
			
		||||
 | 
			
		||||
    def __getitem__(
 | 
			
		||||
        self,
 | 
			
		||||
        index: Union[
 | 
			
		||||
            int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
 | 
			
		||||
        ],
 | 
			
		||||
    ) -> "Volumes":
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            index: Specifying the index of the volume to retrieve.
 | 
			
		||||
                Can be an int, slice, list of ints or a boolean or a long tensor.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Volumes object with selected volumes. The tensors are not cloned.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            index = torch.LongTensor([index])
 | 
			
		||||
        elif isinstance(index, (slice, list, tuple)):
 | 
			
		||||
            pass
 | 
			
		||||
        elif torch.is_tensor(index):
 | 
			
		||||
            if index.dim() != 1 or index.dtype.is_floating_point:
 | 
			
		||||
                raise IndexError(index)
 | 
			
		||||
        else:
 | 
			
		||||
            raise IndexError(index)
 | 
			
		||||
 | 
			
		||||
        new = self.__class__(
 | 
			
		||||
            # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
 | 
			
		||||
            features=self.features()[index] if self._features is not None else None,
 | 
			
		||||
            densities=self.densities()[index],
 | 
			
		||||
        )
 | 
			
		||||
        # dont forget to update grid_sizes!
 | 
			
		||||
        self.locator._copy_transform_and_sizes(new.locator, index=index)
 | 
			
		||||
        return new
 | 
			
		||||
 | 
			
		||||
    def features(self) -> Optional[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the features of the volume.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **features**: The tensor of volume features.
 | 
			
		||||
        """
 | 
			
		||||
        return self._features
 | 
			
		||||
 | 
			
		||||
    def densities(self) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the densities of the volume.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **densities**: The tensor of volume densities.
 | 
			
		||||
        """
 | 
			
		||||
        return self._densities
 | 
			
		||||
 | 
			
		||||
    def densities_list(self) -> List[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Get the list representation of the densities.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of tensors of densities of shape (dim_i, D_i, H_i, W_i).
 | 
			
		||||
        """
 | 
			
		||||
        return self._features_densities_list(self.densities())
 | 
			
		||||
 | 
			
		||||
    def features_list(self) -> List[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Get the list representation of the features.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of tensors of features of shape (dim_i, D_i, H_i, W_i)
 | 
			
		||||
            or `None` for feature-less volumes.
 | 
			
		||||
        """
 | 
			
		||||
        features_ = self.features()
 | 
			
		||||
        if features_ is None:
 | 
			
		||||
            # No features provided so return None
 | 
			
		||||
            # pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
 | 
			
		||||
            return None
 | 
			
		||||
        return self._features_densities_list(features_)
 | 
			
		||||
 | 
			
		||||
    def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve the list representation of features/densities.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x: self.features() or self.densities()
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of tensors of features/densities of shape (dim_i, D_i, H_i, W_i).
 | 
			
		||||
        """
 | 
			
		||||
        x_dim = x.shape[1]
 | 
			
		||||
        pad_sizes = torch.nn.functional.pad(
 | 
			
		||||
            self.get_grid_sizes(), [1, 0], mode="constant", value=x_dim
 | 
			
		||||
        )
 | 
			
		||||
        x_list = struct_utils.padded_to_list(x, pad_sizes.tolist())
 | 
			
		||||
        return x_list
 | 
			
		||||
 | 
			
		||||
    def update_padded(
 | 
			
		||||
        self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
 | 
			
		||||
    ) -> "Volumes":
 | 
			
		||||
        """
 | 
			
		||||
        Returns a Volumes structure with updated padded tensors and copies of
 | 
			
		||||
        the auxiliary tensors `self._local_to_world_transform`,
 | 
			
		||||
        `device` and `self._grid_sizes`. This function allows for an update of
 | 
			
		||||
        densities (and features) without having to explicitly
 | 
			
		||||
        convert it to the list representation for heterogeneous batches.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            new_densities: FloatTensor of shape (N, dim_density, D, H, W)
 | 
			
		||||
            new_features: (optional) FloatTensor of shape (N, dim_feature, D, H, W)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Volumes with updated features and densities
 | 
			
		||||
        """
 | 
			
		||||
        new = copy.copy(self)
 | 
			
		||||
        new._set_densities(new_densities)
 | 
			
		||||
        if new_features is None:
 | 
			
		||||
            new._features = None
 | 
			
		||||
        else:
 | 
			
		||||
            new._set_features(new_features)
 | 
			
		||||
        return new
 | 
			
		||||
 | 
			
		||||
    def _set_features(self, features: _TensorBatch) -> None:
 | 
			
		||||
        self._set_densities_features("features", features)
 | 
			
		||||
 | 
			
		||||
    def _set_densities(self, densities: _TensorBatch) -> None:
 | 
			
		||||
        self._set_densities_features("densities", densities)
 | 
			
		||||
 | 
			
		||||
    def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None:
 | 
			
		||||
        x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
 | 
			
		||||
        if x_tensor.device != self.device:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"`{var_name}` have to be on the same device as `self.densities`."
 | 
			
		||||
            )
 | 
			
		||||
        if len(x_tensor.shape) != 5:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"{var_name} has to be a 5-dim tensor of shape: "
 | 
			
		||||
                f"(minibatch, {var_name}_dim, height, width, depth)"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if not (
 | 
			
		||||
            (self.get_grid_sizes().shape == grid_sizes.shape)
 | 
			
		||||
            and torch.allclose(self.get_grid_sizes(), grid_sizes)
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"The size of every grid in `{var_name}` has to match the size of"
 | 
			
		||||
                "the corresponding `densities` grid."
 | 
			
		||||
            )
 | 
			
		||||
        setattr(self, "_" + var_name, x_tensor)
 | 
			
		||||
 | 
			
		||||
    def clone(self) -> "Volumes":
 | 
			
		||||
        """
 | 
			
		||||
        Deep copy of Volumes object. All internal tensors are cloned
 | 
			
		||||
        individually.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            new Volumes object.
 | 
			
		||||
        """
 | 
			
		||||
        return copy.deepcopy(self)
 | 
			
		||||
 | 
			
		||||
    def to(self, device: Device, copy: bool = False) -> "Volumes":
 | 
			
		||||
        """
 | 
			
		||||
        Match the functionality of torch.Tensor.to()
 | 
			
		||||
        If copy = True or the self Tensor is on a different device, the
 | 
			
		||||
        returned tensor is a copy of self with the desired torch.device.
 | 
			
		||||
        If copy = False and the self Tensor already has the correct torch.device,
 | 
			
		||||
        then self is returned.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            device: Device (as str or torch.device) for the new tensor.
 | 
			
		||||
            copy: Boolean indicator whether or not to clone self. Default False.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Volumes object.
 | 
			
		||||
        """
 | 
			
		||||
        device_ = make_device(device)
 | 
			
		||||
        if not copy and self.device == device_:
 | 
			
		||||
            return self
 | 
			
		||||
 | 
			
		||||
        other = self.clone()
 | 
			
		||||
        if self.device == device_:
 | 
			
		||||
            return other
 | 
			
		||||
 | 
			
		||||
        other.device = device_
 | 
			
		||||
        other._densities = self._densities.to(device_)
 | 
			
		||||
        if self._features is not None:
 | 
			
		||||
            # pyre-fixme[16]: `Optional` has no attribute `to`.
 | 
			
		||||
            other._features = self.features().to(device_)
 | 
			
		||||
        self.locator._copy_transform_and_sizes(other.locator, device=device_)
 | 
			
		||||
        other.locator = other.locator.to(device, copy)
 | 
			
		||||
        return other
 | 
			
		||||
 | 
			
		||||
    def cpu(self) -> "Volumes":
 | 
			
		||||
        return self.to("cpu")
 | 
			
		||||
 | 
			
		||||
    def cuda(self) -> "Volumes":
 | 
			
		||||
        return self.to("cuda")
 | 
			
		||||
 | 
			
		||||
    def get_grid_sizes(self) -> torch.LongTensor:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the sizes of individual volumetric grids in the structure.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **grid_sizes**: Tensor of spatial sizes of each of the volumes
 | 
			
		||||
                of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i).
 | 
			
		||||
        """
 | 
			
		||||
        return self.locator.get_grid_sizes()
 | 
			
		||||
 | 
			
		||||
    def get_local_to_world_coords_transform(self) -> Transform3d:
 | 
			
		||||
        """
 | 
			
		||||
        Return a Transform3d object that converts points in the
 | 
			
		||||
        the local coordinate frame of the volume to world coordinates.
 | 
			
		||||
        Local volume coordinates are scaled s.t. the coordinates along one
 | 
			
		||||
        side of the volume are in range [-1, 1].
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **local_to_world_transform**: A Transform3d object converting
 | 
			
		||||
                points from local coordinates to the world coordinates.
 | 
			
		||||
        """
 | 
			
		||||
        return self.locator.get_local_to_world_coords_transform()
 | 
			
		||||
 | 
			
		||||
    def get_world_to_local_coords_transform(self) -> Transform3d:
 | 
			
		||||
        """
 | 
			
		||||
        Return a Transform3d object that converts points in the
 | 
			
		||||
        world coordinates to the local coordinate frame of the volume.
 | 
			
		||||
        Local volume coordinates are scaled s.t. the coordinates along one
 | 
			
		||||
        side of the volume are in range [-1, 1].
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **world_to_local_transform**: A Transform3d object converting
 | 
			
		||||
                points from world coordinates to local coordinates.
 | 
			
		||||
        """
 | 
			
		||||
        return self.get_local_to_world_coords_transform().inverse()
 | 
			
		||||
 | 
			
		||||
    def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Convert a batch of 3D point coordinates `points_3d_world` of shape
 | 
			
		||||
        (minibatch, ..., dim) in the world coordinates to
 | 
			
		||||
        the local coordinate frame of the volume. Local volume
 | 
			
		||||
        coordinates are scaled s.t. the coordinates along one side of the volume
 | 
			
		||||
        are in range [-1, 1].
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            **points_3d_world**: A tensor of shape `(minibatch, ..., 3)`
 | 
			
		||||
                containing the 3D coordinates of a set of points that will
 | 
			
		||||
                be converted from the local volume coordinates (ranging
 | 
			
		||||
                within [-1, 1]) to the world coordinates
 | 
			
		||||
                defined by the `self.center` and `self.voxel_size` parameters.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **points_3d_local**: `points_3d_world` converted to the local
 | 
			
		||||
                volume coordinates of shape `(minibatch, ..., 3)`.
 | 
			
		||||
        """
 | 
			
		||||
        return self.locator.world_to_local_coords(points_3d_world)
 | 
			
		||||
 | 
			
		||||
    def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Convert a batch of 3D point coordinates `points_3d_local` of shape
 | 
			
		||||
        (minibatch, ..., dim) in the local coordinate frame of the volume
 | 
			
		||||
        to the world coordinates.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            **points_3d_local**: A tensor of shape `(minibatch, ..., 3)`
 | 
			
		||||
                containing the 3D coordinates of a set of points that will
 | 
			
		||||
                be converted from the local volume coordinates (ranging
 | 
			
		||||
                within [-1, 1]) to the world coordinates
 | 
			
		||||
                defined by the `self.center` and `self.voxel_size` parameters.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **points_3d_world**: `points_3d_local` converted to the world
 | 
			
		||||
                coordinates of the volume of shape `(minibatch, ..., 3)`.
 | 
			
		||||
        """
 | 
			
		||||
        return self.locator.local_to_world_coords(points_3d_local)
 | 
			
		||||
 | 
			
		||||
    def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Return the 3D coordinate grid of the volumetric grid
 | 
			
		||||
        in local (`world_coordinates=False`) or world coordinates
 | 
			
		||||
        (`world_coordinates=True`).
 | 
			
		||||
 | 
			
		||||
        The grid records location of each center of the corresponding volume voxel.
 | 
			
		||||
 | 
			
		||||
        Local coordinates are scaled s.t. the values along one side of the
 | 
			
		||||
        volume are in range [-1, 1].
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            **world_coordinates**: if `True`, the method
 | 
			
		||||
                returns the grid in the world coordinates,
 | 
			
		||||
                otherwise, in local coordinates.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **coordinate_grid**: The grid of coordinates of shape
 | 
			
		||||
                `(minibatch, depth, height, width, 3)`, where `minibatch`,
 | 
			
		||||
                `height`, `width` and `depth` are the batch size, height, width
 | 
			
		||||
                and depth of the volume `features` or `densities`.
 | 
			
		||||
        """
 | 
			
		||||
        return self.locator.get_coord_grid(world_coordinates)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VolumeLocator:
 | 
			
		||||
    """
 | 
			
		||||
    The `VolumeLocator` class keeps track of the locations of the
 | 
			
		||||
    centers of the volume cells in the local volume coordinates as well as in
 | 
			
		||||
    the world coordinates for a voxel grid structure in 3D.
 | 
			
		||||
 | 
			
		||||
        Local coordinates:
 | 
			
		||||
            - Represent the locations of the volume cells in the local coordinate
 | 
			
		||||
              frame of the volume.
 | 
			
		||||
            - The center of the voxel indexed with `[·, ·, 0, 0, 0]` in the volume
 | 
			
		||||
              has its 3D local coordinate set to `[-1, -1, -1]`, while the voxel
 | 
			
		||||
              at index `[·, ·, depth_i-1, height_i-1, width_i-1]` has its
 | 
			
		||||
              3D local coordinate set to `[1, 1, 1]`.
 | 
			
		||||
            - The first/second/third coordinate of each of the 3D per-voxel
 | 
			
		||||
              XYZ vector denotes the horizontal/vertical/depth-wise position
 | 
			
		||||
              respectively. I.e the order of the coordinate dimensions in the
 | 
			
		||||
              volume is reversed w.r.t. the order of the 3D coordinate vectors.
 | 
			
		||||
            - The intermediate coordinates between `[-1, -1, -1]` and `[1, 1, 1]`.
 | 
			
		||||
              are linearly interpolated over the spatial dimensions of the volume.
 | 
			
		||||
            - Note that the convention is the same as for the 5D version of the
 | 
			
		||||
              `torch.nn.functional.grid_sample` function called with
 | 
			
		||||
              `align_corners==True`.
 | 
			
		||||
            - Note that the local coordinate convention of `VolumeLocator`
 | 
			
		||||
              (+X = left to right, +Y = top to bottom, +Z = away from the user)
 | 
			
		||||
              is *different* from the world coordinate convention of the
 | 
			
		||||
              renderer for `Meshes` or `Pointclouds`
 | 
			
		||||
              (+X = right to left, +Y = bottom to top, +Z = away from the user).
 | 
			
		||||
 | 
			
		||||
        World coordinates:
 | 
			
		||||
            - These define the locations of the centers of the volume cells
 | 
			
		||||
              in the world coordinates.
 | 
			
		||||
            - They are specified with the following mapping that converts
 | 
			
		||||
              points `x_local` in the local coordinates to points `x_world`
 | 
			
		||||
              in the world coordinates:
 | 
			
		||||
                ```
 | 
			
		||||
                x_world = (
 | 
			
		||||
                    x_local * (volume_size - 1) * 0.5 * voxel_size
 | 
			
		||||
                ) - volume_translation,
 | 
			
		||||
                ```
 | 
			
		||||
              here `voxel_size` specifies the size of each voxel of the volume,
 | 
			
		||||
              and `volume_translation` is the 3D offset of the central voxel of
 | 
			
		||||
              the volume w.r.t. the origin of the world coordinate frame.
 | 
			
		||||
              Both `voxel_size` and `volume_translation` are specified in
 | 
			
		||||
              the world coordinate units. `volume_size` is the spatial size of
 | 
			
		||||
              the volume in form of a 3D vector `[width, height, depth]`.
 | 
			
		||||
            - Given the above definition of `x_world`, one can derive the
 | 
			
		||||
              inverse mapping from `x_world` to `x_local` as follows:
 | 
			
		||||
                ```
 | 
			
		||||
                x_local = (
 | 
			
		||||
                    (x_world + volume_translation) / (0.5 * voxel_size)
 | 
			
		||||
                ) / (volume_size - 1)
 | 
			
		||||
                ```
 | 
			
		||||
            - For a trivial volume with `volume_translation==[0, 0, 0]`
 | 
			
		||||
              with `voxel_size=-1`, `x_world` would range
 | 
			
		||||
              from -(volume_size-1)/2` to `+(volume_size-1)/2`.
 | 
			
		||||
 | 
			
		||||
    Coordinate tensors that denote the locations of each of the volume cells in
 | 
			
		||||
    local / world coordinates (with shape `(depth x height x width x 3)`)
 | 
			
		||||
    can be retrieved by calling the `VolumeLocator.get_coord_grid()` getter with the
 | 
			
		||||
    appropriate `world_coordinates` argument.
 | 
			
		||||
 | 
			
		||||
    Internally, the mapping between `x_local` and `x_world` is represented
 | 
			
		||||
    as a `Transform3d` object `VolumeLocator._local_to_world_transform`.
 | 
			
		||||
    Users can access the relevant transformations with the
 | 
			
		||||
    `VolumeLocator.get_world_to_local_coords_transform()` and
 | 
			
		||||
    `VolumeLocator.get_local_to_world_coords_transform()`
 | 
			
		||||
    functions.
 | 
			
		||||
 | 
			
		||||
    Example coordinate conversion:
 | 
			
		||||
        - For a "trivial" volume with `voxel_size = 1.`,
 | 
			
		||||
          `volume_translation=[0., 0., 0.]`, and the spatial size of
 | 
			
		||||
          `DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped
 | 
			
		||||
          to `x_local=(-1, 0, 1)`.
 | 
			
		||||
        - For a "trivial" volume `v` with `voxel_size = 1.`,
 | 
			
		||||
          `volume_translation=[0., 0., 0.]`, the following holds:
 | 
			
		||||
            ```
 | 
			
		||||
            torch.nn.functional.grid_sample(
 | 
			
		||||
                v.densities(),
 | 
			
		||||
                v.get_coord_grid(world_coordinates=False),
 | 
			
		||||
                align_corners=True,
 | 
			
		||||
            ) == v.densities(),
 | 
			
		||||
            ```
 | 
			
		||||
            i.e. sampling the volume at trivial local coordinates
 | 
			
		||||
            (no scaling with `voxel_size`` or shift with `volume_translation`)
 | 
			
		||||
            results in the same volume.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        batch_size: int,
 | 
			
		||||
        grid_sizes: Union[
 | 
			
		||||
            torch.LongTensor, Tuple[int, int, int], List[torch.LongTensor]
 | 
			
		||||
        ],
 | 
			
		||||
        device: torch.device,
 | 
			
		||||
        voxel_size: _VoxelSize = 1.0,
 | 
			
		||||
        volume_translation: _Translation = (0.0, 0.0, 0.0),
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        **batch_size** : Batch size of the underlaying grids
 | 
			
		||||
        **grid_sizes** : Represents the resolutions of different grids in the batch. Can be
 | 
			
		||||
                a) tuple of form (H, W, D)
 | 
			
		||||
                b) list/tuple of length batch_size of lists/tuples of form (H, W, D)
 | 
			
		||||
                c) torch.Tensor of shape (batch_size, H, W, D)
 | 
			
		||||
            H, W, D are height, width, depth respectively.  If `grid_sizes` is a tuple than
 | 
			
		||||
            all the  grids in the batch have the same resolution.
 | 
			
		||||
        **voxel_size**: Denotes the size of each volume voxel in world units.
 | 
			
		||||
            Has to be one of:
 | 
			
		||||
            a) A scalar (square voxels)
 | 
			
		||||
            b) 3-tuple or a 3-list of scalars
 | 
			
		||||
            c) a Tensor of shape (3,)
 | 
			
		||||
            d) a Tensor of shape (minibatch, 3)
 | 
			
		||||
            e) a Tensor of shape (minibatch, 1)
 | 
			
		||||
            f) a Tensor of shape (1,) (square voxels)
 | 
			
		||||
        **volume_translation**: Denotes the 3D translation of the center
 | 
			
		||||
            of the volume in world units. Has to be one of:
 | 
			
		||||
            a) 3-tuple or a 3-list of scalars
 | 
			
		||||
            b) a Tensor of shape (3,)
 | 
			
		||||
            c) a Tensor of shape (minibatch, 3)
 | 
			
		||||
            d) a Tensor of shape (1,) (square voxels)
 | 
			
		||||
        """
 | 
			
		||||
        self.device = device
 | 
			
		||||
        self._batch_size = batch_size
 | 
			
		||||
        self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes)
 | 
			
		||||
        self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values)
 | 
			
		||||
 | 
			
		||||
        # set the local_to_world transform
 | 
			
		||||
        self._set_local_to_world_transform(
 | 
			
		||||
            voxel_size=voxel_size, volume_translation=volume_translation
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _convert_grid_sizes2tensor(
 | 
			
		||||
        self, x: Union[torch.LongTensor, List[torch.LongTensor], Tuple[int, int, int]]
 | 
			
		||||
    ) -> torch.LongTensor:
 | 
			
		||||
        """
 | 
			
		||||
        Handle the grid_sizes argument to the constructor.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(x, (list, tuple)):
 | 
			
		||||
            if isinstance(x[0], (torch.LongTensor, list, tuple)):
 | 
			
		||||
                if self._batch_size != len(x):
 | 
			
		||||
                    raise ValueError("x should have a batch size of 'batch_size'")
 | 
			
		||||
                # pyre-ignore[6]
 | 
			
		||||
                if any(len(x_) != 3 for x_ in x):
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "`grid_sizes` has to be a list of 3-dim tensors of shape: "
 | 
			
		||||
                        "(height, width, depth)"
 | 
			
		||||
                    )
 | 
			
		||||
                x_shapes = torch.stack(
 | 
			
		||||
                    [
 | 
			
		||||
                        torch.tensor(
 | 
			
		||||
                            # pyre-ignore[6]
 | 
			
		||||
                            list(x_),
 | 
			
		||||
                            dtype=torch.long,
 | 
			
		||||
                            device=self.device,
 | 
			
		||||
                        )
 | 
			
		||||
                        for x_ in x
 | 
			
		||||
                    ],
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                )
 | 
			
		||||
            elif isinstance(x[0], int):
 | 
			
		||||
                x_shapes = torch.stack(
 | 
			
		||||
                    [
 | 
			
		||||
                        torch.tensor(list(x), dtype=torch.long, device=self.device)
 | 
			
		||||
                        for _ in range(self._batch_size)
 | 
			
		||||
                    ],
 | 
			
		||||
                    dim=0,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "`grid_sizes` can be a list/tuple of int or torch.Tensor not of "
 | 
			
		||||
                    + "{type(x[0])}."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        elif torch.is_tensor(x):
 | 
			
		||||
            if x.ndim != 2:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "`grid_sizes` has to be a 2-dim tensor of shape: (minibatch, 3)"
 | 
			
		||||
                )
 | 
			
		||||
            x_shapes = x.to(self.device)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "grid_sizes must be either a list of tensors with shape (H, W, D), tensor with"
 | 
			
		||||
                "shape (batch_size, H, W, D) or a tuple of (H, W, D)."
 | 
			
		||||
            )
 | 
			
		||||
        # pyre-ignore[7]
 | 
			
		||||
        return x_shapes
 | 
			
		||||
 | 
			
		||||
    def _voxel_size_translation_to_transform(
 | 
			
		||||
        self,
 | 
			
		||||
        voxel_size: torch.Tensor,
 | 
			
		||||
@ -280,75 +768,6 @@ class Volumes:
 | 
			
		||||
 | 
			
		||||
        return local_to_world_transform
 | 
			
		||||
 | 
			
		||||
    def _handle_voxel_size(
 | 
			
		||||
        self, voxel_size: _VoxelSize, batch_size: int
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Handle the `voxel_size` argument to the `Volumes` constructor.
 | 
			
		||||
        """
 | 
			
		||||
        err_msg = (
 | 
			
		||||
            "voxel_size has to be either a 3-tuple of scalars, or a scalar, or"
 | 
			
		||||
            " a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)."
 | 
			
		||||
        )
 | 
			
		||||
        if isinstance(voxel_size, (float, int)):
 | 
			
		||||
            # convert a scalar to a 3-element tensor
 | 
			
		||||
            voxel_size = torch.full(
 | 
			
		||||
                (1, 3), voxel_size, device=self.device, dtype=torch.float32
 | 
			
		||||
            )
 | 
			
		||||
        elif isinstance(voxel_size, torch.Tensor):
 | 
			
		||||
            if voxel_size.numel() == 1:
 | 
			
		||||
                # convert a single-element tensor to a 3-element one
 | 
			
		||||
                voxel_size = voxel_size.view(-1).repeat(3)
 | 
			
		||||
            elif len(voxel_size.shape) == 2 and (
 | 
			
		||||
                voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1
 | 
			
		||||
            ):
 | 
			
		||||
                voxel_size = voxel_size.repeat(1, 3)
 | 
			
		||||
        return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
 | 
			
		||||
 | 
			
		||||
    def _handle_volume_translation(
 | 
			
		||||
        self, translation: _Translation, batch_size: int
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Handle the `volume_translation` argument to the `Volumes` constructor.
 | 
			
		||||
        """
 | 
			
		||||
        err_msg = (
 | 
			
		||||
            "`volume_translation` has to be either a 3-tuple of scalars, or"
 | 
			
		||||
            " a Tensor of shape (1,3) or (minibatch, 3) or (3,)`."
 | 
			
		||||
        )
 | 
			
		||||
        return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
 | 
			
		||||
 | 
			
		||||
    def _convert_volume_property_to_tensor(
 | 
			
		||||
        self, x: _Vector, batch_size: int, err_msg: str
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Handle the `volume_translation` or `voxel_size` argument to
 | 
			
		||||
        the Volumes constructor.
 | 
			
		||||
        Return a tensor of shape (N, 3) where N is the batch_size.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(x, (list, tuple)):
 | 
			
		||||
            if len(x) != 3:
 | 
			
		||||
                raise ValueError(err_msg)
 | 
			
		||||
            x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
 | 
			
		||||
            x = x.repeat((batch_size, 1))
 | 
			
		||||
        elif isinstance(x, torch.Tensor):
 | 
			
		||||
            ok = (
 | 
			
		||||
                (x.shape[0] == 1 and x.shape[1] == 3)
 | 
			
		||||
                or (x.shape[0] == 3 and len(x.shape) == 1)
 | 
			
		||||
                or (x.shape[0] == batch_size and x.shape[1] == 3)
 | 
			
		||||
            )
 | 
			
		||||
            if not ok:
 | 
			
		||||
                raise ValueError(err_msg)
 | 
			
		||||
            if x.device != self.device:
 | 
			
		||||
                x = x.to(self.device)
 | 
			
		||||
            if x.shape[0] == 3 and len(x.shape) == 1:
 | 
			
		||||
                x = x[None]
 | 
			
		||||
            if x.shape[0] == 1:
 | 
			
		||||
                x = x.repeat((batch_size, 1))
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(err_msg)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Return the 3D coordinate grid of the volumetric grid
 | 
			
		||||
@ -378,13 +797,12 @@ class Volumes:
 | 
			
		||||
        self, world_coordinates: bool = True
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Calculate the 3D coordinate grid of the volumetric grid either in
 | 
			
		||||
        Calculate the 3D coordinate grid of the volumetric grid either
 | 
			
		||||
        in local (`world_coordinates=False`) or
 | 
			
		||||
        world coordinates (`world_coordinates=True`) .
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        densities = self.densities()
 | 
			
		||||
        ba, _, de, he, wi = densities.shape
 | 
			
		||||
        ba, (de, he, wi) = self._batch_size, self._resolution
 | 
			
		||||
        grid_sizes = self.get_grid_sizes()
 | 
			
		||||
 | 
			
		||||
        # generate coordinate axes
 | 
			
		||||
@ -497,102 +915,6 @@ class Volumes:
 | 
			
		||||
            .view(pts_shape)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return self._densities.shape[0]
 | 
			
		||||
 | 
			
		||||
    def __getitem__(
 | 
			
		||||
        self,
 | 
			
		||||
        index: Union[
 | 
			
		||||
            int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
 | 
			
		||||
        ],
 | 
			
		||||
    ) -> "Volumes":
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            index: Specifying the index of the volume to retrieve.
 | 
			
		||||
                Can be an int, slice, list of ints or a boolean or a long tensor.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Volumes object with selected volumes. The tensors are not cloned.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            index = torch.LongTensor([index])
 | 
			
		||||
        elif isinstance(index, (slice, list, tuple)):
 | 
			
		||||
            pass
 | 
			
		||||
        elif torch.is_tensor(index):
 | 
			
		||||
            if index.dim() != 1 or index.dtype.is_floating_point:
 | 
			
		||||
                raise IndexError(index)
 | 
			
		||||
        else:
 | 
			
		||||
            raise IndexError(index)
 | 
			
		||||
 | 
			
		||||
        new = self.__class__(
 | 
			
		||||
            # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
 | 
			
		||||
            features=self.features()[index] if self._features is not None else None,
 | 
			
		||||
            densities=self.densities()[index],
 | 
			
		||||
        )
 | 
			
		||||
        # dont forget to update grid_sizes!
 | 
			
		||||
        new._grid_sizes = self.get_grid_sizes()[index]
 | 
			
		||||
        new._local_to_world_transform = self._local_to_world_transform[index]
 | 
			
		||||
        return new
 | 
			
		||||
 | 
			
		||||
    def features(self) -> Optional[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the features of the volume.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **features**: The tensor of volume features.
 | 
			
		||||
        """
 | 
			
		||||
        return self._features
 | 
			
		||||
 | 
			
		||||
    def densities(self) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the densities of the volume.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            **densities**: The tensor of volume densities.
 | 
			
		||||
        """
 | 
			
		||||
        return self._densities
 | 
			
		||||
 | 
			
		||||
    def densities_list(self) -> List[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Get the list representation of the densities.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of tensors of densities of shape (dim_i, D_i, H_i, W_i).
 | 
			
		||||
        """
 | 
			
		||||
        return self._features_densities_list(self.densities())
 | 
			
		||||
 | 
			
		||||
    def features_list(self) -> List[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Get the list representation of the features.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of tensors of features of shape (dim_i, D_i, H_i, W_i)
 | 
			
		||||
            or `None` for feature-less volumes.
 | 
			
		||||
        """
 | 
			
		||||
        features_ = self.features()
 | 
			
		||||
        if features_ is None:
 | 
			
		||||
            # No features provided so return None
 | 
			
		||||
            # pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
 | 
			
		||||
            return None
 | 
			
		||||
        return self._features_densities_list(features_)
 | 
			
		||||
 | 
			
		||||
    def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        Retrieve the list representation of features/densities.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            x: self.features() or self.densities()
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of tensors of features/densities of shape (dim_i, D_i, H_i, W_i).
 | 
			
		||||
        """
 | 
			
		||||
        x_dim = x.shape[1]
 | 
			
		||||
        pad_sizes = torch.nn.functional.pad(
 | 
			
		||||
            self.get_grid_sizes(), [1, 0], mode="constant", value=x_dim
 | 
			
		||||
        )
 | 
			
		||||
        x_list = struct_utils.padded_to_list(x, pad_sizes.tolist())
 | 
			
		||||
        return x_list
 | 
			
		||||
 | 
			
		||||
    def get_grid_sizes(self) -> torch.LongTensor:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the sizes of individual volumetric grids in the structure.
 | 
			
		||||
@ -603,59 +925,6 @@ class Volumes:
 | 
			
		||||
        """
 | 
			
		||||
        return self._grid_sizes
 | 
			
		||||
 | 
			
		||||
    def update_padded(
 | 
			
		||||
        self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
 | 
			
		||||
    ) -> "Volumes":
 | 
			
		||||
        """
 | 
			
		||||
        Returns a Volumes structure with updated padded tensors and copies of
 | 
			
		||||
        the auxiliary tensors `self._local_to_world_transform`,
 | 
			
		||||
        `device` and `self._grid_sizes`. This function allows for an update of
 | 
			
		||||
        densities (and features) without having to explicitly
 | 
			
		||||
        convert it to the list representation for heterogeneous batches.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            new_densities: FloatTensor of shape (N, dim_density, D, H, W)
 | 
			
		||||
            new_features: (optional) FloatTensor of shape (N, dim_feature, D, H, W)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Volumes with updated features and densities
 | 
			
		||||
        """
 | 
			
		||||
        new = copy.copy(self)
 | 
			
		||||
        new._set_densities(new_densities)
 | 
			
		||||
        if new_features is None:
 | 
			
		||||
            new._features = None
 | 
			
		||||
        else:
 | 
			
		||||
            new._set_features(new_features)
 | 
			
		||||
        return new
 | 
			
		||||
 | 
			
		||||
    def _set_features(self, features: _TensorBatch) -> None:
 | 
			
		||||
        self._set_densities_features("features", features)
 | 
			
		||||
 | 
			
		||||
    def _set_densities(self, densities: _TensorBatch) -> None:
 | 
			
		||||
        self._set_densities_features("densities", densities)
 | 
			
		||||
 | 
			
		||||
    def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None:
 | 
			
		||||
        x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
 | 
			
		||||
        if x_tensor.device != self.device:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"`{var_name}` have to be on the same device as `self.densities`."
 | 
			
		||||
            )
 | 
			
		||||
        if len(x_tensor.shape) != 5:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"{var_name} has to be a 5-dim tensor of shape: "
 | 
			
		||||
                f"(minibatch, {var_name}_dim, height, width, depth)"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if not (
 | 
			
		||||
            (self.get_grid_sizes().shape == grid_sizes.shape)
 | 
			
		||||
            and torch.allclose(self.get_grid_sizes(), grid_sizes)
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"The size of every grid in `{var_name}` has to match the size of"
 | 
			
		||||
                "the corresponding `densities` grid."
 | 
			
		||||
            )
 | 
			
		||||
        setattr(self, "_" + var_name, x_tensor)
 | 
			
		||||
 | 
			
		||||
    def _set_local_to_world_transform(
 | 
			
		||||
        self,
 | 
			
		||||
        voxel_size: _VoxelSize = 1.0,
 | 
			
		||||
@ -690,17 +959,104 @@ class Volumes:
 | 
			
		||||
            voxel_size, volume_translation, len(self)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def clone(self) -> "Volumes":
 | 
			
		||||
    def _copy_transform_and_sizes(
 | 
			
		||||
        self,
 | 
			
		||||
        other: "VolumeLocator",
 | 
			
		||||
        device: Optional[torch.device] = None,
 | 
			
		||||
        index: Optional[
 | 
			
		||||
            Union[int, List[int], Tuple[int], slice, torch.Tensor]
 | 
			
		||||
        ] = _ALL_CONTENT,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Deep copy of Volumes object. All internal tensors are cloned
 | 
			
		||||
        individually.
 | 
			
		||||
        Copies the local to world transform and grid sizes to other VolumeLocator object
 | 
			
		||||
        and moves it to specified device. Operates in place on other.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            new Volumes object.
 | 
			
		||||
        Args:
 | 
			
		||||
            other: VolumeLocator object to which to copy
 | 
			
		||||
            device: torch.device on which to put the result, defatults to self.device
 | 
			
		||||
            index: Specifies which parts to copy.
 | 
			
		||||
                Can be an int, slice, list of ints or a boolean or a long tensor.
 | 
			
		||||
                Defaults to all items (`:`).
 | 
			
		||||
        """
 | 
			
		||||
        return copy.deepcopy(self)
 | 
			
		||||
        device = device if device is not None else self.device
 | 
			
		||||
        other._grid_sizes = self._grid_sizes[index].to(device)
 | 
			
		||||
        other._local_to_world_transform = self.get_local_to_world_coords_transform()[
 | 
			
		||||
            index
 | 
			
		||||
        ].to(device)
 | 
			
		||||
 | 
			
		||||
    def to(self, device: Device, copy: bool = False) -> "Volumes":
 | 
			
		||||
    def _handle_voxel_size(
 | 
			
		||||
        self, voxel_size: _VoxelSize, batch_size: int
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Handle the `voxel_size` argument to the `VolumeLocator` constructor.
 | 
			
		||||
        """
 | 
			
		||||
        err_msg = (
 | 
			
		||||
            "voxel_size has to be either a 3-tuple of scalars, or a scalar, or"
 | 
			
		||||
            " a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)."
 | 
			
		||||
        )
 | 
			
		||||
        if isinstance(voxel_size, (float, int)):
 | 
			
		||||
            # convert a scalar to a 3-element tensor
 | 
			
		||||
            voxel_size = torch.full(
 | 
			
		||||
                (1, 3), voxel_size, device=self.device, dtype=torch.float32
 | 
			
		||||
            )
 | 
			
		||||
        elif isinstance(voxel_size, torch.Tensor):
 | 
			
		||||
            if voxel_size.numel() == 1:
 | 
			
		||||
                # convert a single-element tensor to a 3-element one
 | 
			
		||||
                voxel_size = voxel_size.view(-1).repeat(3)
 | 
			
		||||
            elif len(voxel_size.shape) == 2 and (
 | 
			
		||||
                voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1
 | 
			
		||||
            ):
 | 
			
		||||
                voxel_size = voxel_size.repeat(1, 3)
 | 
			
		||||
        return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
 | 
			
		||||
 | 
			
		||||
    def _handle_volume_translation(
 | 
			
		||||
        self, translation: _Translation, batch_size: int
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Handle the `volume_translation` argument to the `VolumeLocator` constructor.
 | 
			
		||||
        """
 | 
			
		||||
        err_msg = (
 | 
			
		||||
            "`volume_translation` has to be either a 3-tuple of scalars, or"
 | 
			
		||||
            " a Tensor of shape (1,3) or (minibatch, 3) or (3,)`."
 | 
			
		||||
        )
 | 
			
		||||
        return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return self._batch_size
 | 
			
		||||
 | 
			
		||||
    def _convert_volume_property_to_tensor(
 | 
			
		||||
        self, x: _Vector, batch_size: int, err_msg: str
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Handle the `volume_translation` or `voxel_size` argument to
 | 
			
		||||
        the VolumeLocator constructor.
 | 
			
		||||
        Return a tensor of shape (N, 3) where N is the batch_size.
 | 
			
		||||
        """
 | 
			
		||||
        if isinstance(x, (list, tuple)):
 | 
			
		||||
            if len(x) != 3:
 | 
			
		||||
                raise ValueError(err_msg)
 | 
			
		||||
            x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
 | 
			
		||||
            x = x.repeat((batch_size, 1))
 | 
			
		||||
        elif isinstance(x, torch.Tensor):
 | 
			
		||||
            ok = (
 | 
			
		||||
                (x.shape[0] == 1 and x.shape[1] == 3)
 | 
			
		||||
                or (x.shape[0] == 3 and len(x.shape) == 1)
 | 
			
		||||
                or (x.shape[0] == batch_size and x.shape[1] == 3)
 | 
			
		||||
            )
 | 
			
		||||
            if not ok:
 | 
			
		||||
                raise ValueError(err_msg)
 | 
			
		||||
            if x.device != self.device:
 | 
			
		||||
                x = x.to(self.device)
 | 
			
		||||
            if x.shape[0] == 3 and len(x.shape) == 1:
 | 
			
		||||
                x = x[None]
 | 
			
		||||
            if x.shape[0] == 1:
 | 
			
		||||
                x = x.repeat((batch_size, 1))
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(err_msg)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
    def to(self, device: Device, copy: bool = False) -> "VolumeLocator":
 | 
			
		||||
        """
 | 
			
		||||
        Match the functionality of torch.Tensor.to()
 | 
			
		||||
        If copy = True or the self Tensor is on a different device, the
 | 
			
		||||
@ -713,7 +1069,7 @@ class Volumes:
 | 
			
		||||
            copy: Boolean indicator whether or not to clone self. Default False.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Volumes object.
 | 
			
		||||
            VolumeLocator object.
 | 
			
		||||
        """
 | 
			
		||||
        device_ = make_device(device)
 | 
			
		||||
        if not copy and self.device == device_:
 | 
			
		||||
@ -724,18 +1080,24 @@ class Volumes:
 | 
			
		||||
            return other
 | 
			
		||||
 | 
			
		||||
        other.device = device_
 | 
			
		||||
        other._densities = self._densities.to(device_)
 | 
			
		||||
        if self._features is not None:
 | 
			
		||||
            # pyre-fixme[16]: `Optional` has no attribute `to`.
 | 
			
		||||
            other._features = self.features().to(device_)
 | 
			
		||||
        other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
 | 
			
		||||
            device_
 | 
			
		||||
        )
 | 
			
		||||
        other._grid_sizes = self._grid_sizes.to(device_)
 | 
			
		||||
        other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
 | 
			
		||||
            device
 | 
			
		||||
        )
 | 
			
		||||
        return other
 | 
			
		||||
 | 
			
		||||
    def cpu(self) -> "Volumes":
 | 
			
		||||
    def clone(self) -> "VolumeLocator":
 | 
			
		||||
        """
 | 
			
		||||
        Deep copy of VoluVolumeLocatormes object. All internal tensors are cloned
 | 
			
		||||
        individually.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            new VolumeLocator object.
 | 
			
		||||
        """
 | 
			
		||||
        return copy.deepcopy(self)
 | 
			
		||||
 | 
			
		||||
    def cpu(self) -> "VolumeLocator":
 | 
			
		||||
        return self.to("cpu")
 | 
			
		||||
 | 
			
		||||
    def cuda(self) -> "Volumes":
 | 
			
		||||
    def cuda(self) -> "VolumeLocator":
 | 
			
		||||
        return self.to("cuda")
 | 
			
		||||
 | 
			
		||||
@ -11,7 +11,7 @@ import unittest
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.structures.volumes import Volumes
 | 
			
		||||
from pytorch3d.structures.volumes import VolumeLocator, Volumes
 | 
			
		||||
from pytorch3d.transforms import Scale
 | 
			
		||||
 | 
			
		||||
from .common_testing import TestCaseMixin
 | 
			
		||||
@ -53,8 +53,8 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        for selectedIdx, index in indices:
 | 
			
		||||
            self.assertClose(selected.densities()[selectedIdx], v.densities()[index])
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                v._local_to_world_transform.get_matrix()[index],
 | 
			
		||||
                selected._local_to_world_transform.get_matrix()[selectedIdx],
 | 
			
		||||
                v.locator._local_to_world_transform.get_matrix()[index],
 | 
			
		||||
                selected.locator._local_to_world_transform.get_matrix()[selectedIdx],
 | 
			
		||||
            )
 | 
			
		||||
            if selected.features() is not None:
 | 
			
		||||
                self.assertClose(selected.features()[selectedIdx], v.features()[index])
 | 
			
		||||
@ -149,10 +149,55 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            with self.assertRaises(IndexError):
 | 
			
		||||
                v_selected = v[index]
 | 
			
		||||
 | 
			
		||||
    def test_locator_init(self, batch_size=9, resolution=(3, 5, 7)):
 | 
			
		||||
        with self.subTest("VolumeLocator init with all sizes equal"):
 | 
			
		||||
            grid_sizes = [resolution for _ in range(batch_size)]
 | 
			
		||||
            locator_tuple = VolumeLocator(
 | 
			
		||||
                batch_size=batch_size, grid_sizes=resolution, device=torch.device("cpu")
 | 
			
		||||
            )
 | 
			
		||||
            locator_list = VolumeLocator(
 | 
			
		||||
                batch_size=batch_size, grid_sizes=grid_sizes, device=torch.device("cpu")
 | 
			
		||||
            )
 | 
			
		||||
            locator_tensor = VolumeLocator(
 | 
			
		||||
                batch_size=batch_size,
 | 
			
		||||
                grid_sizes=torch.tensor(grid_sizes),
 | 
			
		||||
                device=torch.device("cpu"),
 | 
			
		||||
            )
 | 
			
		||||
            expected_grid_sizes = torch.tensor(grid_sizes)
 | 
			
		||||
            expected_resolution = resolution
 | 
			
		||||
            assert torch.allclose(expected_grid_sizes, locator_tuple._grid_sizes)
 | 
			
		||||
            assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
 | 
			
		||||
            assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
 | 
			
		||||
            self.assertEqual(expected_resolution, locator_tuple._resolution)
 | 
			
		||||
            self.assertEqual(expected_resolution, locator_list._resolution)
 | 
			
		||||
            self.assertEqual(expected_resolution, locator_tensor._resolution)
 | 
			
		||||
 | 
			
		||||
        with self.subTest("VolumeLocator with different sizes in different grids"):
 | 
			
		||||
            grid_sizes_list = [
 | 
			
		||||
                torch.randint(low=1, high=42, size=(3,)) for _ in range(batch_size)
 | 
			
		||||
            ]
 | 
			
		||||
            grid_sizes_tensor = torch.cat([el[None] for el in grid_sizes_list])
 | 
			
		||||
            locator_list = VolumeLocator(
 | 
			
		||||
                batch_size=batch_size,
 | 
			
		||||
                grid_sizes=grid_sizes_list,
 | 
			
		||||
                device=torch.device("cpu"),
 | 
			
		||||
            )
 | 
			
		||||
            locator_tensor = VolumeLocator(
 | 
			
		||||
                batch_size=batch_size,
 | 
			
		||||
                grid_sizes=grid_sizes_tensor,
 | 
			
		||||
                device=torch.device("cpu"),
 | 
			
		||||
            )
 | 
			
		||||
            expected_grid_sizes = grid_sizes_tensor
 | 
			
		||||
            expected_resolution = tuple(torch.max(expected_grid_sizes, dim=0).values)
 | 
			
		||||
            assert torch.allclose(expected_grid_sizes, locator_list._grid_sizes)
 | 
			
		||||
            assert torch.allclose(expected_grid_sizes, locator_tensor._grid_sizes)
 | 
			
		||||
            self.assertEqual(expected_resolution, locator_list._resolution)
 | 
			
		||||
            self.assertEqual(expected_resolution, locator_tensor._resolution)
 | 
			
		||||
 | 
			
		||||
    def test_coord_transforms(self, num_volumes=3, num_channels=4, dtype=torch.float32):
 | 
			
		||||
        """
 | 
			
		||||
        Test the correctness of the conversion between the internal
 | 
			
		||||
        Transform3D Volumes._local_to_world_transform and the initialization
 | 
			
		||||
        Transform3D Volumes.VolumeLocator._local_to_world_transform and the initialization
 | 
			
		||||
        from the translation and voxel_size.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -440,7 +485,10 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        for var_name, var in vars(v).items():
 | 
			
		||||
            if var_name != "device":
 | 
			
		||||
                if var is not None:
 | 
			
		||||
                    self.assertTrue(var.device.type == desired_device.type)
 | 
			
		||||
                    self.assertTrue(
 | 
			
		||||
                        var.device.type == desired_device.type,
 | 
			
		||||
                        (var_name, var.device, desired_device),
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                self.assertTrue(var.type == desired_device.type)
 | 
			
		||||
 | 
			
		||||
@ -456,60 +504,74 @@ class TestVolumes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        densities = torch.rand(size=[num_volumes, 1, *size], dtype=dtype)
 | 
			
		||||
        volumes = Volumes(densities=densities, features=features)
 | 
			
		||||
        locator = VolumeLocator(
 | 
			
		||||
            batch_size=5, grid_sizes=(3, 5, 7), device=volumes.device
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Test support for str and torch.device
 | 
			
		||||
        cpu_device = torch.device("cpu")
 | 
			
		||||
        for name, obj in (("VolumeLocator", locator), ("Volumes", volumes)):
 | 
			
		||||
            with self.subTest(f"Moving {name} from/to gpu and cpu"):
 | 
			
		||||
                # Test support for str and torch.device
 | 
			
		||||
                cpu_device = torch.device("cpu")
 | 
			
		||||
 | 
			
		||||
        converted_volumes = volumes.to("cpu")
 | 
			
		||||
        self.assertEqual(cpu_device, converted_volumes.device)
 | 
			
		||||
        self.assertEqual(cpu_device, volumes.device)
 | 
			
		||||
        self.assertIs(volumes, converted_volumes)
 | 
			
		||||
                converted_obj = obj.to("cpu")
 | 
			
		||||
                self.assertEqual(cpu_device, converted_obj.device)
 | 
			
		||||
                self.assertEqual(cpu_device, obj.device)
 | 
			
		||||
                self.assertIs(obj, converted_obj)
 | 
			
		||||
 | 
			
		||||
        converted_volumes = volumes.to(cpu_device)
 | 
			
		||||
        self.assertEqual(cpu_device, converted_volumes.device)
 | 
			
		||||
        self.assertEqual(cpu_device, volumes.device)
 | 
			
		||||
        self.assertIs(volumes, converted_volumes)
 | 
			
		||||
                converted_obj = obj.to(cpu_device)
 | 
			
		||||
                self.assertEqual(cpu_device, converted_obj.device)
 | 
			
		||||
                self.assertEqual(cpu_device, obj.device)
 | 
			
		||||
                self.assertIs(obj, converted_obj)
 | 
			
		||||
 | 
			
		||||
        cuda_device = torch.device("cuda:0")
 | 
			
		||||
                cuda_device = torch.device("cuda:0")
 | 
			
		||||
 | 
			
		||||
        converted_volumes = volumes.to("cuda:0")
 | 
			
		||||
        self.assertEqual(cuda_device, converted_volumes.device)
 | 
			
		||||
        self.assertEqual(cpu_device, volumes.device)
 | 
			
		||||
        self.assertIsNot(volumes, converted_volumes)
 | 
			
		||||
                converted_obj = obj.to("cuda:0")
 | 
			
		||||
                self.assertEqual(cuda_device, converted_obj.device)
 | 
			
		||||
                self.assertEqual(cpu_device, obj.device)
 | 
			
		||||
                self.assertIsNot(obj, converted_obj)
 | 
			
		||||
 | 
			
		||||
        converted_volumes = volumes.to(cuda_device)
 | 
			
		||||
        self.assertEqual(cuda_device, converted_volumes.device)
 | 
			
		||||
        self.assertEqual(cpu_device, volumes.device)
 | 
			
		||||
        self.assertIsNot(volumes, converted_volumes)
 | 
			
		||||
                converted_obj = obj.to(cuda_device)
 | 
			
		||||
                self.assertEqual(cuda_device, converted_obj.device)
 | 
			
		||||
                self.assertEqual(cpu_device, obj.device)
 | 
			
		||||
                self.assertIsNot(obj, converted_obj)
 | 
			
		||||
 | 
			
		||||
        # Test device placement of internal tensors
 | 
			
		||||
        features = features.to(cuda_device)
 | 
			
		||||
        densities = features.to(cuda_device)
 | 
			
		||||
        with self.subTest("Test device placement of internal tensors of Volumes"):
 | 
			
		||||
            features = features.to(cuda_device)
 | 
			
		||||
            densities = features.to(cuda_device)
 | 
			
		||||
 | 
			
		||||
        for features_ in (features, None):
 | 
			
		||||
            volumes = Volumes(densities=densities, features=features_)
 | 
			
		||||
            for features_ in (features, None):
 | 
			
		||||
                volumes = Volumes(densities=densities, features=features_)
 | 
			
		||||
 | 
			
		||||
            cpu_volumes = volumes.cpu()
 | 
			
		||||
            cuda_volumes = cpu_volumes.cuda()
 | 
			
		||||
            cuda_volumes2 = cuda_volumes.cuda()
 | 
			
		||||
            cpu_volumes2 = cuda_volumes2.cpu()
 | 
			
		||||
                cpu_volumes = volumes.cpu()
 | 
			
		||||
                cuda_volumes = cpu_volumes.cuda()
 | 
			
		||||
                cuda_volumes2 = cuda_volumes.cuda()
 | 
			
		||||
                cpu_volumes2 = cuda_volumes2.cpu()
 | 
			
		||||
 | 
			
		||||
            for volumes1, volumes2 in itertools.combinations(
 | 
			
		||||
                (volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
 | 
			
		||||
            ):
 | 
			
		||||
                if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
 | 
			
		||||
                    # checks that we do not copy if the devices stay the same
 | 
			
		||||
                    assert_fun = self.assertIs
 | 
			
		||||
                else:
 | 
			
		||||
                    assert_fun = self.assertSeparate
 | 
			
		||||
                assert_fun(volumes1._densities, volumes2._densities)
 | 
			
		||||
                if features_ is not None:
 | 
			
		||||
                    assert_fun(volumes1._features, volumes2._features)
 | 
			
		||||
                for volumes_ in (volumes1, volumes2):
 | 
			
		||||
                    if volumes_ in (cpu_volumes, cpu_volumes2):
 | 
			
		||||
                        self._check_vars_on_device(volumes_, cpu_device)
 | 
			
		||||
                for volumes1, volumes2 in itertools.combinations(
 | 
			
		||||
                    (volumes, cpu_volumes, cpu_volumes2, cuda_volumes, cuda_volumes2), 2
 | 
			
		||||
                ):
 | 
			
		||||
                    if volumes1 is cuda_volumes and volumes2 is cuda_volumes2:
 | 
			
		||||
                        # checks that we do not copy if the devices stay the same
 | 
			
		||||
                        assert_fun = self.assertIs
 | 
			
		||||
                    else:
 | 
			
		||||
                        self._check_vars_on_device(volumes_, cuda_device)
 | 
			
		||||
                        assert_fun = self.assertSeparate
 | 
			
		||||
                    assert_fun(volumes1._densities, volumes2._densities)
 | 
			
		||||
                    if features_ is not None:
 | 
			
		||||
                        assert_fun(volumes1._features, volumes2._features)
 | 
			
		||||
                    for volumes_ in (volumes1, volumes2):
 | 
			
		||||
                        if volumes_ in (cpu_volumes, cpu_volumes2):
 | 
			
		||||
                            self._check_vars_on_device(volumes_, cpu_device)
 | 
			
		||||
                        else:
 | 
			
		||||
                            self._check_vars_on_device(volumes_, cuda_device)
 | 
			
		||||
 | 
			
		||||
        with self.subTest("Test device placement of internal tensors of VolumeLocator"):
 | 
			
		||||
            for device1, device2 in itertools.combinations(
 | 
			
		||||
                (torch.device("cpu"), torch.device("cuda:0")), 2
 | 
			
		||||
            ):
 | 
			
		||||
                locator = locator.to(device1)
 | 
			
		||||
                locator = locator.to(device2)
 | 
			
		||||
                self.assertEqual(locator._grid_sizes.device, device2)
 | 
			
		||||
                self.assertEqual(locator._local_to_world_transform.device, device2)
 | 
			
		||||
 | 
			
		||||
    def _check_padded(self, x_pad, x_list, grid_sizes):
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user