some pointcloud typing

Summary: Make clear that features_padded() etc can return None

Reviewed By: patricklabatut

Differential Revision: D31795088

fbshipit-source-id: 7b0bbb6f3b7ad7f7b6e6a727129537af1d1873af
This commit is contained in:
Jeremy Reizenstein 2021-10-28 04:52:53 -07:00 committed by Facebook GitHub Bot
parent 73a14d7266
commit bfeb82efa3

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
from itertools import zip_longest
from typing import Sequence, Union
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@ -240,7 +240,9 @@ class Pointclouds:
if features_C is not None:
self._C = features_C
def _parse_auxiliary_input(self, aux_input):
def _parse_auxiliary_input(
self, aux_input
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor], Optional[int]]:
"""
Interpret the auxiliary inputs (normals, features) given to __init__.
@ -323,24 +325,26 @@ class Pointclouds:
Pointclouds object with selected clouds. The tensors are not cloned.
"""
normals, features = None, None
normals_list = self.normals_list()
features_list = self.features_list()
if isinstance(index, int):
points = [self.points_list()[index]]
if self.normals_list() is not None:
normals = [self.normals_list()[index]]
if self.features_list() is not None:
features = [self.features_list()[index]]
if normals_list is not None:
normals = [normals_list[index]]
if features_list is not None:
features = [features_list[index]]
elif isinstance(index, slice):
points = self.points_list()[index]
if self.normals_list() is not None:
normals = self.normals_list()[index]
if self.features_list() is not None:
features = self.features_list()[index]
if normals_list is not None:
normals = normals_list[index]
if features_list is not None:
features = features_list[index]
elif isinstance(index, list):
points = [self.points_list()[i] for i in index]
if self.normals_list() is not None:
normals = [self.normals_list()[i] for i in index]
if self.features_list() is not None:
features = [self.features_list()[i] for i in index]
if normals_list is not None:
normals = [normals_list[i] for i in index]
if features_list is not None:
features = [features_list[i] for i in index]
elif isinstance(index, torch.Tensor):
if index.dim() != 1 or index.dtype.is_floating_point:
raise IndexError(index)
@ -351,10 +355,10 @@ class Pointclouds:
index = index.squeeze(1) if index.numel() > 0 else index
index = index.tolist()
points = [self.points_list()[i] for i in index]
if self.normals_list() is not None:
normals = [self.normals_list()[i] for i in index]
if self.features_list() is not None:
features = [self.features_list()[i] for i in index]
if normals_list is not None:
normals = [normals_list[i] for i in index]
if features_list is not None:
features = [features_list[i] for i in index]
else:
raise IndexError(index)
@ -369,7 +373,7 @@ class Pointclouds:
"""
return self._N == 0 or self.valid.eq(False).all()
def points_list(self):
def points_list(self) -> List[torch.Tensor]:
"""
Get the list representation of the points.
@ -388,9 +392,10 @@ class Pointclouds:
self._points_list = points_list
return self._points_list
def normals_list(self):
def normals_list(self) -> Optional[List[torch.Tensor]]:
"""
Get the list representation of the normals.
Get the list representation of the normals,
or None if there are no normals.
Returns:
list of tensors of normals of shape (P_n, 3).
@ -404,9 +409,10 @@ class Pointclouds:
)
return self._normals_list
def features_list(self):
def features_list(self) -> Optional[List[torch.Tensor]]:
"""
Get the list representation of the features.
Get the list representation of the features,
or None if there are no features.
Returns:
list of tensors of features of shape (P_n, C).
@ -420,7 +426,7 @@ class Pointclouds:
)
return self._features_list
def points_packed(self):
def points_packed(self) -> torch.Tensor:
"""
Get the packed representation of the points.
@ -430,22 +436,24 @@ class Pointclouds:
self._compute_packed()
return self._points_packed
def normals_packed(self):
def normals_packed(self) -> Optional[torch.Tensor]:
"""
Get the packed representation of the normals.
Returns:
tensor of normals of shape (sum(P_n), 3).
tensor of normals of shape (sum(P_n), 3),
or None if there are no normals.
"""
self._compute_packed()
return self._normals_packed
def features_packed(self):
def features_packed(self) -> Optional[torch.Tensor]:
"""
Get the packed representation of the features.
Returns:
tensor of features of shape (sum(P_n), C).
tensor of features of shape (sum(P_n), C),
or None if there are no features
"""
self._compute_packed()
return self._features_packed
@ -483,7 +491,7 @@ class Pointclouds:
"""
return self._num_points_per_cloud
def points_padded(self):
def points_padded(self) -> torch.Tensor:
"""
Get the padded representation of the points.
@ -493,9 +501,10 @@ class Pointclouds:
self._compute_padded()
return self._points_padded
def normals_padded(self):
def normals_padded(self) -> Optional[torch.Tensor]:
"""
Get the padded representation of the normals.
Get the padded representation of the normals,
or None if there are no normals.
Returns:
tensor of normals of shape (N, max(P_n), 3).
@ -503,9 +512,10 @@ class Pointclouds:
self._compute_padded()
return self._normals_padded
def features_padded(self):
def features_padded(self) -> Optional[torch.Tensor]:
"""
Get the padded representation of the features.
Get the padded representation of the features,
or None if there are no features.
Returns:
tensor of features of shape (N, max(P_n), 3).
@ -562,16 +572,18 @@ class Pointclouds:
pad_value=0.0,
equisized=self.equisized,
)
if self.normals_list() is not None:
normals_list = self.normals_list()
if normals_list is not None:
self._normals_padded = struct_utils.list_to_padded(
self.normals_list(),
normals_list,
(self._P, 3),
pad_value=0.0,
equisized=self.equisized,
)
if self.features_list() is not None:
features_list = self.features_list()
if features_list is not None:
self._features_padded = struct_utils.list_to_padded(
self.features_list(),
features_list,
(self._P, self._C),
pad_value=0.0,
equisized=self.equisized,
@ -772,10 +784,12 @@ class Pointclouds:
)
points = self.points_list()[index]
normals, features = None, None
if self.normals_list() is not None:
normals = self.normals_list()[index]
if self.features_list() is not None:
features = self.features_list()[index]
normals_list = self.normals_list()
if normals_list is not None:
normals = normals_list[index]
features_list = self.features_list()
if features_list is not None:
features = features_list[index]
return points, normals, features
# TODO(nikhilar) Move function to a utils file.
@ -1022,13 +1036,15 @@ class Pointclouds:
new_points_list, new_normals_list, new_features_list = [], None, None
for points in self.points_list():
new_points_list.extend(points.clone() for _ in range(N))
if self.normals_list() is not None:
normals_list = self.normals_list()
if normals_list is not None:
new_normals_list = []
for normals in self.normals_list():
for normals in normals_list:
new_normals_list.extend(normals.clone() for _ in range(N))
if self.features_list() is not None:
features_list = self.features_list()
if features_list is not None:
new_features_list = []
for features in self.features_list():
for features in features_list:
new_features_list.extend(features.clone() for _ in range(N))
return self.__class__(
points=new_points_list, normals=new_normals_list, features=new_features_list