some pointcloud typing

Summary: Make clear that features_padded() etc can return None

Reviewed By: patricklabatut

Differential Revision: D31795088

fbshipit-source-id: 7b0bbb6f3b7ad7f7b6e6a727129537af1d1873af
This commit is contained in:
Jeremy Reizenstein 2021-10-28 04:52:53 -07:00 committed by Facebook GitHub Bot
parent 73a14d7266
commit bfeb82efa3

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from itertools import zip_longest from itertools import zip_longest
from typing import Sequence, Union from typing import List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -240,7 +240,9 @@ class Pointclouds:
if features_C is not None: if features_C is not None:
self._C = features_C self._C = features_C
def _parse_auxiliary_input(self, aux_input): def _parse_auxiliary_input(
self, aux_input
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor], Optional[int]]:
""" """
Interpret the auxiliary inputs (normals, features) given to __init__. Interpret the auxiliary inputs (normals, features) given to __init__.
@ -323,24 +325,26 @@ class Pointclouds:
Pointclouds object with selected clouds. The tensors are not cloned. Pointclouds object with selected clouds. The tensors are not cloned.
""" """
normals, features = None, None normals, features = None, None
normals_list = self.normals_list()
features_list = self.features_list()
if isinstance(index, int): if isinstance(index, int):
points = [self.points_list()[index]] points = [self.points_list()[index]]
if self.normals_list() is not None: if normals_list is not None:
normals = [self.normals_list()[index]] normals = [normals_list[index]]
if self.features_list() is not None: if features_list is not None:
features = [self.features_list()[index]] features = [features_list[index]]
elif isinstance(index, slice): elif isinstance(index, slice):
points = self.points_list()[index] points = self.points_list()[index]
if self.normals_list() is not None: if normals_list is not None:
normals = self.normals_list()[index] normals = normals_list[index]
if self.features_list() is not None: if features_list is not None:
features = self.features_list()[index] features = features_list[index]
elif isinstance(index, list): elif isinstance(index, list):
points = [self.points_list()[i] for i in index] points = [self.points_list()[i] for i in index]
if self.normals_list() is not None: if normals_list is not None:
normals = [self.normals_list()[i] for i in index] normals = [normals_list[i] for i in index]
if self.features_list() is not None: if features_list is not None:
features = [self.features_list()[i] for i in index] features = [features_list[i] for i in index]
elif isinstance(index, torch.Tensor): elif isinstance(index, torch.Tensor):
if index.dim() != 1 or index.dtype.is_floating_point: if index.dim() != 1 or index.dtype.is_floating_point:
raise IndexError(index) raise IndexError(index)
@ -351,10 +355,10 @@ class Pointclouds:
index = index.squeeze(1) if index.numel() > 0 else index index = index.squeeze(1) if index.numel() > 0 else index
index = index.tolist() index = index.tolist()
points = [self.points_list()[i] for i in index] points = [self.points_list()[i] for i in index]
if self.normals_list() is not None: if normals_list is not None:
normals = [self.normals_list()[i] for i in index] normals = [normals_list[i] for i in index]
if self.features_list() is not None: if features_list is not None:
features = [self.features_list()[i] for i in index] features = [features_list[i] for i in index]
else: else:
raise IndexError(index) raise IndexError(index)
@ -369,7 +373,7 @@ class Pointclouds:
""" """
return self._N == 0 or self.valid.eq(False).all() return self._N == 0 or self.valid.eq(False).all()
def points_list(self): def points_list(self) -> List[torch.Tensor]:
""" """
Get the list representation of the points. Get the list representation of the points.
@ -388,9 +392,10 @@ class Pointclouds:
self._points_list = points_list self._points_list = points_list
return self._points_list return self._points_list
def normals_list(self): def normals_list(self) -> Optional[List[torch.Tensor]]:
""" """
Get the list representation of the normals. Get the list representation of the normals,
or None if there are no normals.
Returns: Returns:
list of tensors of normals of shape (P_n, 3). list of tensors of normals of shape (P_n, 3).
@ -404,9 +409,10 @@ class Pointclouds:
) )
return self._normals_list return self._normals_list
def features_list(self): def features_list(self) -> Optional[List[torch.Tensor]]:
""" """
Get the list representation of the features. Get the list representation of the features,
or None if there are no features.
Returns: Returns:
list of tensors of features of shape (P_n, C). list of tensors of features of shape (P_n, C).
@ -420,7 +426,7 @@ class Pointclouds:
) )
return self._features_list return self._features_list
def points_packed(self): def points_packed(self) -> torch.Tensor:
""" """
Get the packed representation of the points. Get the packed representation of the points.
@ -430,22 +436,24 @@ class Pointclouds:
self._compute_packed() self._compute_packed()
return self._points_packed return self._points_packed
def normals_packed(self): def normals_packed(self) -> Optional[torch.Tensor]:
""" """
Get the packed representation of the normals. Get the packed representation of the normals.
Returns: Returns:
tensor of normals of shape (sum(P_n), 3). tensor of normals of shape (sum(P_n), 3),
or None if there are no normals.
""" """
self._compute_packed() self._compute_packed()
return self._normals_packed return self._normals_packed
def features_packed(self): def features_packed(self) -> Optional[torch.Tensor]:
""" """
Get the packed representation of the features. Get the packed representation of the features.
Returns: Returns:
tensor of features of shape (sum(P_n), C). tensor of features of shape (sum(P_n), C),
or None if there are no features
""" """
self._compute_packed() self._compute_packed()
return self._features_packed return self._features_packed
@ -483,7 +491,7 @@ class Pointclouds:
""" """
return self._num_points_per_cloud return self._num_points_per_cloud
def points_padded(self): def points_padded(self) -> torch.Tensor:
""" """
Get the padded representation of the points. Get the padded representation of the points.
@ -493,9 +501,10 @@ class Pointclouds:
self._compute_padded() self._compute_padded()
return self._points_padded return self._points_padded
def normals_padded(self): def normals_padded(self) -> Optional[torch.Tensor]:
""" """
Get the padded representation of the normals. Get the padded representation of the normals,
or None if there are no normals.
Returns: Returns:
tensor of normals of shape (N, max(P_n), 3). tensor of normals of shape (N, max(P_n), 3).
@ -503,9 +512,10 @@ class Pointclouds:
self._compute_padded() self._compute_padded()
return self._normals_padded return self._normals_padded
def features_padded(self): def features_padded(self) -> Optional[torch.Tensor]:
""" """
Get the padded representation of the features. Get the padded representation of the features,
or None if there are no features.
Returns: Returns:
tensor of features of shape (N, max(P_n), 3). tensor of features of shape (N, max(P_n), 3).
@ -562,16 +572,18 @@ class Pointclouds:
pad_value=0.0, pad_value=0.0,
equisized=self.equisized, equisized=self.equisized,
) )
if self.normals_list() is not None: normals_list = self.normals_list()
if normals_list is not None:
self._normals_padded = struct_utils.list_to_padded( self._normals_padded = struct_utils.list_to_padded(
self.normals_list(), normals_list,
(self._P, 3), (self._P, 3),
pad_value=0.0, pad_value=0.0,
equisized=self.equisized, equisized=self.equisized,
) )
if self.features_list() is not None: features_list = self.features_list()
if features_list is not None:
self._features_padded = struct_utils.list_to_padded( self._features_padded = struct_utils.list_to_padded(
self.features_list(), features_list,
(self._P, self._C), (self._P, self._C),
pad_value=0.0, pad_value=0.0,
equisized=self.equisized, equisized=self.equisized,
@ -772,10 +784,12 @@ class Pointclouds:
) )
points = self.points_list()[index] points = self.points_list()[index]
normals, features = None, None normals, features = None, None
if self.normals_list() is not None: normals_list = self.normals_list()
normals = self.normals_list()[index] if normals_list is not None:
if self.features_list() is not None: normals = normals_list[index]
features = self.features_list()[index] features_list = self.features_list()
if features_list is not None:
features = features_list[index]
return points, normals, features return points, normals, features
# TODO(nikhilar) Move function to a utils file. # TODO(nikhilar) Move function to a utils file.
@ -1022,13 +1036,15 @@ class Pointclouds:
new_points_list, new_normals_list, new_features_list = [], None, None new_points_list, new_normals_list, new_features_list = [], None, None
for points in self.points_list(): for points in self.points_list():
new_points_list.extend(points.clone() for _ in range(N)) new_points_list.extend(points.clone() for _ in range(N))
if self.normals_list() is not None: normals_list = self.normals_list()
if normals_list is not None:
new_normals_list = [] new_normals_list = []
for normals in self.normals_list(): for normals in normals_list:
new_normals_list.extend(normals.clone() for _ in range(N)) new_normals_list.extend(normals.clone() for _ in range(N))
if self.features_list() is not None: features_list = self.features_list()
if features_list is not None:
new_features_list = [] new_features_list = []
for features in self.features_list(): for features in features_list:
new_features_list.extend(features.clone() for _ in range(N)) new_features_list.extend(features.clone() for _ in range(N))
return self.__class__( return self.__class__(
points=new_points_list, normals=new_normals_list, features=new_features_list points=new_points_list, normals=new_normals_list, features=new_features_list