mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Improve volumes type annotations
Summary: Improve type annotations for volumes and remove a few pyre fixmes Reviewed By: nikhilaravi Differential Revision: D28943371 fbshipit-source-id: ca2b7a50d72a392910e65cee5e564f34523414d2
This commit is contained in:
parent
a15c33a3cc
commit
02650672f6
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import List, Sequence, Union
|
||||
from typing import List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -11,7 +11,7 @@ Util functions for points/verts/faces/volumes.
|
||||
|
||||
|
||||
def list_to_padded(
|
||||
x: List[torch.Tensor],
|
||||
x: Union[List[torch.Tensor], Tuple[torch.Tensor]],
|
||||
pad_size: Union[Sequence[int], None] = None,
|
||||
pad_value: float = 0.0,
|
||||
equisized: bool = False,
|
||||
@ -66,7 +66,6 @@ def list_to_padded(
|
||||
pad_dims = pad_size
|
||||
|
||||
N = len(x)
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `new_full`.
|
||||
x_padded = x[0].new_full((N, *pad_dims), pad_value)
|
||||
for i, y in enumerate(x):
|
||||
if len(y) > 0:
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import copy
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@ -8,6 +9,16 @@ from ..transforms import Scale, Transform3d
|
||||
from . import utils as struct_utils
|
||||
|
||||
|
||||
_Scalar = Union[int, float]
|
||||
_Vector = Union[torch.Tensor, Tuple[_Scalar, ...], List[_Scalar]]
|
||||
_ScalarOrVector = Union[_Scalar, _Vector]
|
||||
|
||||
_VoxelSize = _ScalarOrVector
|
||||
_Translation = _Vector
|
||||
|
||||
_TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
|
||||
|
||||
|
||||
class Volumes:
|
||||
"""
|
||||
This class provides functions for working with batches of volumetric grids
|
||||
@ -135,10 +146,10 @@ class Volumes:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
densities: Union[List[torch.Tensor], torch.Tensor],
|
||||
features: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
||||
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0,
|
||||
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0),
|
||||
densities: _TensorBatch,
|
||||
features: Optional[_TensorBatch] = None,
|
||||
voxel_size: _VoxelSize = 1.0,
|
||||
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -193,8 +204,8 @@ class Volumes:
|
||||
)
|
||||
|
||||
def _convert_densities_features_to_tensor(
|
||||
self, x: Union[List[torch.Tensor], torch.Tensor], var_name: str
|
||||
):
|
||||
self, x: _TensorBatch, var_name: str
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Handle the `densities` or `features` arguments to the constructor.
|
||||
"""
|
||||
@ -236,9 +247,12 @@ class Volumes:
|
||||
)
|
||||
return x_tensor, x_shapes
|
||||
|
||||
def _vsize_translation_to_transform(
|
||||
self, voxel_size, volume_translation, batch_size
|
||||
):
|
||||
def _voxel_size_translation_to_transform(
|
||||
self,
|
||||
voxel_size: torch.Tensor,
|
||||
volume_translation: torch.Tensor,
|
||||
batch_size: int,
|
||||
) -> Transform3d:
|
||||
"""
|
||||
Converts the `voxel_size` and `volume_translation` constructor arguments
|
||||
to the internal `Transform3D` object `local_to_world_transform`.
|
||||
@ -260,7 +274,9 @@ class Volumes:
|
||||
|
||||
return local_to_world_transform
|
||||
|
||||
def _handle_voxel_size(self, voxel_size, batch_size):
|
||||
def _handle_voxel_size(
|
||||
self, voxel_size: _VoxelSize, batch_size: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Handle the `voxel_size` argument to the `Volumes` constructor.
|
||||
"""
|
||||
@ -270,8 +286,10 @@ class Volumes:
|
||||
)
|
||||
if isinstance(voxel_size, (float, int)):
|
||||
# convert a scalar to a 3-element tensor
|
||||
voxel_size = (voxel_size,) * 3
|
||||
if torch.is_tensor(voxel_size):
|
||||
voxel_size = torch.full(
|
||||
(1, 3), voxel_size, device=self.device, dtype=torch.float32
|
||||
)
|
||||
elif isinstance(voxel_size, torch.Tensor):
|
||||
if voxel_size.numel() == 1:
|
||||
# convert a single-element tensor to a 3-element one
|
||||
voxel_size = voxel_size.view(-1).repeat(3)
|
||||
@ -281,7 +299,9 @@ class Volumes:
|
||||
voxel_size = voxel_size.repeat(1, 3)
|
||||
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
|
||||
|
||||
def _handle_volume_translation(self, translation, batch_size):
|
||||
def _handle_volume_translation(
|
||||
self, translation: _Translation, batch_size: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Handle the `volume_translation` argument to the `Volumes` constructor.
|
||||
"""
|
||||
@ -292,20 +312,19 @@ class Volumes:
|
||||
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
|
||||
|
||||
def _convert_volume_property_to_tensor(
|
||||
self, x, batch_size, err_msg
|
||||
self, x: _Vector, batch_size: int, err_msg: str
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Handle the `volume_translation` or `voxel_size` argument to
|
||||
the Volumes constructor.
|
||||
Return a tensor of shape (N, 3) where N is the batch_size or 1
|
||||
if batch_size is None.
|
||||
Return a tensor of shape (N, 3) where N is the batch_size.
|
||||
"""
|
||||
if isinstance(x, (list, tuple)):
|
||||
if len(x) != 3:
|
||||
raise ValueError(err_msg)
|
||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
|
||||
x = x.repeat((batch_size, 1))
|
||||
elif torch.is_tensor(x):
|
||||
elif isinstance(x, torch.Tensor):
|
||||
ok = (
|
||||
(x.shape[0] == 1 and x.shape[1] == 3)
|
||||
or (x.shape[0] == 3 and len(x.shape) == 1)
|
||||
@ -544,13 +563,14 @@ class Volumes:
|
||||
list of tensors of features of shape (dim_i, D_i, H_i, W_i)
|
||||
or `None` for feature-less volumes.
|
||||
"""
|
||||
if self._features is None:
|
||||
features_ = self.features()
|
||||
if features_ is None:
|
||||
# No features provided so return None
|
||||
# pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
|
||||
return None
|
||||
return self._features_densities_list(self.features())
|
||||
return self._features_densities_list(features_)
|
||||
|
||||
def _features_densities_list(self, x) -> List[torch.Tensor]:
|
||||
def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""
|
||||
Retrieve the list representation of features/densities.
|
||||
|
||||
@ -577,7 +597,9 @@ class Volumes:
|
||||
"""
|
||||
return self._grid_sizes
|
||||
|
||||
def update_padded(self, new_densities, new_features=None) -> "Volumes":
|
||||
def update_padded(
|
||||
self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
|
||||
) -> "Volumes":
|
||||
"""
|
||||
Returns a Volumes structure with updated padded tensors and copies of
|
||||
the auxiliary tensors `self._local_to_world_transform`,
|
||||
@ -600,13 +622,13 @@ class Volumes:
|
||||
new._set_features(new_features)
|
||||
return new
|
||||
|
||||
def _set_features(self, features):
|
||||
self._set_densities_features(features, "features")
|
||||
def _set_features(self, features: _TensorBatch) -> None:
|
||||
self._set_densities_features("features", features)
|
||||
|
||||
def _set_densities(self, densities):
|
||||
self._set_densities_features(densities, "densities")
|
||||
def _set_densities(self, densities: _TensorBatch) -> None:
|
||||
self._set_densities_features("densities", densities)
|
||||
|
||||
def _set_densities_features(self, x, var_name):
|
||||
def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None:
|
||||
x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
|
||||
if x_tensor.device != self.device:
|
||||
raise ValueError(
|
||||
@ -630,8 +652,8 @@ class Volumes:
|
||||
|
||||
def _set_local_to_world_transform(
|
||||
self,
|
||||
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0,
|
||||
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0),
|
||||
voxel_size: _VoxelSize = 1.0,
|
||||
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
||||
):
|
||||
"""
|
||||
Sets the internal representation of the transformation between the
|
||||
@ -658,7 +680,7 @@ class Volumes:
|
||||
volume_translation = self._handle_volume_translation(
|
||||
volume_translation, len(self)
|
||||
)
|
||||
self._local_to_world_transform = self._vsize_translation_to_transform(
|
||||
self._local_to_world_transform = self._voxel_size_translation_to_transform(
|
||||
voxel_size, volume_translation, len(self)
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user