Improve volumes type annotations

Summary: Improve type annotations for volumes and remove a few pyre fixmes

Reviewed By: nikhilaravi

Differential Revision: D28943371

fbshipit-source-id: ca2b7a50d72a392910e65cee5e564f34523414d2
This commit is contained in:
Patrick Labatut 2021-06-09 15:48:56 -07:00 committed by Facebook GitHub Bot
parent a15c33a3cc
commit 02650672f6
2 changed files with 53 additions and 32 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import List, Sequence, Union from typing import List, Sequence, Tuple, Union
import torch import torch
@ -11,7 +11,7 @@ Util functions for points/verts/faces/volumes.
def list_to_padded( def list_to_padded(
x: List[torch.Tensor], x: Union[List[torch.Tensor], Tuple[torch.Tensor]],
pad_size: Union[Sequence[int], None] = None, pad_size: Union[Sequence[int], None] = None,
pad_value: float = 0.0, pad_value: float = 0.0,
equisized: bool = False, equisized: bool = False,
@ -66,7 +66,6 @@ def list_to_padded(
pad_dims = pad_size pad_dims = pad_size
N = len(x) N = len(x)
# pyre-fixme[16]: `Tensor` has no attribute `new_full`.
x_padded = x[0].new_full((N, *pad_dims), pad_value) x_padded = x[0].new_full((N, *pad_dims), pad_value)
for i, y in enumerate(x): for i, y in enumerate(x):
if len(y) > 0: if len(y) > 0:

View File

@ -1,4 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import copy import copy
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@ -8,6 +9,16 @@ from ..transforms import Scale, Transform3d
from . import utils as struct_utils from . import utils as struct_utils
_Scalar = Union[int, float]
_Vector = Union[torch.Tensor, Tuple[_Scalar, ...], List[_Scalar]]
_ScalarOrVector = Union[_Scalar, _Vector]
_VoxelSize = _ScalarOrVector
_Translation = _Vector
_TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
class Volumes: class Volumes:
""" """
This class provides functions for working with batches of volumetric grids This class provides functions for working with batches of volumetric grids
@ -135,10 +146,10 @@ class Volumes:
def __init__( def __init__(
self, self,
densities: Union[List[torch.Tensor], torch.Tensor], densities: _TensorBatch,
features: Optional[Union[List[torch.Tensor], torch.Tensor]] = None, features: Optional[_TensorBatch] = None,
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0, voxel_size: _VoxelSize = 1.0,
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0), volume_translation: _Translation = (0.0, 0.0, 0.0),
): ):
""" """
Args: Args:
@ -193,8 +204,8 @@ class Volumes:
) )
def _convert_densities_features_to_tensor( def _convert_densities_features_to_tensor(
self, x: Union[List[torch.Tensor], torch.Tensor], var_name: str self, x: _TensorBatch, var_name: str
): ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Handle the `densities` or `features` arguments to the constructor. Handle the `densities` or `features` arguments to the constructor.
""" """
@ -236,9 +247,12 @@ class Volumes:
) )
return x_tensor, x_shapes return x_tensor, x_shapes
def _vsize_translation_to_transform( def _voxel_size_translation_to_transform(
self, voxel_size, volume_translation, batch_size self,
): voxel_size: torch.Tensor,
volume_translation: torch.Tensor,
batch_size: int,
) -> Transform3d:
""" """
Converts the `voxel_size` and `volume_translation` constructor arguments Converts the `voxel_size` and `volume_translation` constructor arguments
to the internal `Transform3D` object `local_to_world_transform`. to the internal `Transform3D` object `local_to_world_transform`.
@ -260,7 +274,9 @@ class Volumes:
return local_to_world_transform return local_to_world_transform
def _handle_voxel_size(self, voxel_size, batch_size): def _handle_voxel_size(
self, voxel_size: _VoxelSize, batch_size: int
) -> torch.Tensor:
""" """
Handle the `voxel_size` argument to the `Volumes` constructor. Handle the `voxel_size` argument to the `Volumes` constructor.
""" """
@ -270,8 +286,10 @@ class Volumes:
) )
if isinstance(voxel_size, (float, int)): if isinstance(voxel_size, (float, int)):
# convert a scalar to a 3-element tensor # convert a scalar to a 3-element tensor
voxel_size = (voxel_size,) * 3 voxel_size = torch.full(
if torch.is_tensor(voxel_size): (1, 3), voxel_size, device=self.device, dtype=torch.float32
)
elif isinstance(voxel_size, torch.Tensor):
if voxel_size.numel() == 1: if voxel_size.numel() == 1:
# convert a single-element tensor to a 3-element one # convert a single-element tensor to a 3-element one
voxel_size = voxel_size.view(-1).repeat(3) voxel_size = voxel_size.view(-1).repeat(3)
@ -281,7 +299,9 @@ class Volumes:
voxel_size = voxel_size.repeat(1, 3) voxel_size = voxel_size.repeat(1, 3)
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg) return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
def _handle_volume_translation(self, translation, batch_size): def _handle_volume_translation(
self, translation: _Translation, batch_size: int
) -> torch.Tensor:
""" """
Handle the `volume_translation` argument to the `Volumes` constructor. Handle the `volume_translation` argument to the `Volumes` constructor.
""" """
@ -292,20 +312,19 @@ class Volumes:
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg) return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
def _convert_volume_property_to_tensor( def _convert_volume_property_to_tensor(
self, x, batch_size, err_msg self, x: _Vector, batch_size: int, err_msg: str
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Handle the `volume_translation` or `voxel_size` argument to Handle the `volume_translation` or `voxel_size` argument to
the Volumes constructor. the Volumes constructor.
Return a tensor of shape (N, 3) where N is the batch_size or 1 Return a tensor of shape (N, 3) where N is the batch_size.
if batch_size is None.
""" """
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
if len(x) != 3: if len(x) != 3:
raise ValueError(err_msg) raise ValueError(err_msg)
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None] x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
x = x.repeat((batch_size, 1)) x = x.repeat((batch_size, 1))
elif torch.is_tensor(x): elif isinstance(x, torch.Tensor):
ok = ( ok = (
(x.shape[0] == 1 and x.shape[1] == 3) (x.shape[0] == 1 and x.shape[1] == 3)
or (x.shape[0] == 3 and len(x.shape) == 1) or (x.shape[0] == 3 and len(x.shape) == 1)
@ -544,13 +563,14 @@ class Volumes:
list of tensors of features of shape (dim_i, D_i, H_i, W_i) list of tensors of features of shape (dim_i, D_i, H_i, W_i)
or `None` for feature-less volumes. or `None` for feature-less volumes.
""" """
if self._features is None: features_ = self.features()
if features_ is None:
# No features provided so return None # No features provided so return None
# pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`. # pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
return None return None
return self._features_densities_list(self.features()) return self._features_densities_list(features_)
def _features_densities_list(self, x) -> List[torch.Tensor]: def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
""" """
Retrieve the list representation of features/densities. Retrieve the list representation of features/densities.
@ -577,7 +597,9 @@ class Volumes:
""" """
return self._grid_sizes return self._grid_sizes
def update_padded(self, new_densities, new_features=None) -> "Volumes": def update_padded(
self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
) -> "Volumes":
""" """
Returns a Volumes structure with updated padded tensors and copies of Returns a Volumes structure with updated padded tensors and copies of
the auxiliary tensors `self._local_to_world_transform`, the auxiliary tensors `self._local_to_world_transform`,
@ -600,13 +622,13 @@ class Volumes:
new._set_features(new_features) new._set_features(new_features)
return new return new
def _set_features(self, features): def _set_features(self, features: _TensorBatch) -> None:
self._set_densities_features(features, "features") self._set_densities_features("features", features)
def _set_densities(self, densities): def _set_densities(self, densities: _TensorBatch) -> None:
self._set_densities_features(densities, "densities") self._set_densities_features("densities", densities)
def _set_densities_features(self, x, var_name): def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None:
x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name) x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
if x_tensor.device != self.device: if x_tensor.device != self.device:
raise ValueError( raise ValueError(
@ -630,8 +652,8 @@ class Volumes:
def _set_local_to_world_transform( def _set_local_to_world_transform(
self, self,
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0, voxel_size: _VoxelSize = 1.0,
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0), volume_translation: _Translation = (0.0, 0.0, 0.0),
): ):
""" """
Sets the internal representation of the transformation between the Sets the internal representation of the transformation between the
@ -658,7 +680,7 @@ class Volumes:
volume_translation = self._handle_volume_translation( volume_translation = self._handle_volume_translation(
volume_translation, len(self) volume_translation, len(self)
) )
self._local_to_world_transform = self._vsize_translation_to_transform( self._local_to_world_transform = self._voxel_size_translation_to_transform(
voxel_size, volume_translation, len(self) voxel_size, volume_translation, len(self)
) )