mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
some pointcloud typing
Summary: Make clear that features_padded() etc can return None Reviewed By: patricklabatut Differential Revision: D31795088 fbshipit-source-id: 7b0bbb6f3b7ad7f7b6e6a727129537af1d1873af
This commit is contained in:
parent
73a14d7266
commit
bfeb82efa3
@ -5,7 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from itertools import zip_longest
|
||||
from typing import Sequence, Union
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -240,7 +240,9 @@ class Pointclouds:
|
||||
if features_C is not None:
|
||||
self._C = features_C
|
||||
|
||||
def _parse_auxiliary_input(self, aux_input):
|
||||
def _parse_auxiliary_input(
|
||||
self, aux_input
|
||||
) -> Tuple[Optional[List[torch.Tensor]], Optional[torch.Tensor], Optional[int]]:
|
||||
"""
|
||||
Interpret the auxiliary inputs (normals, features) given to __init__.
|
||||
|
||||
@ -323,24 +325,26 @@ class Pointclouds:
|
||||
Pointclouds object with selected clouds. The tensors are not cloned.
|
||||
"""
|
||||
normals, features = None, None
|
||||
normals_list = self.normals_list()
|
||||
features_list = self.features_list()
|
||||
if isinstance(index, int):
|
||||
points = [self.points_list()[index]]
|
||||
if self.normals_list() is not None:
|
||||
normals = [self.normals_list()[index]]
|
||||
if self.features_list() is not None:
|
||||
features = [self.features_list()[index]]
|
||||
if normals_list is not None:
|
||||
normals = [normals_list[index]]
|
||||
if features_list is not None:
|
||||
features = [features_list[index]]
|
||||
elif isinstance(index, slice):
|
||||
points = self.points_list()[index]
|
||||
if self.normals_list() is not None:
|
||||
normals = self.normals_list()[index]
|
||||
if self.features_list() is not None:
|
||||
features = self.features_list()[index]
|
||||
if normals_list is not None:
|
||||
normals = normals_list[index]
|
||||
if features_list is not None:
|
||||
features = features_list[index]
|
||||
elif isinstance(index, list):
|
||||
points = [self.points_list()[i] for i in index]
|
||||
if self.normals_list() is not None:
|
||||
normals = [self.normals_list()[i] for i in index]
|
||||
if self.features_list() is not None:
|
||||
features = [self.features_list()[i] for i in index]
|
||||
if normals_list is not None:
|
||||
normals = [normals_list[i] for i in index]
|
||||
if features_list is not None:
|
||||
features = [features_list[i] for i in index]
|
||||
elif isinstance(index, torch.Tensor):
|
||||
if index.dim() != 1 or index.dtype.is_floating_point:
|
||||
raise IndexError(index)
|
||||
@ -351,10 +355,10 @@ class Pointclouds:
|
||||
index = index.squeeze(1) if index.numel() > 0 else index
|
||||
index = index.tolist()
|
||||
points = [self.points_list()[i] for i in index]
|
||||
if self.normals_list() is not None:
|
||||
normals = [self.normals_list()[i] for i in index]
|
||||
if self.features_list() is not None:
|
||||
features = [self.features_list()[i] for i in index]
|
||||
if normals_list is not None:
|
||||
normals = [normals_list[i] for i in index]
|
||||
if features_list is not None:
|
||||
features = [features_list[i] for i in index]
|
||||
else:
|
||||
raise IndexError(index)
|
||||
|
||||
@ -369,7 +373,7 @@ class Pointclouds:
|
||||
"""
|
||||
return self._N == 0 or self.valid.eq(False).all()
|
||||
|
||||
def points_list(self):
|
||||
def points_list(self) -> List[torch.Tensor]:
|
||||
"""
|
||||
Get the list representation of the points.
|
||||
|
||||
@ -388,9 +392,10 @@ class Pointclouds:
|
||||
self._points_list = points_list
|
||||
return self._points_list
|
||||
|
||||
def normals_list(self):
|
||||
def normals_list(self) -> Optional[List[torch.Tensor]]:
|
||||
"""
|
||||
Get the list representation of the normals.
|
||||
Get the list representation of the normals,
|
||||
or None if there are no normals.
|
||||
|
||||
Returns:
|
||||
list of tensors of normals of shape (P_n, 3).
|
||||
@ -404,9 +409,10 @@ class Pointclouds:
|
||||
)
|
||||
return self._normals_list
|
||||
|
||||
def features_list(self):
|
||||
def features_list(self) -> Optional[List[torch.Tensor]]:
|
||||
"""
|
||||
Get the list representation of the features.
|
||||
Get the list representation of the features,
|
||||
or None if there are no features.
|
||||
|
||||
Returns:
|
||||
list of tensors of features of shape (P_n, C).
|
||||
@ -420,7 +426,7 @@ class Pointclouds:
|
||||
)
|
||||
return self._features_list
|
||||
|
||||
def points_packed(self):
|
||||
def points_packed(self) -> torch.Tensor:
|
||||
"""
|
||||
Get the packed representation of the points.
|
||||
|
||||
@ -430,22 +436,24 @@ class Pointclouds:
|
||||
self._compute_packed()
|
||||
return self._points_packed
|
||||
|
||||
def normals_packed(self):
|
||||
def normals_packed(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get the packed representation of the normals.
|
||||
|
||||
Returns:
|
||||
tensor of normals of shape (sum(P_n), 3).
|
||||
tensor of normals of shape (sum(P_n), 3),
|
||||
or None if there are no normals.
|
||||
"""
|
||||
self._compute_packed()
|
||||
return self._normals_packed
|
||||
|
||||
def features_packed(self):
|
||||
def features_packed(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get the packed representation of the features.
|
||||
|
||||
Returns:
|
||||
tensor of features of shape (sum(P_n), C).
|
||||
tensor of features of shape (sum(P_n), C),
|
||||
or None if there are no features
|
||||
"""
|
||||
self._compute_packed()
|
||||
return self._features_packed
|
||||
@ -483,7 +491,7 @@ class Pointclouds:
|
||||
"""
|
||||
return self._num_points_per_cloud
|
||||
|
||||
def points_padded(self):
|
||||
def points_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
Get the padded representation of the points.
|
||||
|
||||
@ -493,9 +501,10 @@ class Pointclouds:
|
||||
self._compute_padded()
|
||||
return self._points_padded
|
||||
|
||||
def normals_padded(self):
|
||||
def normals_padded(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get the padded representation of the normals.
|
||||
Get the padded representation of the normals,
|
||||
or None if there are no normals.
|
||||
|
||||
Returns:
|
||||
tensor of normals of shape (N, max(P_n), 3).
|
||||
@ -503,9 +512,10 @@ class Pointclouds:
|
||||
self._compute_padded()
|
||||
return self._normals_padded
|
||||
|
||||
def features_padded(self):
|
||||
def features_padded(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Get the padded representation of the features.
|
||||
Get the padded representation of the features,
|
||||
or None if there are no features.
|
||||
|
||||
Returns:
|
||||
tensor of features of shape (N, max(P_n), 3).
|
||||
@ -562,16 +572,18 @@ class Pointclouds:
|
||||
pad_value=0.0,
|
||||
equisized=self.equisized,
|
||||
)
|
||||
if self.normals_list() is not None:
|
||||
normals_list = self.normals_list()
|
||||
if normals_list is not None:
|
||||
self._normals_padded = struct_utils.list_to_padded(
|
||||
self.normals_list(),
|
||||
normals_list,
|
||||
(self._P, 3),
|
||||
pad_value=0.0,
|
||||
equisized=self.equisized,
|
||||
)
|
||||
if self.features_list() is not None:
|
||||
features_list = self.features_list()
|
||||
if features_list is not None:
|
||||
self._features_padded = struct_utils.list_to_padded(
|
||||
self.features_list(),
|
||||
features_list,
|
||||
(self._P, self._C),
|
||||
pad_value=0.0,
|
||||
equisized=self.equisized,
|
||||
@ -772,10 +784,12 @@ class Pointclouds:
|
||||
)
|
||||
points = self.points_list()[index]
|
||||
normals, features = None, None
|
||||
if self.normals_list() is not None:
|
||||
normals = self.normals_list()[index]
|
||||
if self.features_list() is not None:
|
||||
features = self.features_list()[index]
|
||||
normals_list = self.normals_list()
|
||||
if normals_list is not None:
|
||||
normals = normals_list[index]
|
||||
features_list = self.features_list()
|
||||
if features_list is not None:
|
||||
features = features_list[index]
|
||||
return points, normals, features
|
||||
|
||||
# TODO(nikhilar) Move function to a utils file.
|
||||
@ -1022,13 +1036,15 @@ class Pointclouds:
|
||||
new_points_list, new_normals_list, new_features_list = [], None, None
|
||||
for points in self.points_list():
|
||||
new_points_list.extend(points.clone() for _ in range(N))
|
||||
if self.normals_list() is not None:
|
||||
normals_list = self.normals_list()
|
||||
if normals_list is not None:
|
||||
new_normals_list = []
|
||||
for normals in self.normals_list():
|
||||
for normals in normals_list:
|
||||
new_normals_list.extend(normals.clone() for _ in range(N))
|
||||
if self.features_list() is not None:
|
||||
features_list = self.features_list()
|
||||
if features_list is not None:
|
||||
new_features_list = []
|
||||
for features in self.features_list():
|
||||
for features in features_list:
|
||||
new_features_list.extend(features.clone() for _ in range(N))
|
||||
return self.__class__(
|
||||
points=new_points_list, normals=new_normals_list, features=new_features_list
|
||||
|
Loading…
x
Reference in New Issue
Block a user