From 02650672f621f4827088164ee3e3ab1b13ab1159 Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Wed, 9 Jun 2021 15:48:56 -0700 Subject: [PATCH] Improve volumes type annotations Summary: Improve type annotations for volumes and remove a few pyre fixmes Reviewed By: nikhilaravi Differential Revision: D28943371 fbshipit-source-id: ca2b7a50d72a392910e65cee5e564f34523414d2 --- pytorch3d/structures/utils.py | 5 +-- pytorch3d/structures/volumes.py | 80 +++++++++++++++++++++------------ 2 files changed, 53 insertions(+), 32 deletions(-) diff --git a/pytorch3d/structures/utils.py b/pytorch3d/structures/utils.py index 77c2c3f7..d8b888a3 100644 --- a/pytorch3d/structures/utils.py +++ b/pytorch3d/structures/utils.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import List, Sequence, Union +from typing import List, Sequence, Tuple, Union import torch @@ -11,7 +11,7 @@ Util functions for points/verts/faces/volumes. def list_to_padded( - x: List[torch.Tensor], + x: Union[List[torch.Tensor], Tuple[torch.Tensor]], pad_size: Union[Sequence[int], None] = None, pad_value: float = 0.0, equisized: bool = False, @@ -66,7 +66,6 @@ def list_to_padded( pad_dims = pad_size N = len(x) - # pyre-fixme[16]: `Tensor` has no attribute `new_full`. x_padded = x[0].new_full((N, *pad_dims), pad_value) for i, y in enumerate(x): if len(y) > 0: diff --git a/pytorch3d/structures/volumes.py b/pytorch3d/structures/volumes.py index 010cb14b..cec18d58 100644 --- a/pytorch3d/structures/volumes.py +++ b/pytorch3d/structures/volumes.py @@ -1,4 +1,5 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + import copy from typing import List, Optional, Tuple, Union @@ -8,6 +9,16 @@ from ..transforms import Scale, Transform3d from . import utils as struct_utils +_Scalar = Union[int, float] +_Vector = Union[torch.Tensor, Tuple[_Scalar, ...], List[_Scalar]] +_ScalarOrVector = Union[_Scalar, _Vector] + +_VoxelSize = _ScalarOrVector +_Translation = _Vector + +_TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] + + class Volumes: """ This class provides functions for working with batches of volumetric grids @@ -135,10 +146,10 @@ class Volumes: def __init__( self, - densities: Union[List[torch.Tensor], torch.Tensor], - features: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, - voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0, - volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0), + densities: _TensorBatch, + features: Optional[_TensorBatch] = None, + voxel_size: _VoxelSize = 1.0, + volume_translation: _Translation = (0.0, 0.0, 0.0), ): """ Args: @@ -193,8 +204,8 @@ class Volumes: ) def _convert_densities_features_to_tensor( - self, x: Union[List[torch.Tensor], torch.Tensor], var_name: str - ): + self, x: _TensorBatch, var_name: str + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Handle the `densities` or `features` arguments to the constructor. """ @@ -236,9 +247,12 @@ class Volumes: ) return x_tensor, x_shapes - def _vsize_translation_to_transform( - self, voxel_size, volume_translation, batch_size - ): + def _voxel_size_translation_to_transform( + self, + voxel_size: torch.Tensor, + volume_translation: torch.Tensor, + batch_size: int, + ) -> Transform3d: """ Converts the `voxel_size` and `volume_translation` constructor arguments to the internal `Transform3D` object `local_to_world_transform`. @@ -260,7 +274,9 @@ class Volumes: return local_to_world_transform - def _handle_voxel_size(self, voxel_size, batch_size): + def _handle_voxel_size( + self, voxel_size: _VoxelSize, batch_size: int + ) -> torch.Tensor: """ Handle the `voxel_size` argument to the `Volumes` constructor. """ @@ -270,8 +286,10 @@ class Volumes: ) if isinstance(voxel_size, (float, int)): # convert a scalar to a 3-element tensor - voxel_size = (voxel_size,) * 3 - if torch.is_tensor(voxel_size): + voxel_size = torch.full( + (1, 3), voxel_size, device=self.device, dtype=torch.float32 + ) + elif isinstance(voxel_size, torch.Tensor): if voxel_size.numel() == 1: # convert a single-element tensor to a 3-element one voxel_size = voxel_size.view(-1).repeat(3) @@ -281,7 +299,9 @@ class Volumes: voxel_size = voxel_size.repeat(1, 3) return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg) - def _handle_volume_translation(self, translation, batch_size): + def _handle_volume_translation( + self, translation: _Translation, batch_size: int + ) -> torch.Tensor: """ Handle the `volume_translation` argument to the `Volumes` constructor. """ @@ -292,20 +312,19 @@ class Volumes: return self._convert_volume_property_to_tensor(translation, batch_size, err_msg) def _convert_volume_property_to_tensor( - self, x, batch_size, err_msg + self, x: _Vector, batch_size: int, err_msg: str ) -> torch.Tensor: """ Handle the `volume_translation` or `voxel_size` argument to the Volumes constructor. - Return a tensor of shape (N, 3) where N is the batch_size or 1 - if batch_size is None. + Return a tensor of shape (N, 3) where N is the batch_size. """ if isinstance(x, (list, tuple)): if len(x) != 3: raise ValueError(err_msg) x = torch.tensor(x, device=self.device, dtype=torch.float32)[None] x = x.repeat((batch_size, 1)) - elif torch.is_tensor(x): + elif isinstance(x, torch.Tensor): ok = ( (x.shape[0] == 1 and x.shape[1] == 3) or (x.shape[0] == 3 and len(x.shape) == 1) @@ -544,13 +563,14 @@ class Volumes: list of tensors of features of shape (dim_i, D_i, H_i, W_i) or `None` for feature-less volumes. """ - if self._features is None: + features_ = self.features() + if features_ is None: # No features provided so return None # pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`. return None - return self._features_densities_list(self.features()) + return self._features_densities_list(features_) - def _features_densities_list(self, x) -> List[torch.Tensor]: + def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]: """ Retrieve the list representation of features/densities. @@ -577,7 +597,9 @@ class Volumes: """ return self._grid_sizes - def update_padded(self, new_densities, new_features=None) -> "Volumes": + def update_padded( + self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None + ) -> "Volumes": """ Returns a Volumes structure with updated padded tensors and copies of the auxiliary tensors `self._local_to_world_transform`, @@ -600,13 +622,13 @@ class Volumes: new._set_features(new_features) return new - def _set_features(self, features): - self._set_densities_features(features, "features") + def _set_features(self, features: _TensorBatch) -> None: + self._set_densities_features("features", features) - def _set_densities(self, densities): - self._set_densities_features(densities, "densities") + def _set_densities(self, densities: _TensorBatch) -> None: + self._set_densities_features("densities", densities) - def _set_densities_features(self, x, var_name): + def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None: x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name) if x_tensor.device != self.device: raise ValueError( @@ -630,8 +652,8 @@ class Volumes: def _set_local_to_world_transform( self, - voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0, - volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0), + voxel_size: _VoxelSize = 1.0, + volume_translation: _Translation = (0.0, 0.0, 0.0), ): """ Sets the internal representation of the transformation between the @@ -658,7 +680,7 @@ class Volumes: volume_translation = self._handle_volume_translation( volume_translation, len(self) ) - self._local_to_world_transform = self._vsize_translation_to_transform( + self._local_to_world_transform = self._voxel_size_translation_to_transform( voxel_size, volume_translation, len(self) )