diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 7c9cd4cc..d23e74f8 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from itertools import zip_longest -from typing import Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -240,7 +240,9 @@ class Pointclouds: if features_C is not None: self._C = features_C - def _parse_auxiliary_input(self, aux_input): + def _parse_auxiliary_input( + self, aux_input + ) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor], Optional[int]]: """ Interpret the auxiliary inputs (normals, features) given to __init__. @@ -323,24 +325,26 @@ class Pointclouds: Pointclouds object with selected clouds. The tensors are not cloned. """ normals, features = None, None + normals_list = self.normals_list() + features_list = self.features_list() if isinstance(index, int): points = [self.points_list()[index]] - if self.normals_list() is not None: - normals = [self.normals_list()[index]] - if self.features_list() is not None: - features = [self.features_list()[index]] + if normals_list is not None: + normals = [normals_list[index]] + if features_list is not None: + features = [features_list[index]] elif isinstance(index, slice): points = self.points_list()[index] - if self.normals_list() is not None: - normals = self.normals_list()[index] - if self.features_list() is not None: - features = self.features_list()[index] + if normals_list is not None: + normals = normals_list[index] + if features_list is not None: + features = features_list[index] elif isinstance(index, list): points = [self.points_list()[i] for i in index] - if self.normals_list() is not None: - normals = [self.normals_list()[i] for i in index] - if self.features_list() is not None: - features = [self.features_list()[i] for i in index] + if normals_list is not None: + normals = [normals_list[i] for i in index] + if features_list is not None: + features = [features_list[i] for i in index] elif isinstance(index, torch.Tensor): if index.dim() != 1 or index.dtype.is_floating_point: raise IndexError(index) @@ -351,10 +355,10 @@ class Pointclouds: index = index.squeeze(1) if index.numel() > 0 else index index = index.tolist() points = [self.points_list()[i] for i in index] - if self.normals_list() is not None: - normals = [self.normals_list()[i] for i in index] - if self.features_list() is not None: - features = [self.features_list()[i] for i in index] + if normals_list is not None: + normals = [normals_list[i] for i in index] + if features_list is not None: + features = [features_list[i] for i in index] else: raise IndexError(index) @@ -369,7 +373,7 @@ class Pointclouds: """ return self._N == 0 or self.valid.eq(False).all() - def points_list(self): + def points_list(self) -> List[torch.Tensor]: """ Get the list representation of the points. @@ -388,9 +392,10 @@ class Pointclouds: self._points_list = points_list return self._points_list - def normals_list(self): + def normals_list(self) -> Optional[List[torch.Tensor]]: """ - Get the list representation of the normals. + Get the list representation of the normals, + or None if there are no normals. Returns: list of tensors of normals of shape (P_n, 3). @@ -404,9 +409,10 @@ class Pointclouds: ) return self._normals_list - def features_list(self): + def features_list(self) -> Optional[List[torch.Tensor]]: """ - Get the list representation of the features. + Get the list representation of the features, + or None if there are no features. Returns: list of tensors of features of shape (P_n, C). @@ -420,7 +426,7 @@ class Pointclouds: ) return self._features_list - def points_packed(self): + def points_packed(self) -> torch.Tensor: """ Get the packed representation of the points. @@ -430,22 +436,24 @@ class Pointclouds: self._compute_packed() return self._points_packed - def normals_packed(self): + def normals_packed(self) -> Optional[torch.Tensor]: """ Get the packed representation of the normals. Returns: - tensor of normals of shape (sum(P_n), 3). + tensor of normals of shape (sum(P_n), 3), + or None if there are no normals. """ self._compute_packed() return self._normals_packed - def features_packed(self): + def features_packed(self) -> Optional[torch.Tensor]: """ Get the packed representation of the features. Returns: - tensor of features of shape (sum(P_n), C). + tensor of features of shape (sum(P_n), C), + or None if there are no features """ self._compute_packed() return self._features_packed @@ -483,7 +491,7 @@ class Pointclouds: """ return self._num_points_per_cloud - def points_padded(self): + def points_padded(self) -> torch.Tensor: """ Get the padded representation of the points. @@ -493,9 +501,10 @@ class Pointclouds: self._compute_padded() return self._points_padded - def normals_padded(self): + def normals_padded(self) -> Optional[torch.Tensor]: """ - Get the padded representation of the normals. + Get the padded representation of the normals, + or None if there are no normals. Returns: tensor of normals of shape (N, max(P_n), 3). @@ -503,9 +512,10 @@ class Pointclouds: self._compute_padded() return self._normals_padded - def features_padded(self): + def features_padded(self) -> Optional[torch.Tensor]: """ - Get the padded representation of the features. + Get the padded representation of the features, + or None if there are no features. Returns: tensor of features of shape (N, max(P_n), 3). @@ -562,16 +572,18 @@ class Pointclouds: pad_value=0.0, equisized=self.equisized, ) - if self.normals_list() is not None: + normals_list = self.normals_list() + if normals_list is not None: self._normals_padded = struct_utils.list_to_padded( - self.normals_list(), + normals_list, (self._P, 3), pad_value=0.0, equisized=self.equisized, ) - if self.features_list() is not None: + features_list = self.features_list() + if features_list is not None: self._features_padded = struct_utils.list_to_padded( - self.features_list(), + features_list, (self._P, self._C), pad_value=0.0, equisized=self.equisized, @@ -772,10 +784,12 @@ class Pointclouds: ) points = self.points_list()[index] normals, features = None, None - if self.normals_list() is not None: - normals = self.normals_list()[index] - if self.features_list() is not None: - features = self.features_list()[index] + normals_list = self.normals_list() + if normals_list is not None: + normals = normals_list[index] + features_list = self.features_list() + if features_list is not None: + features = features_list[index] return points, normals, features # TODO(nikhilar) Move function to a utils file. @@ -1022,13 +1036,15 @@ class Pointclouds: new_points_list, new_normals_list, new_features_list = [], None, None for points in self.points_list(): new_points_list.extend(points.clone() for _ in range(N)) - if self.normals_list() is not None: + normals_list = self.normals_list() + if normals_list is not None: new_normals_list = [] - for normals in self.normals_list(): + for normals in normals_list: new_normals_list.extend(normals.clone() for _ in range(N)) - if self.features_list() is not None: + features_list = self.features_list() + if features_list is not None: new_features_list = [] - for features in self.features_list(): + for features in features_list: new_features_list.extend(features.clone() for _ in range(N)) return self.__class__( points=new_points_list, normals=new_normals_list, features=new_features_list