Improve volumes type annotations

Summary: Improve type annotations for volumes and remove a few pyre fixmes

Reviewed By: nikhilaravi

Differential Revision: D28943371

fbshipit-source-id: ca2b7a50d72a392910e65cee5e564f34523414d2
This commit is contained in:
Patrick Labatut 2021-06-09 15:48:56 -07:00 committed by Facebook GitHub Bot
parent a15c33a3cc
commit 02650672f6
2 changed files with 53 additions and 32 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Sequence, Union
from typing import List, Sequence, Tuple, Union
import torch
@ -11,7 +11,7 @@ Util functions for points/verts/faces/volumes.
def list_to_padded(
x: List[torch.Tensor],
x: Union[List[torch.Tensor], Tuple[torch.Tensor]],
pad_size: Union[Sequence[int], None] = None,
pad_value: float = 0.0,
equisized: bool = False,
@ -66,7 +66,6 @@ def list_to_padded(
pad_dims = pad_size
N = len(x)
# pyre-fixme[16]: `Tensor` has no attribute `new_full`.
x_padded = x[0].new_full((N, *pad_dims), pad_value)
for i, y in enumerate(x):
if len(y) > 0:

View File

@ -1,4 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import copy
from typing import List, Optional, Tuple, Union
@ -8,6 +9,16 @@ from ..transforms import Scale, Transform3d
from . import utils as struct_utils
_Scalar = Union[int, float]
_Vector = Union[torch.Tensor, Tuple[_Scalar, ...], List[_Scalar]]
_ScalarOrVector = Union[_Scalar, _Vector]
_VoxelSize = _ScalarOrVector
_Translation = _Vector
_TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
class Volumes:
"""
This class provides functions for working with batches of volumetric grids
@ -135,10 +146,10 @@ class Volumes:
def __init__(
self,
densities: Union[List[torch.Tensor], torch.Tensor],
features: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0,
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0),
densities: _TensorBatch,
features: Optional[_TensorBatch] = None,
voxel_size: _VoxelSize = 1.0,
volume_translation: _Translation = (0.0, 0.0, 0.0),
):
"""
Args:
@ -193,8 +204,8 @@ class Volumes:
)
def _convert_densities_features_to_tensor(
self, x: Union[List[torch.Tensor], torch.Tensor], var_name: str
):
self, x: _TensorBatch, var_name: str
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Handle the `densities` or `features` arguments to the constructor.
"""
@ -236,9 +247,12 @@ class Volumes:
)
return x_tensor, x_shapes
def _vsize_translation_to_transform(
self, voxel_size, volume_translation, batch_size
):
def _voxel_size_translation_to_transform(
self,
voxel_size: torch.Tensor,
volume_translation: torch.Tensor,
batch_size: int,
) -> Transform3d:
"""
Converts the `voxel_size` and `volume_translation` constructor arguments
to the internal `Transform3D` object `local_to_world_transform`.
@ -260,7 +274,9 @@ class Volumes:
return local_to_world_transform
def _handle_voxel_size(self, voxel_size, batch_size):
def _handle_voxel_size(
self, voxel_size: _VoxelSize, batch_size: int
) -> torch.Tensor:
"""
Handle the `voxel_size` argument to the `Volumes` constructor.
"""
@ -270,8 +286,10 @@ class Volumes:
)
if isinstance(voxel_size, (float, int)):
# convert a scalar to a 3-element tensor
voxel_size = (voxel_size,) * 3
if torch.is_tensor(voxel_size):
voxel_size = torch.full(
(1, 3), voxel_size, device=self.device, dtype=torch.float32
)
elif isinstance(voxel_size, torch.Tensor):
if voxel_size.numel() == 1:
# convert a single-element tensor to a 3-element one
voxel_size = voxel_size.view(-1).repeat(3)
@ -281,7 +299,9 @@ class Volumes:
voxel_size = voxel_size.repeat(1, 3)
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
def _handle_volume_translation(self, translation, batch_size):
def _handle_volume_translation(
self, translation: _Translation, batch_size: int
) -> torch.Tensor:
"""
Handle the `volume_translation` argument to the `Volumes` constructor.
"""
@ -292,20 +312,19 @@ class Volumes:
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
def _convert_volume_property_to_tensor(
self, x, batch_size, err_msg
self, x: _Vector, batch_size: int, err_msg: str
) -> torch.Tensor:
"""
Handle the `volume_translation` or `voxel_size` argument to
the Volumes constructor.
Return a tensor of shape (N, 3) where N is the batch_size or 1
if batch_size is None.
Return a tensor of shape (N, 3) where N is the batch_size.
"""
if isinstance(x, (list, tuple)):
if len(x) != 3:
raise ValueError(err_msg)
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
x = x.repeat((batch_size, 1))
elif torch.is_tensor(x):
elif isinstance(x, torch.Tensor):
ok = (
(x.shape[0] == 1 and x.shape[1] == 3)
or (x.shape[0] == 3 and len(x.shape) == 1)
@ -544,13 +563,14 @@ class Volumes:
list of tensors of features of shape (dim_i, D_i, H_i, W_i)
or `None` for feature-less volumes.
"""
if self._features is None:
features_ = self.features()
if features_ is None:
# No features provided so return None
# pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
return None
return self._features_densities_list(self.features())
return self._features_densities_list(features_)
def _features_densities_list(self, x) -> List[torch.Tensor]:
def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Retrieve the list representation of features/densities.
@ -577,7 +597,9 @@ class Volumes:
"""
return self._grid_sizes
def update_padded(self, new_densities, new_features=None) -> "Volumes":
def update_padded(
self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
) -> "Volumes":
"""
Returns a Volumes structure with updated padded tensors and copies of
the auxiliary tensors `self._local_to_world_transform`,
@ -600,13 +622,13 @@ class Volumes:
new._set_features(new_features)
return new
def _set_features(self, features):
self._set_densities_features(features, "features")
def _set_features(self, features: _TensorBatch) -> None:
self._set_densities_features("features", features)
def _set_densities(self, densities):
self._set_densities_features(densities, "densities")
def _set_densities(self, densities: _TensorBatch) -> None:
self._set_densities_features("densities", densities)
def _set_densities_features(self, x, var_name):
def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None:
x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
if x_tensor.device != self.device:
raise ValueError(
@ -630,8 +652,8 @@ class Volumes:
def _set_local_to_world_transform(
self,
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0,
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0),
voxel_size: _VoxelSize = 1.0,
volume_translation: _Translation = (0.0, 0.0, 0.0),
):
"""
Sets the internal representation of the transformation between the
@ -658,7 +680,7 @@ class Volumes:
volume_translation = self._handle_volume_translation(
volume_translation, len(self)
)
self._local_to_world_transform = self._vsize_translation_to_transform(
self._local_to_world_transform = self._voxel_size_translation_to_transform(
voxel_size, volume_translation, len(self)
)