mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-04 04:42:49 +08:00
Improve volumes type annotations
Summary: Improve type annotations for volumes and remove a few pyre fixmes Reviewed By: nikhilaravi Differential Revision: D28943371 fbshipit-source-id: ca2b7a50d72a392910e65cee5e564f34523414d2
This commit is contained in:
parent
a15c33a3cc
commit
02650672f6
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from typing import List, Sequence, Union
|
from typing import List, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ Util functions for points/verts/faces/volumes.
|
|||||||
|
|
||||||
|
|
||||||
def list_to_padded(
|
def list_to_padded(
|
||||||
x: List[torch.Tensor],
|
x: Union[List[torch.Tensor], Tuple[torch.Tensor]],
|
||||||
pad_size: Union[Sequence[int], None] = None,
|
pad_size: Union[Sequence[int], None] = None,
|
||||||
pad_value: float = 0.0,
|
pad_value: float = 0.0,
|
||||||
equisized: bool = False,
|
equisized: bool = False,
|
||||||
@ -66,7 +66,6 @@ def list_to_padded(
|
|||||||
pad_dims = pad_size
|
pad_dims = pad_size
|
||||||
|
|
||||||
N = len(x)
|
N = len(x)
|
||||||
# pyre-fixme[16]: `Tensor` has no attribute `new_full`.
|
|
||||||
x_padded = x[0].new_full((N, *pad_dims), pad_value)
|
x_padded = x[0].new_full((N, *pad_dims), pad_value)
|
||||||
for i, y in enumerate(x):
|
for i, y in enumerate(x):
|
||||||
if len(y) > 0:
|
if len(y) > 0:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -8,6 +9,16 @@ from ..transforms import Scale, Transform3d
|
|||||||
from . import utils as struct_utils
|
from . import utils as struct_utils
|
||||||
|
|
||||||
|
|
||||||
|
_Scalar = Union[int, float]
|
||||||
|
_Vector = Union[torch.Tensor, Tuple[_Scalar, ...], List[_Scalar]]
|
||||||
|
_ScalarOrVector = Union[_Scalar, _Vector]
|
||||||
|
|
||||||
|
_VoxelSize = _ScalarOrVector
|
||||||
|
_Translation = _Vector
|
||||||
|
|
||||||
|
_TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
class Volumes:
|
class Volumes:
|
||||||
"""
|
"""
|
||||||
This class provides functions for working with batches of volumetric grids
|
This class provides functions for working with batches of volumetric grids
|
||||||
@ -135,10 +146,10 @@ class Volumes:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
densities: Union[List[torch.Tensor], torch.Tensor],
|
densities: _TensorBatch,
|
||||||
features: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
features: Optional[_TensorBatch] = None,
|
||||||
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0,
|
voxel_size: _VoxelSize = 1.0,
|
||||||
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0),
|
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -193,8 +204,8 @@ class Volumes:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _convert_densities_features_to_tensor(
|
def _convert_densities_features_to_tensor(
|
||||||
self, x: Union[List[torch.Tensor], torch.Tensor], var_name: str
|
self, x: _TensorBatch, var_name: str
|
||||||
):
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Handle the `densities` or `features` arguments to the constructor.
|
Handle the `densities` or `features` arguments to the constructor.
|
||||||
"""
|
"""
|
||||||
@ -236,9 +247,12 @@ class Volumes:
|
|||||||
)
|
)
|
||||||
return x_tensor, x_shapes
|
return x_tensor, x_shapes
|
||||||
|
|
||||||
def _vsize_translation_to_transform(
|
def _voxel_size_translation_to_transform(
|
||||||
self, voxel_size, volume_translation, batch_size
|
self,
|
||||||
):
|
voxel_size: torch.Tensor,
|
||||||
|
volume_translation: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
) -> Transform3d:
|
||||||
"""
|
"""
|
||||||
Converts the `voxel_size` and `volume_translation` constructor arguments
|
Converts the `voxel_size` and `volume_translation` constructor arguments
|
||||||
to the internal `Transform3D` object `local_to_world_transform`.
|
to the internal `Transform3D` object `local_to_world_transform`.
|
||||||
@ -260,7 +274,9 @@ class Volumes:
|
|||||||
|
|
||||||
return local_to_world_transform
|
return local_to_world_transform
|
||||||
|
|
||||||
def _handle_voxel_size(self, voxel_size, batch_size):
|
def _handle_voxel_size(
|
||||||
|
self, voxel_size: _VoxelSize, batch_size: int
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Handle the `voxel_size` argument to the `Volumes` constructor.
|
Handle the `voxel_size` argument to the `Volumes` constructor.
|
||||||
"""
|
"""
|
||||||
@ -270,8 +286,10 @@ class Volumes:
|
|||||||
)
|
)
|
||||||
if isinstance(voxel_size, (float, int)):
|
if isinstance(voxel_size, (float, int)):
|
||||||
# convert a scalar to a 3-element tensor
|
# convert a scalar to a 3-element tensor
|
||||||
voxel_size = (voxel_size,) * 3
|
voxel_size = torch.full(
|
||||||
if torch.is_tensor(voxel_size):
|
(1, 3), voxel_size, device=self.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
elif isinstance(voxel_size, torch.Tensor):
|
||||||
if voxel_size.numel() == 1:
|
if voxel_size.numel() == 1:
|
||||||
# convert a single-element tensor to a 3-element one
|
# convert a single-element tensor to a 3-element one
|
||||||
voxel_size = voxel_size.view(-1).repeat(3)
|
voxel_size = voxel_size.view(-1).repeat(3)
|
||||||
@ -281,7 +299,9 @@ class Volumes:
|
|||||||
voxel_size = voxel_size.repeat(1, 3)
|
voxel_size = voxel_size.repeat(1, 3)
|
||||||
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
|
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
|
||||||
|
|
||||||
def _handle_volume_translation(self, translation, batch_size):
|
def _handle_volume_translation(
|
||||||
|
self, translation: _Translation, batch_size: int
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Handle the `volume_translation` argument to the `Volumes` constructor.
|
Handle the `volume_translation` argument to the `Volumes` constructor.
|
||||||
"""
|
"""
|
||||||
@ -292,20 +312,19 @@ class Volumes:
|
|||||||
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
|
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
|
||||||
|
|
||||||
def _convert_volume_property_to_tensor(
|
def _convert_volume_property_to_tensor(
|
||||||
self, x, batch_size, err_msg
|
self, x: _Vector, batch_size: int, err_msg: str
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Handle the `volume_translation` or `voxel_size` argument to
|
Handle the `volume_translation` or `voxel_size` argument to
|
||||||
the Volumes constructor.
|
the Volumes constructor.
|
||||||
Return a tensor of shape (N, 3) where N is the batch_size or 1
|
Return a tensor of shape (N, 3) where N is the batch_size.
|
||||||
if batch_size is None.
|
|
||||||
"""
|
"""
|
||||||
if isinstance(x, (list, tuple)):
|
if isinstance(x, (list, tuple)):
|
||||||
if len(x) != 3:
|
if len(x) != 3:
|
||||||
raise ValueError(err_msg)
|
raise ValueError(err_msg)
|
||||||
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
|
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
|
||||||
x = x.repeat((batch_size, 1))
|
x = x.repeat((batch_size, 1))
|
||||||
elif torch.is_tensor(x):
|
elif isinstance(x, torch.Tensor):
|
||||||
ok = (
|
ok = (
|
||||||
(x.shape[0] == 1 and x.shape[1] == 3)
|
(x.shape[0] == 1 and x.shape[1] == 3)
|
||||||
or (x.shape[0] == 3 and len(x.shape) == 1)
|
or (x.shape[0] == 3 and len(x.shape) == 1)
|
||||||
@ -544,13 +563,14 @@ class Volumes:
|
|||||||
list of tensors of features of shape (dim_i, D_i, H_i, W_i)
|
list of tensors of features of shape (dim_i, D_i, H_i, W_i)
|
||||||
or `None` for feature-less volumes.
|
or `None` for feature-less volumes.
|
||||||
"""
|
"""
|
||||||
if self._features is None:
|
features_ = self.features()
|
||||||
|
if features_ is None:
|
||||||
# No features provided so return None
|
# No features provided so return None
|
||||||
# pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
|
# pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
|
||||||
return None
|
return None
|
||||||
return self._features_densities_list(self.features())
|
return self._features_densities_list(features_)
|
||||||
|
|
||||||
def _features_densities_list(self, x) -> List[torch.Tensor]:
|
def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Retrieve the list representation of features/densities.
|
Retrieve the list representation of features/densities.
|
||||||
|
|
||||||
@ -577,7 +597,9 @@ class Volumes:
|
|||||||
"""
|
"""
|
||||||
return self._grid_sizes
|
return self._grid_sizes
|
||||||
|
|
||||||
def update_padded(self, new_densities, new_features=None) -> "Volumes":
|
def update_padded(
|
||||||
|
self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
|
||||||
|
) -> "Volumes":
|
||||||
"""
|
"""
|
||||||
Returns a Volumes structure with updated padded tensors and copies of
|
Returns a Volumes structure with updated padded tensors and copies of
|
||||||
the auxiliary tensors `self._local_to_world_transform`,
|
the auxiliary tensors `self._local_to_world_transform`,
|
||||||
@ -600,13 +622,13 @@ class Volumes:
|
|||||||
new._set_features(new_features)
|
new._set_features(new_features)
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def _set_features(self, features):
|
def _set_features(self, features: _TensorBatch) -> None:
|
||||||
self._set_densities_features(features, "features")
|
self._set_densities_features("features", features)
|
||||||
|
|
||||||
def _set_densities(self, densities):
|
def _set_densities(self, densities: _TensorBatch) -> None:
|
||||||
self._set_densities_features(densities, "densities")
|
self._set_densities_features("densities", densities)
|
||||||
|
|
||||||
def _set_densities_features(self, x, var_name):
|
def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None:
|
||||||
x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
|
x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
|
||||||
if x_tensor.device != self.device:
|
if x_tensor.device != self.device:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -630,8 +652,8 @@ class Volumes:
|
|||||||
|
|
||||||
def _set_local_to_world_transform(
|
def _set_local_to_world_transform(
|
||||||
self,
|
self,
|
||||||
voxel_size: Union[float, torch.Tensor, Tuple, List] = 1.0,
|
voxel_size: _VoxelSize = 1.0,
|
||||||
volume_translation: Union[torch.Tensor, Tuple, List] = (0.0, 0.0, 0.0),
|
volume_translation: _Translation = (0.0, 0.0, 0.0),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Sets the internal representation of the transformation between the
|
Sets the internal representation of the transformation between the
|
||||||
@ -658,7 +680,7 @@ class Volumes:
|
|||||||
volume_translation = self._handle_volume_translation(
|
volume_translation = self._handle_volume_translation(
|
||||||
volume_translation, len(self)
|
volume_translation, len(self)
|
||||||
)
|
)
|
||||||
self._local_to_world_transform = self._vsize_translation_to_transform(
|
self._local_to_world_transform = self._voxel_size_translation_to_transform(
|
||||||
voxel_size, volume_translation, len(self)
|
voxel_size, volume_translation, len(self)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user