mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Re-sync with internal repository
This commit is contained in:
parent
2480723adf
commit
3d3b2fdc46
1
pytorch3d/renderer/points/__init__.py
Normal file
1
pytorch3d/renderer/points/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
227
pytorch3d/renderer/points/rasterize_points.py
Normal file
227
pytorch3d/renderer/points/rasterize_points.py
Normal file
@ -0,0 +1,227 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
|
||||
|
||||
|
||||
# TODO(jcjohns): Support non-square images
|
||||
def rasterize_points(
|
||||
pointclouds,
|
||||
image_size: int = 256,
|
||||
radius: float = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: Optional[int] = None,
|
||||
max_points_per_bin: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Pointcloud rasterization
|
||||
|
||||
Args:
|
||||
pointclouds: A Pointclouds object representing a batch of point clouds to be
|
||||
rasterized. This is a batch of N pointclouds, where each point cloud
|
||||
can have a different number of points; the coordinates of each point
|
||||
are (x, y, z). The coordinates are expected to
|
||||
be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at
|
||||
(0, 0, 0); the x-axis goes from left-to-right, the y-axis goes from
|
||||
top-to-bottom, and the z-axis goes from back-to-front.
|
||||
image_size: Integer giving the resolution of the rasterized image
|
||||
radius (Optional): Float giving the radius (in NDC units) of the disk to
|
||||
be rasterized for each point.
|
||||
points_per_pixel (Optional): We will keep track of this many points per
|
||||
pixel, returning the nearest points_per_pixel points along the z-axis
|
||||
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
|
||||
bin_size=0 uses naive rasterization; setting bin_size=None attempts to
|
||||
set it heuristically based on the shape of the input. This should not
|
||||
affect the output, but can affect the speed of the forward pass.
|
||||
points_per_bin: Only applicable when using coarse-to-fine rasterization
|
||||
(bin_size > 0); this is the maxiumum number of points allowed within each
|
||||
bin. If more than this many points actually fall into a bin, an error
|
||||
will be raised. This should not affect the output values, but can affect
|
||||
the memory usage in the forward pass.
|
||||
|
||||
Returns:
|
||||
3-element tuple containing
|
||||
|
||||
- **idx**: int32 Tensor of shape (N, image_size, image_size, points_per_pixel)
|
||||
giving the indices of the nearest points at each pixel, in ascending
|
||||
z-order. Concretely `idx[n, y, x, k] = p` means that `points[p]` is the kth
|
||||
closest point (along the z-direction) to pixel (y, x) - note that points
|
||||
represents the packed points of shape (P, 3).
|
||||
Pixels that are hit by fewer than points_per_pixel are padded with -1.
|
||||
- **zbuf**: Tensor of shape (N, image_size, image_size, points_per_pixel)
|
||||
giving the z-coordinates of the nearest points at each pixel, sorted in
|
||||
z-order. Concretely, if `idx[n, y, x, k] = p` then
|
||||
`zbuf[n, y, x, k] = points[n, p, 2]`. Pixels hit by fewer than
|
||||
points_per_pixel are padded with -1
|
||||
- **dists2**: Tensor of shape (N, image_size, image_size, points_per_pixel)
|
||||
giving the squared Euclidean distance (in NDC units) in the x/y plane
|
||||
for each point closest to the pixel. Concretely if `idx[n, y, x, k] = p`
|
||||
then `dists[n, y, x, k]` is the squared distance between the pixel (y, x)
|
||||
and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer
|
||||
than points_per_pixel are padded with -1.
|
||||
"""
|
||||
points_packed = pointclouds.points_packed()
|
||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds.num_points_per_cloud()
|
||||
|
||||
if bin_size is None:
|
||||
if not points_packed.is_cuda:
|
||||
# Binned CPU rasterization not fully implemented
|
||||
bin_size = 0
|
||||
else:
|
||||
# TODO: These heuristics are not well-thought out!
|
||||
if image_size <= 64:
|
||||
bin_size = 8
|
||||
elif image_size <= 256:
|
||||
bin_size = 16
|
||||
elif image_size <= 512:
|
||||
bin_size = 32
|
||||
elif image_size <= 1024:
|
||||
bin_size = 64
|
||||
|
||||
if max_points_per_bin is None:
|
||||
max_points_per_bin = int(max(10000, points_packed.shape[0] / 5))
|
||||
|
||||
# Function.apply cannot take keyword args, so we handle defaults in this
|
||||
# wrapper and call apply with positional args only
|
||||
return _RasterizePoints.apply(
|
||||
points_packed,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
points_per_pixel,
|
||||
bin_size,
|
||||
max_points_per_bin,
|
||||
)
|
||||
|
||||
|
||||
class _RasterizePoints(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
points, # (P, 3)
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size: int = 256,
|
||||
radius: float = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: int = 0,
|
||||
max_points_per_bin: int = 0,
|
||||
):
|
||||
# TODO: Add better error handling for when there are more than
|
||||
# max_points_per_bin in any bin.
|
||||
args = (
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
points_per_pixel,
|
||||
bin_size,
|
||||
max_points_per_bin,
|
||||
)
|
||||
idx, zbuf, dists = _C.rasterize_points(*args)
|
||||
ctx.save_for_backward(points, idx)
|
||||
return idx, zbuf, dists
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_idx, grad_zbuf, grad_dists):
|
||||
grad_points = None
|
||||
grad_cloud_to_packed_first_idx = None
|
||||
grad_num_points_per_cloud = None
|
||||
grad_image_size = None
|
||||
grad_radius = None
|
||||
grad_points_per_pixel = None
|
||||
grad_bin_size = None
|
||||
grad_max_points_per_bin = None
|
||||
points, idx = ctx.saved_tensors
|
||||
args = (points, idx, grad_zbuf, grad_dists)
|
||||
grad_points = _C.rasterize_points_backward(*args)
|
||||
grads = (
|
||||
grad_points,
|
||||
grad_cloud_to_packed_first_idx,
|
||||
grad_num_points_per_cloud,
|
||||
grad_image_size,
|
||||
grad_radius,
|
||||
grad_points_per_pixel,
|
||||
grad_bin_size,
|
||||
grad_max_points_per_bin,
|
||||
)
|
||||
return grads
|
||||
|
||||
|
||||
def rasterize_points_python(
|
||||
pointclouds,
|
||||
image_size: int = 256,
|
||||
radius: float = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
):
|
||||
"""
|
||||
Naive pure PyTorch implementation of pointcloud rasterization.
|
||||
|
||||
Inputs / Outputs: Same as above
|
||||
"""
|
||||
N = len(pointclouds)
|
||||
S, K = image_size, points_per_pixel
|
||||
device = pointclouds.device
|
||||
|
||||
points_packed = pointclouds.points_packed()
|
||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds.num_points_per_cloud()
|
||||
|
||||
# Intialize output tensors.
|
||||
point_idxs = torch.full(
|
||||
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
|
||||
)
|
||||
zbuf = torch.full(
|
||||
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
|
||||
)
|
||||
pix_dists = torch.full(
|
||||
(N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# NDC is from [-1, 1]. Get pixel size using specified image size.
|
||||
radius2 = radius * radius
|
||||
|
||||
# Iterate through the batch of point clouds.
|
||||
for n in range(N):
|
||||
point_start_idx = cloud_to_packed_first_idx[n]
|
||||
point_stop_idx = point_start_idx + num_points_per_cloud[n]
|
||||
|
||||
# Iterate through the horizontal lines of the image from top to bottom.
|
||||
for yi in range(S):
|
||||
# Y coordinate of one end of the image. Reverse the ordering
|
||||
# of yi so that +Y is pointing up in the image.
|
||||
yfix = S - 1 - yi
|
||||
yf = pix_to_ndc(yfix, S)
|
||||
|
||||
# Iterate through pixels on this horizontal line, left to right.
|
||||
for xi in range(S):
|
||||
# X coordinate of one end of the image. Reverse the ordering
|
||||
# of xi so that +X is pointing to the left in the image.
|
||||
xfix = S - 1 - xi
|
||||
xf = pix_to_ndc(xfix, S)
|
||||
|
||||
top_k_points = []
|
||||
# Check whether each point in the batch affects this pixel.
|
||||
for p in range(point_start_idx, point_stop_idx):
|
||||
px, py, pz = points_packed[p, :]
|
||||
if pz < 0:
|
||||
continue
|
||||
dx = px - xf
|
||||
dy = py - yf
|
||||
dist2 = dx * dx + dy * dy
|
||||
if dist2 < radius2:
|
||||
top_k_points.append((pz, p, dist2))
|
||||
top_k_points.sort()
|
||||
if len(top_k_points) > K:
|
||||
top_k_points = top_k_points[:K]
|
||||
for k, (pz, p, dist2) in enumerate(top_k_points):
|
||||
zbuf[n, yi, xi, k] = pz
|
||||
point_idxs[n, yi, xi, k] = p
|
||||
pix_dists[n, yi, xi, k] = dist2
|
||||
return point_idxs, zbuf, pix_dists
|
992
pytorch3d/structures/pointclouds.py
Normal file
992
pytorch3d/structures/pointclouds.py
Normal file
@ -0,0 +1,992 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
from . import utils as struct_utils
|
||||
|
||||
|
||||
class Pointclouds(object):
|
||||
"""
|
||||
This class provides functions for working with batches of 3d point clouds,
|
||||
and converting between representations.
|
||||
|
||||
Within Pointclouds, there are three different representations of the data.
|
||||
|
||||
List
|
||||
- only used for input as a starting point to convert to other representations.
|
||||
Padded
|
||||
- has specific batch dimension.
|
||||
Packed
|
||||
- no batch dimension.
|
||||
- has auxillary variables used to index into the padded representation.
|
||||
|
||||
Example
|
||||
|
||||
Input list of points = [[P_1], [P_2], ... , [P_N]]
|
||||
where P_1, ... , P_N are the number of points in each cloud and N is the
|
||||
number of clouds.
|
||||
|
||||
# SPHINX IGNORE
|
||||
List | Padded | Packed
|
||||
---------------------------|-------------------------|------------------------
|
||||
[[P_1], ... , [P_N]] | size = (N, max(P_n), 3) | size = (sum(P_n), 3)
|
||||
| |
|
||||
Example for locations | |
|
||||
or colors: | |
|
||||
| |
|
||||
P_1 = 3, P_2 = 4, P_3 = 5 | size = (3, 5, 3) | size = (12, 3)
|
||||
| |
|
||||
List([ | tensor([ | tensor([
|
||||
[ | [ | [0.1, 0.3, 0.5],
|
||||
[0.1, 0.3, 0.5], | [0.1, 0.3, 0.5], | [0.5, 0.2, 0.1],
|
||||
[0.5, 0.2, 0.1], | [0.5, 0.2, 0.1], | [0.6, 0.8, 0.7],
|
||||
[0.6, 0.8, 0.7] | [0.6, 0.8, 0.7], | [0.1, 0.3, 0.3],
|
||||
], | [0, 0, 0], | [0.6, 0.7, 0.8],
|
||||
[ | [0, 0, 0] | [0.2, 0.3, 0.4],
|
||||
[0.1, 0.3, 0.3], | ], | [0.1, 0.5, 0.3],
|
||||
[0.6, 0.7, 0.8], | [ | [0.7, 0.3, 0.6],
|
||||
[0.2, 0.3, 0.4], | [0.1, 0.3, 0.3], | [0.2, 0.4, 0.8],
|
||||
[0.1, 0.5, 0.3] | [0.6, 0.7, 0.8], | [0.9, 0.5, 0.2],
|
||||
], | [0.2, 0.3, 0.4], | [0.2, 0.3, 0.4],
|
||||
[ | [0.1, 0.5, 0.3], | [0.9, 0.3, 0.8],
|
||||
[0.7, 0.3, 0.6], | [0, 0, 0] | ])
|
||||
[0.2, 0.4, 0.8], | ], |
|
||||
[0.9, 0.5, 0.2], | [ |
|
||||
[0.2, 0.3, 0.4], | [0.7, 0.3, 0.6], |
|
||||
[0.9, 0.3, 0.8], | [0.2, 0.4, 0.8], |
|
||||
] | [0.9, 0.5, 0.2], |
|
||||
]) | [0.2, 0.3, 0.4], |
|
||||
| [0.9, 0.3, 0.8] |
|
||||
| ] |
|
||||
| ]) |
|
||||
-----------------------------------------------------------------------------
|
||||
|
||||
Auxillary variables for packed representation
|
||||
|
||||
Name | Size | Example from above
|
||||
-------------------------------|---------------------|-----------------------
|
||||
| |
|
||||
packed_to_cloud_idx | size = (sum(P_n)) | tensor([
|
||||
| | 0, 0, 0, 1, 1, 1,
|
||||
| | 1, 2, 2, 2, 2, 2
|
||||
| | )]
|
||||
| | size = (12)
|
||||
| |
|
||||
cloud_to_packed_first_idx | size = (N) | tensor([0, 3, 7])
|
||||
| | size = (3)
|
||||
| |
|
||||
num_points_per_cloud | size = (N) | tensor([3, 4, 5])
|
||||
| | size = (3)
|
||||
| |
|
||||
padded_to_packed_idx | size = (sum(P_n)) | tensor([
|
||||
| | 0, 1, 2, 5, 6, 7,
|
||||
| | 8, 10, 11, 12, 13,
|
||||
| | 14
|
||||
| | )]
|
||||
| | size = (12)
|
||||
-----------------------------------------------------------------------------
|
||||
# SPHINX IGNORE
|
||||
"""
|
||||
|
||||
_INTERNAL_TENSORS = [
|
||||
"_points_packed",
|
||||
"_points_padded",
|
||||
"_normals_packed",
|
||||
"_normals_padded",
|
||||
"_features_packed",
|
||||
"_features_padded",
|
||||
"_packed_to_cloud_idx",
|
||||
"_cloud_to_packed_first_idx",
|
||||
"_num_points_per_cloud",
|
||||
"_padded_to_packed_idx",
|
||||
"valid",
|
||||
"equisized",
|
||||
]
|
||||
|
||||
def __init__(self, points, normals=None, features=None):
|
||||
"""
|
||||
Args:
|
||||
points:
|
||||
Can be either
|
||||
|
||||
- List where each element is a tensor of shape (num_points, 3)
|
||||
containing the (x, y, z) coordinates of each point.
|
||||
- Padded float tensor with shape (num_clouds, num_points, 3).
|
||||
normals:
|
||||
Can be either
|
||||
|
||||
- List where each element is a tensor of shape (num_points, 3)
|
||||
containing the normal vector for each point.
|
||||
- Padded float tensor of shape (num_clouds, num_points, 3).
|
||||
features:
|
||||
Can be either
|
||||
|
||||
- List where each element is a tensor of shape (num_points, C)
|
||||
containing the features for the points in the cloud.
|
||||
- Padded float tensor of shape (num_clouds, num_points, C).
|
||||
where C is the number of channels in the features.
|
||||
For example 3 for RGB color.
|
||||
|
||||
Refer to comments above for descriptions of List and Padded
|
||||
representations.
|
||||
"""
|
||||
self.device = None
|
||||
|
||||
# Indicates whether the clouds in the list/batch have the same number
|
||||
# of points.
|
||||
self.equisized = False
|
||||
|
||||
# Boolean indicator for each cloud in the batch.
|
||||
# True if cloud has non zero number of points, False otherwise.
|
||||
self.valid = None
|
||||
|
||||
self._N = 0 # batch size (number of clouds)
|
||||
self._P = 0 # (max) number of points per cloud
|
||||
self._C = None # number of channels in the features
|
||||
|
||||
# List of Tensors of points and features.
|
||||
self._points_list = None
|
||||
self._normals_list = None
|
||||
self._features_list = None
|
||||
|
||||
# Number of points per cloud.
|
||||
self._num_points_per_cloud = None # N
|
||||
|
||||
# Packed representation.
|
||||
self._points_packed = None # (sum(P_n), 3)
|
||||
self._normals_packed = None # (sum(P_n), 3)
|
||||
self._features_packed = None # (sum(P_n), C)
|
||||
|
||||
self._packed_to_cloud_idx = None # sum(P_n)
|
||||
|
||||
# Index of each cloud's first point in the packed points.
|
||||
# Assumes packing is sequential.
|
||||
self._cloud_to_packed_first_idx = None # N
|
||||
|
||||
# Padded representation.
|
||||
self._points_padded = None # (N, max(P_n), 3)
|
||||
self._normals_padded = None # (N, max(P_n), 3)
|
||||
self._features_padded = None # (N, max(P_n), C)
|
||||
|
||||
# Index to convert points from flattened padded to packed.
|
||||
self._padded_to_packed_idx = None # N * max_P
|
||||
|
||||
# Identify type of points.
|
||||
if isinstance(points, list):
|
||||
self._points_list = points
|
||||
self._N = len(self._points_list)
|
||||
self.device = torch.device("cpu")
|
||||
self.valid = torch.zeros(
|
||||
(self._N,), dtype=torch.bool, device=self.device
|
||||
)
|
||||
self._num_points_per_cloud = []
|
||||
|
||||
if self._N > 0:
|
||||
for p in self._points_list:
|
||||
if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3):
|
||||
raise ValueError(
|
||||
"Clouds in list must be of shape Px3 or empty"
|
||||
)
|
||||
|
||||
self.device = self._points_list[0].device
|
||||
num_points_per_cloud = torch.tensor(
|
||||
[len(p) for p in self._points_list], device=self.device
|
||||
)
|
||||
self._P = num_points_per_cloud.max()
|
||||
self.valid = torch.tensor(
|
||||
[len(p) > 0 for p in self._points_list],
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if len(num_points_per_cloud.unique()) == 1:
|
||||
self.equisized = True
|
||||
self._num_points_per_cloud = num_points_per_cloud
|
||||
|
||||
elif torch.is_tensor(points):
|
||||
if points.dim() != 3 or points.shape[2] != 3:
|
||||
raise ValueError("Points tensor has incorrect dimensions.")
|
||||
self._points_padded = points
|
||||
self._N = self._points_padded.shape[0]
|
||||
self._P = self._points_padded.shape[1]
|
||||
self.device = self._points_padded.device
|
||||
self.valid = torch.ones(
|
||||
(self._N,), dtype=torch.bool, device=self.device
|
||||
)
|
||||
self._num_points_per_cloud = torch.tensor(
|
||||
[self._P] * self._N, device=self.device
|
||||
)
|
||||
self.equisized = True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Points must be either a list or a tensor with \
|
||||
shape (batch_size, P, 3) where P is the maximum number of \
|
||||
points in a cloud."
|
||||
)
|
||||
|
||||
# parse normals
|
||||
normals_parsed = self._parse_auxiliary_input(normals)
|
||||
self._normals_list, self._normals_padded, normals_C = normals_parsed
|
||||
if normals_C is not None and normals_C != 3:
|
||||
raise ValueError("Normals are expected to be 3-dimensional")
|
||||
|
||||
# parse features
|
||||
features_parsed = self._parse_auxiliary_input(features)
|
||||
self._features_list, self._features_padded, features_C = features_parsed
|
||||
if features_C is not None:
|
||||
self._C = features_C
|
||||
|
||||
def _parse_auxiliary_input(self, aux_input):
|
||||
"""
|
||||
Interpret the auxiliary inputs (normals, features) given to __init__.
|
||||
|
||||
Args:
|
||||
aux_input:
|
||||
Can be either
|
||||
|
||||
- List where each element is a tensor of shape (num_points, C)
|
||||
containing the features for the points in the cloud.
|
||||
- Padded float tensor of shape (num_clouds, num_points, C).
|
||||
For normals, C = 3
|
||||
|
||||
Returns:
|
||||
3-element tuple of list, padded, num_channels.
|
||||
If aux_input is list, then padded is None. If aux_input is a tensor, then list is None.
|
||||
"""
|
||||
if aux_input is None or self._N == 0:
|
||||
return None, None, None
|
||||
|
||||
aux_input_C = None
|
||||
|
||||
if isinstance(aux_input, list):
|
||||
if len(aux_input) != self._N:
|
||||
raise ValueError(
|
||||
"Points and auxiliary input must be the same length."
|
||||
)
|
||||
for p, d in zip(self._num_points_per_cloud, aux_input):
|
||||
if p != d.shape[0]:
|
||||
raise ValueError(
|
||||
"A cloud has mismatched numbers of points and inputs"
|
||||
)
|
||||
if p > 0:
|
||||
if d.dim() != 2:
|
||||
raise ValueError(
|
||||
"A cloud auxiliary input must be of shape PxC or empty"
|
||||
)
|
||||
if aux_input_C is None:
|
||||
aux_input_C = d.shape[1]
|
||||
if aux_input_C != d.shape[1]:
|
||||
raise ValueError(
|
||||
"The clouds must have the same number of channels"
|
||||
)
|
||||
return aux_input, None, aux_input_C
|
||||
elif torch.is_tensor(aux_input):
|
||||
if aux_input.dim() != 3:
|
||||
raise ValueError(
|
||||
"Auxiliary input tensor has incorrect dimensions."
|
||||
)
|
||||
if self._N != aux_input.shape[0]:
|
||||
raise ValueError("Points and inputs must be the same length.")
|
||||
if self._P != aux_input.shape[1]:
|
||||
raise ValueError(
|
||||
"Inputs tensor must have the right maximum \
|
||||
number of points in each cloud."
|
||||
)
|
||||
aux_input_C = aux_input.shape[2]
|
||||
return None, aux_input, aux_input_C
|
||||
else:
|
||||
raise ValueError(
|
||||
"Auxiliary input must be either a list or a tensor with \
|
||||
shape (batch_size, P, C) where P is the maximum number of \
|
||||
points in a cloud."
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self._N
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index: Specifying the index of the cloud to retrieve.
|
||||
Can be an int, slice, list of ints or a boolean tensor.
|
||||
|
||||
Returns:
|
||||
Pointclouds object with selected clouds. The tensors are not cloned.
|
||||
"""
|
||||
normals, features = None, None
|
||||
if isinstance(index, int):
|
||||
points = [self.points_list()[index]]
|
||||
if self.normals_list() is not None:
|
||||
normals = [self.normals_list()[index]]
|
||||
if self.features_list() is not None:
|
||||
features = [self.features_list()[index]]
|
||||
elif isinstance(index, slice):
|
||||
points = self.points_list()[index]
|
||||
if self.normals_list() is not None:
|
||||
normals = self.normals_list()[index]
|
||||
if self.features_list() is not None:
|
||||
features = self.features_list()[index]
|
||||
elif isinstance(index, list):
|
||||
points = [self.points_list()[i] for i in index]
|
||||
if self.normals_list() is not None:
|
||||
normals = [self.normals_list()[i] for i in index]
|
||||
if self.features_list() is not None:
|
||||
features = [self.features_list()[i] for i in index]
|
||||
elif isinstance(index, torch.Tensor):
|
||||
if index.dim() != 1 or index.dtype.is_floating_point:
|
||||
raise IndexError(index)
|
||||
# NOTE consider converting index to cpu for efficiency
|
||||
if index.dtype == torch.bool:
|
||||
# advanced indexing on a single dimension
|
||||
index = index.nonzero()
|
||||
index = index.squeeze(1) if index.numel() > 0 else index
|
||||
index = index.tolist()
|
||||
points = [self.points_list()[i] for i in index]
|
||||
if self.normals_list() is not None:
|
||||
normals = [self.normals_list()[i] for i in index]
|
||||
if self.features_list() is not None:
|
||||
features = [self.features_list()[i] for i in index]
|
||||
else:
|
||||
raise IndexError(index)
|
||||
|
||||
return Pointclouds(points=points, normals=normals, features=features)
|
||||
|
||||
def isempty(self) -> bool:
|
||||
"""
|
||||
Checks whether any cloud is valid.
|
||||
|
||||
Returns:
|
||||
bool indicating whether there is any data.
|
||||
"""
|
||||
return self._N == 0 or self.valid.eq(False).all()
|
||||
|
||||
def points_list(self):
|
||||
"""
|
||||
Get the list representation of the points.
|
||||
|
||||
Returns:
|
||||
list of tensors of points of shape (P_n, 3).
|
||||
"""
|
||||
if self._points_list is None:
|
||||
assert (
|
||||
self._points_padded is not None
|
||||
), "points_padded is required to compute points_list."
|
||||
points_list = []
|
||||
for i in range(self._N):
|
||||
points_list.append(
|
||||
self._points_padded[i, : self.num_points_per_cloud()[i]]
|
||||
)
|
||||
self._points_list = points_list
|
||||
return self._points_list
|
||||
|
||||
def normals_list(self):
|
||||
"""
|
||||
Get the list representation of the normals.
|
||||
|
||||
Returns:
|
||||
list of tensors of normals of shape (P_n, 3).
|
||||
"""
|
||||
if self._normals_list is None:
|
||||
if self._normals_padded is None:
|
||||
# No normals provided so return None
|
||||
return None
|
||||
self._normals_list = []
|
||||
for i in range(self._N):
|
||||
self._normals_list.append(
|
||||
self._normals_padded[i, : self.num_points_per_cloud()[i]]
|
||||
)
|
||||
return self._normals_list
|
||||
|
||||
def features_list(self):
|
||||
"""
|
||||
Get the list representation of the features.
|
||||
|
||||
Returns:
|
||||
list of tensors of features of shape (P_n, C).
|
||||
"""
|
||||
if self._features_list is None:
|
||||
if self._features_padded is None:
|
||||
# No features provided so return None
|
||||
return None
|
||||
self._features_list = []
|
||||
for i in range(self._N):
|
||||
self._features_list.append(
|
||||
self._features_padded[i, : self.num_points_per_cloud()[i]]
|
||||
)
|
||||
return self._features_list
|
||||
|
||||
def points_packed(self):
|
||||
"""
|
||||
Get the packed representation of the points.
|
||||
|
||||
Returns:
|
||||
tensor of points of shape (sum(P_n), 3).
|
||||
"""
|
||||
self._compute_packed()
|
||||
return self._points_packed
|
||||
|
||||
def normals_packed(self):
|
||||
"""
|
||||
Get the packed representation of the normals.
|
||||
|
||||
Returns:
|
||||
tensor of normals of shape (sum(P_n), 3).
|
||||
"""
|
||||
self._compute_packed()
|
||||
return self._normals_packed
|
||||
|
||||
def features_packed(self):
|
||||
"""
|
||||
Get the packed representation of the features.
|
||||
|
||||
Returns:
|
||||
tensor of features of shape (sum(P_n), C).
|
||||
"""
|
||||
self._compute_packed()
|
||||
return self._features_packed
|
||||
|
||||
def packed_to_cloud_idx(self):
|
||||
"""
|
||||
Return a 1D tensor x with length equal to the total number of points.
|
||||
packed_to_cloud_idx()[i] gives the index of the cloud which contains
|
||||
points_packed()[i].
|
||||
|
||||
Returns:
|
||||
1D tensor of indices.
|
||||
"""
|
||||
self._compute_packed()
|
||||
return self._packed_to_cloud_idx
|
||||
|
||||
def cloud_to_packed_first_idx(self):
|
||||
"""
|
||||
Return a 1D tensor x with length equal to the number of clouds such that
|
||||
the first point of the ith cloud is points_packed[x[i]].
|
||||
|
||||
Returns:
|
||||
1D tensor of indices of first items.
|
||||
"""
|
||||
self._compute_packed()
|
||||
return self._cloud_to_packed_first_idx
|
||||
|
||||
def num_points_per_cloud(self):
|
||||
"""
|
||||
Return a 1D tensor x with length equal to the number of clouds giving
|
||||
the number of points in each cloud.
|
||||
|
||||
Returns:
|
||||
1D tensor of sizes.
|
||||
"""
|
||||
return self._num_points_per_cloud
|
||||
|
||||
def points_padded(self):
|
||||
"""
|
||||
Get the padded representation of the points.
|
||||
|
||||
Returns:
|
||||
tensor of points of shape (N, max(P_n), 3).
|
||||
"""
|
||||
self._compute_padded()
|
||||
return self._points_padded
|
||||
|
||||
def normals_padded(self):
|
||||
"""
|
||||
Get the padded representation of the normals.
|
||||
|
||||
Returns:
|
||||
tensor of normals of shape (N, max(P_n), 3).
|
||||
"""
|
||||
self._compute_padded()
|
||||
return self._normals_padded
|
||||
|
||||
def features_padded(self):
|
||||
"""
|
||||
Get the padded representation of the features.
|
||||
|
||||
Returns:
|
||||
tensor of features of shape (N, max(P_n), 3).
|
||||
"""
|
||||
self._compute_padded()
|
||||
return self._features_padded
|
||||
|
||||
def padded_to_packed_idx(self):
|
||||
"""
|
||||
Return a 1D tensor x with length equal to the total number of points
|
||||
such that points_packed()[i] is element x[i] of the flattened padded
|
||||
representation.
|
||||
The packed representation can be calculated as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
p = points_padded().reshape(-1, 3)
|
||||
points_packed = p[x]
|
||||
|
||||
Returns:
|
||||
1D tensor of indices.
|
||||
"""
|
||||
self._compute_packed()
|
||||
if self._padded_to_packed_idx is not None:
|
||||
return self._padded_to_packed_idx
|
||||
if self._N == 0:
|
||||
self._padded_to_packed_idx = []
|
||||
else:
|
||||
self._padded_to_packed_idx = torch.cat(
|
||||
[
|
||||
torch.arange(v, dtype=torch.int64, device=self.device)
|
||||
+ i * self._P
|
||||
for (i, v) in enumerate(self._num_points_per_cloud)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
return self._padded_to_packed_idx
|
||||
|
||||
def _compute_padded(self, refresh: bool = False):
|
||||
"""
|
||||
Computes the padded version from points_list, normals_list and features_list.
|
||||
|
||||
Args:
|
||||
refresh: whether to force the recalculation.
|
||||
"""
|
||||
if not (refresh or self._points_padded is None):
|
||||
return
|
||||
|
||||
self._normals_padded, self._features_padded = None, None
|
||||
if self.isempty():
|
||||
self._points_padded = torch.zeros(
|
||||
(self._N, 0, 3), device=self.device
|
||||
)
|
||||
else:
|
||||
self._points_padded = struct_utils.list_to_padded(
|
||||
self.points_list(),
|
||||
(self._P, 3),
|
||||
pad_value=0.0,
|
||||
equisized=self.equisized,
|
||||
)
|
||||
if self.normals_list() is not None:
|
||||
self._normals_padded = struct_utils.list_to_padded(
|
||||
self.normals_list(),
|
||||
(self._P, 3),
|
||||
pad_value=0.0,
|
||||
equisized=self.equisized,
|
||||
)
|
||||
if self.features_list() is not None:
|
||||
self._features_padded = struct_utils.list_to_padded(
|
||||
self.features_list(),
|
||||
(self._P, self._C),
|
||||
pad_value=0.0,
|
||||
equisized=self.equisized,
|
||||
)
|
||||
|
||||
# TODO(nikhilar) Improve performance of _compute_packed.
|
||||
def _compute_packed(self, refresh: bool = False):
|
||||
"""
|
||||
Computes the packed version from points_list, normals_list and
|
||||
features_list and sets the values of auxillary tensors.
|
||||
|
||||
Args:
|
||||
refresh: Set to True to force recomputation of packed
|
||||
representations. Default: False.
|
||||
"""
|
||||
|
||||
if not (
|
||||
refresh
|
||||
or any(
|
||||
v is None
|
||||
for v in [
|
||||
self._points_packed,
|
||||
self._packed_to_cloud_idx,
|
||||
self._cloud_to_packed_first_idx,
|
||||
]
|
||||
)
|
||||
):
|
||||
return
|
||||
|
||||
# Packed can be calculated from padded or list, so can call the
|
||||
# accessor function for the lists.
|
||||
points_list = self.points_list()
|
||||
normals_list = self.normals_list()
|
||||
features_list = self.features_list()
|
||||
if self.isempty():
|
||||
self._points_packed = torch.zeros(
|
||||
(0, 3), dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._packed_to_cloud_idx = torch.zeros(
|
||||
(0,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
self._cloud_to_packed_first_idx = torch.zeros(
|
||||
(0,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
self._normals_packed = None
|
||||
self._features_packed = None
|
||||
return
|
||||
|
||||
points_list_to_packed = struct_utils.list_to_packed(points_list)
|
||||
self._points_packed = points_list_to_packed[0]
|
||||
if not torch.allclose(
|
||||
self._num_points_per_cloud, points_list_to_packed[1]
|
||||
):
|
||||
raise ValueError("Inconsistent list to packed conversion")
|
||||
self._cloud_to_packed_first_idx = points_list_to_packed[2]
|
||||
self._packed_to_cloud_idx = points_list_to_packed[3]
|
||||
|
||||
self._normals_packed, self._features_packed = None, None
|
||||
if normals_list is not None:
|
||||
normals_list_to_packed = struct_utils.list_to_packed(normals_list)
|
||||
self._normals_packed = normals_list_to_packed[0]
|
||||
|
||||
if features_list is not None:
|
||||
features_list_to_packed = struct_utils.list_to_packed(features_list)
|
||||
self._features_packed = features_list_to_packed[0]
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
Deep copy of Pointclouds object. All internal tensors are cloned
|
||||
individually.
|
||||
|
||||
Returns:
|
||||
new Pointclouds object.
|
||||
"""
|
||||
# instantiate new pointcloud with the representation which is not None
|
||||
# (either list or tensor) to save compute.
|
||||
new_points, new_normals, new_features = None, None, None
|
||||
if self._points_list is not None:
|
||||
new_points = [v.clone() for v in self.points_list()]
|
||||
normals_list = self.normals_list()
|
||||
features_list = self.features_list()
|
||||
if normals_list is not None:
|
||||
new_normals = [n.clone() for n in normals_list]
|
||||
if features_list is not None:
|
||||
new_features = [f.clone() for f in features_list]
|
||||
elif self._points_padded is not None:
|
||||
new_points = self.points_padded().clone()
|
||||
normals_padded = self.normals_padded()
|
||||
features_padded = self.features_padded()
|
||||
if normals_padded is not None:
|
||||
new_normals = self.normals_padded().clone()
|
||||
if features_padded is not None:
|
||||
new_features = self.features_padded().clone()
|
||||
other = Pointclouds(
|
||||
points=new_points, normals=new_normals, features=new_features
|
||||
)
|
||||
for k in self._INTERNAL_TENSORS:
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(other, k, v.clone())
|
||||
return other
|
||||
|
||||
def to(self, device, copy: bool = False):
|
||||
"""
|
||||
Match functionality of torch.Tensor.to()
|
||||
If copy = True or the self Tensor is on a different device, the
|
||||
returned tensor is a copy of self with the desired torch.device.
|
||||
If copy = False and the self Tensor already has the correct torch.device,
|
||||
then self is returned.
|
||||
|
||||
Args:
|
||||
device: Device id for the new tensor.
|
||||
copy: Boolean indicator whether or not to clone self. Default False.
|
||||
|
||||
Returns:
|
||||
Pointclouds object.
|
||||
"""
|
||||
if not copy and self.device == device:
|
||||
return self
|
||||
other = self.clone()
|
||||
if self.device != device:
|
||||
other.device = device
|
||||
if other._N > 0:
|
||||
other._points_list = [v.to(device) for v in other.points_list()]
|
||||
if other._normals_list is not None:
|
||||
other._normals_list = [
|
||||
n.to(device) for n in other.normals_list()
|
||||
]
|
||||
if other._features_list is not None:
|
||||
other._features_list = [
|
||||
f.to(device) for f in other.features_list()
|
||||
]
|
||||
for k in self._INTERNAL_TENSORS:
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(other, k, v.to(device))
|
||||
return other
|
||||
|
||||
def cpu(self):
|
||||
return self.to(torch.device("cpu"))
|
||||
|
||||
def cuda(self):
|
||||
return self.to(torch.device("cuda"))
|
||||
|
||||
def get_cloud(self, index: int):
|
||||
"""
|
||||
Get tensors for a single cloud from the list representation.
|
||||
|
||||
Args:
|
||||
index: Integer in the range [0, N).
|
||||
|
||||
Returns:
|
||||
points: Tensor of shape (P, 3).
|
||||
normals: Tensor of shape (P, 3)
|
||||
features: LongTensor of shape (P, C).
|
||||
"""
|
||||
if not isinstance(index, int):
|
||||
raise ValueError("Cloud index must be an integer.")
|
||||
if index < 0 or index > self._N:
|
||||
raise ValueError(
|
||||
"Cloud index must be in the range [0, N) where \
|
||||
N is the number of clouds in the batch."
|
||||
)
|
||||
points = self.points_list()[index]
|
||||
normals, features = None, None
|
||||
if self.normals_list() is not None:
|
||||
normals = self.normals_list()[index]
|
||||
if self.features_list() is not None:
|
||||
features = self.features_list()[index]
|
||||
return points, normals, features
|
||||
|
||||
# TODO(nikhilar) Move function to a utils file.
|
||||
def split(self, split_sizes: list):
|
||||
"""
|
||||
Splits Pointclouds object of size N into a list of Pointclouds objects
|
||||
of size len(split_sizes), where the i-th Pointclouds object is of size
|
||||
split_sizes[i]. Similar to torch.split().
|
||||
|
||||
Args:
|
||||
split_sizes: List of integer sizes of Pointclouds objects to be
|
||||
returned.
|
||||
|
||||
Returns:
|
||||
list[PointClouds].
|
||||
"""
|
||||
if not all(isinstance(x, int) for x in split_sizes):
|
||||
raise ValueError("Value of split_sizes must be a list of integers.")
|
||||
cloudlist = []
|
||||
curi = 0
|
||||
for i in split_sizes:
|
||||
cloudlist.append(self[curi : curi + i])
|
||||
curi += i
|
||||
return cloudlist
|
||||
|
||||
def offset_(self, offsets_packed):
|
||||
"""
|
||||
Translate the point clouds by an offset. In place operation.
|
||||
|
||||
Args:
|
||||
offsets_packed: A Tensor of the same shape as self.points_packed
|
||||
giving offsets to be added to all points.
|
||||
Returns:
|
||||
self.
|
||||
"""
|
||||
points_packed = self.points_packed()
|
||||
if offsets_packed.shape != points_packed.shape:
|
||||
raise ValueError("Offsets must have dimension (all_p, 3).")
|
||||
self._points_packed = points_packed + offsets_packed
|
||||
new_points_list = list(
|
||||
self._points_packed.split(self.num_points_per_cloud().tolist(), 0)
|
||||
)
|
||||
# Note that since _compute_packed() has been executed, points_list
|
||||
# cannot be None even if not provided during construction.
|
||||
self._points_list = new_points_list
|
||||
if self._points_padded is not None:
|
||||
for i, points in enumerate(new_points_list):
|
||||
if len(points) > 0:
|
||||
self._points_padded[i, : points.shape[0], :] = points
|
||||
return self
|
||||
|
||||
# TODO(nikhilar) Move out of place operator to a utils file.
|
||||
def offset(self, offsets_packed):
|
||||
"""
|
||||
Out of place offset.
|
||||
|
||||
Args:
|
||||
offsets_packed: A Tensor of the same shape as self.points_packed
|
||||
giving offsets to be added to all points.
|
||||
Returns:
|
||||
new Pointclouds object.
|
||||
"""
|
||||
new_clouds = self.clone()
|
||||
return new_clouds.offset_(offsets_packed)
|
||||
|
||||
def scale_(self, scale):
|
||||
"""
|
||||
Multiply the coordinates of this object by a scalar value.
|
||||
- i.e. enlarge/dilate
|
||||
In place operation.
|
||||
|
||||
Args:
|
||||
scale: A scalar, or a Tensor of shape (N,).
|
||||
|
||||
Returns:
|
||||
self.
|
||||
"""
|
||||
if not torch.is_tensor(scale):
|
||||
scale = torch.full(len(self), scale)
|
||||
new_points_list = []
|
||||
points_list = self.points_list()
|
||||
for i, old_points in enumerate(points_list):
|
||||
new_points_list.append(scale[i] * old_points)
|
||||
self._points_list = new_points_list
|
||||
if self._points_packed is not None:
|
||||
self._points_packed = torch.cat(new_points_list, dim=0)
|
||||
if self._points_padded is not None:
|
||||
for i, points in enumerate(new_points_list):
|
||||
if len(points) > 0:
|
||||
self._points_padded[i, : points.shape[0], :] = points
|
||||
return self
|
||||
|
||||
def scale(self, scale):
|
||||
"""
|
||||
Out of place scale_.
|
||||
|
||||
Args:
|
||||
scale: A scalar, or a Tensor of shape (N,).
|
||||
|
||||
Returns:
|
||||
new Pointclouds object.
|
||||
"""
|
||||
new_clouds = self.clone()
|
||||
return new_clouds.scale_(scale)
|
||||
|
||||
# TODO(nikhilar) Move function to utils file.
|
||||
def get_bounding_boxes(self):
|
||||
"""
|
||||
Compute an axis-aligned bounding box for each cloud.
|
||||
|
||||
Returns:
|
||||
bboxes: Tensor of shape (N, 3, 2) where bbox[i, j] gives the
|
||||
min and max values of cloud i along the jth coordinate axis.
|
||||
"""
|
||||
all_mins, all_maxes = [], []
|
||||
for points in self.points_list():
|
||||
cur_mins = points.min(dim=0)[0] # (3,)
|
||||
cur_maxes = points.max(dim=0)[0] # (3,)
|
||||
all_mins.append(cur_mins)
|
||||
all_maxes.append(cur_maxes)
|
||||
all_mins = torch.stack(all_mins, dim=0) # (N, 3)
|
||||
all_maxes = torch.stack(all_maxes, dim=0) # (N, 3)
|
||||
bboxes = torch.stack([all_mins, all_maxes], dim=2)
|
||||
return bboxes
|
||||
|
||||
def extend(self, N: int):
|
||||
"""
|
||||
Create new Pointclouds which contains each cloud N times.
|
||||
|
||||
Args:
|
||||
N: number of new copies of each cloud.
|
||||
|
||||
Returns:
|
||||
new Pointclouds object.
|
||||
"""
|
||||
if not isinstance(N, int):
|
||||
raise ValueError("N must be an integer.")
|
||||
if N <= 0:
|
||||
raise ValueError("N must be > 0.")
|
||||
|
||||
new_points_list, new_normals_list, new_features_list = [], None, None
|
||||
for points in self.points_list():
|
||||
new_points_list.extend(points.clone() for _ in range(N))
|
||||
if self.normals_list() is not None:
|
||||
new_normals_list = []
|
||||
for normals in self.normals_list():
|
||||
new_normals_list.extend(normals.clone() for _ in range(N))
|
||||
if self.features_list() is not None:
|
||||
new_features_list = []
|
||||
for features in self.features_list():
|
||||
new_features_list.extend(features.clone() for _ in range(N))
|
||||
return Pointclouds(
|
||||
points=new_points_list,
|
||||
normals=new_normals_list,
|
||||
features=new_features_list,
|
||||
)
|
||||
|
||||
def update_padded(
|
||||
self,
|
||||
new_points_padded,
|
||||
new_normals_padded=None,
|
||||
new_features_padded=None,
|
||||
):
|
||||
"""
|
||||
Returns a Pointcloud structure with updated padded tensors and copies of
|
||||
the auxiliary tensors. This function allows for an update of
|
||||
points_padded (and normals and features) without having to explicitly
|
||||
convert it to the list representation for heterogeneous batches.
|
||||
|
||||
Args:
|
||||
new_points_padded: FloatTensor of shape (N, P, 3)
|
||||
new_normals_padded: (optional) FloatTensor of shape (N, P, 3)
|
||||
new_features_padded: (optional) FloatTensors of shape (N, P, C)
|
||||
|
||||
Returns:
|
||||
Pointcloud with updated padded representations
|
||||
"""
|
||||
|
||||
def check_shapes(x, size):
|
||||
if x.shape[0] != size[0]:
|
||||
raise ValueError(
|
||||
"new values must have the same batch dimension."
|
||||
)
|
||||
if x.shape[1] != size[1]:
|
||||
raise ValueError(
|
||||
"new values must have the same number of points."
|
||||
)
|
||||
if size[2] is not None:
|
||||
if x.shape[2] != size[2]:
|
||||
raise ValueError(
|
||||
"new values must have the same number of channels."
|
||||
)
|
||||
|
||||
check_shapes(new_points_padded, [self._N, self._P, 3])
|
||||
if new_normals_padded is not None:
|
||||
check_shapes(new_normals_padded, [self._N, self._P, 3])
|
||||
if new_features_padded is not None:
|
||||
check_shapes(new_features_padded, [self._N, self._P, self._C])
|
||||
|
||||
new = Pointclouds(
|
||||
points=new_points_padded,
|
||||
normals=new_normals_padded,
|
||||
features=new_features_padded,
|
||||
)
|
||||
|
||||
# overwrite the equisized flag
|
||||
new.equisized = self.equisized
|
||||
|
||||
# copy normals
|
||||
if new_normals_padded is None:
|
||||
# If no normals are provided, keep old ones (shallow copy)
|
||||
new._normals_list = self._normals_list
|
||||
new._normals_padded = self._normals_padded
|
||||
new._normals_packed = self._normals_packed
|
||||
|
||||
# copy features
|
||||
if new_features_padded is None:
|
||||
# If no features are provided, keep old ones (shallow copy)
|
||||
new._features_list = self._features_list
|
||||
new._features_padded = self._features_padded
|
||||
new._features_packed = self._features_packed
|
||||
|
||||
# copy auxiliary tensors
|
||||
copy_tensors = [
|
||||
"_packed_to_cloud_idx",
|
||||
"_cloud_to_packed_first_idx",
|
||||
"_num_points_per_cloud",
|
||||
"_padded_to_packed_idx",
|
||||
"valid",
|
||||
]
|
||||
for k in copy_tensors:
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(new, k, v) # shallow copy
|
||||
|
||||
# update points
|
||||
new._points_padded = new_points_padded
|
||||
assert new._points_list is None
|
||||
assert new._points_packed is None
|
||||
|
||||
# update normals and features if provided
|
||||
if new_normals_padded is not None:
|
||||
new._normals_padded = new_normals_padded
|
||||
new._normals_list = None
|
||||
new._normals_packed = None
|
||||
if new_features_padded is not None:
|
||||
new._features_padded = new_features_padded
|
||||
new._features_list = None
|
||||
new._features_packed = None
|
||||
return new
|
30
tests/bm_pointclouds.py
Normal file
30
tests/bm_pointclouds.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from itertools import product
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from test_pointclouds import TestPointclouds
|
||||
|
||||
|
||||
def bm_compute_packed_padded_pointclouds() -> None:
|
||||
kwargs_list = []
|
||||
num_clouds = [32, 128]
|
||||
max_p = [100, 10000]
|
||||
feats = [1, 10, 300]
|
||||
test_cases = product(num_clouds, max_p, feats)
|
||||
for case in test_cases:
|
||||
n, p, f = case
|
||||
kwargs_list.append({"num_clouds": n, "max_p": p, "features": f})
|
||||
benchmark(
|
||||
TestPointclouds.compute_packed_with_init,
|
||||
"COMPUTE_PACKED",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
TestPointclouds.compute_padded_with_init,
|
||||
"COMPUTE_PADDED",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
52
tests/bm_rasterize_points.py
Normal file
52
tests/bm_rasterize_points.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from pytorch3d.renderer.points.rasterize_points import (
|
||||
rasterize_points,
|
||||
rasterize_points_python,
|
||||
)
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
|
||||
def _bm_python_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
|
||||
torch.manual_seed(231)
|
||||
points = torch.randn(N, P, 3)
|
||||
pointclouds = Pointclouds(points=points)
|
||||
args = (pointclouds, img_size, radius, pts_per_pxl)
|
||||
return lambda: rasterize_points_python(*args)
|
||||
|
||||
|
||||
def _bm_cpu_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
|
||||
torch.manual_seed(231)
|
||||
points = torch.randn(N, P, 3)
|
||||
pointclouds = Pointclouds(points=points)
|
||||
args = (pointclouds, img_size, radius, pts_per_pxl)
|
||||
return lambda: rasterize_points(*args)
|
||||
|
||||
|
||||
def _bm_cuda_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
|
||||
torch.manual_seed(231)
|
||||
points = torch.randn(N, P, 3, device=torch.device("cuda"))
|
||||
pointclouds = Pointclouds(points=points)
|
||||
args = (pointclouds, img_size, radius, pts_per_pxl)
|
||||
return lambda: rasterize_points(*args)
|
||||
|
||||
|
||||
def bm_python_vs_cpu() -> None:
|
||||
kwargs_list = [
|
||||
{"N": 1, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
|
||||
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
|
||||
]
|
||||
benchmark(
|
||||
_bm_python_with_init, "RASTERIZE_PYTHON", kwargs_list, warmup_iters=1
|
||||
)
|
||||
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
|
||||
kwargs_list = [
|
||||
{"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
|
||||
{"N": 4, "P": 1024, "img_size": 128, "radius": 0.05, "pts_per_pxl": 5},
|
||||
]
|
||||
benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
|
||||
benchmark(_bm_cuda_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1)
|
978
tests/test_pointclouds.py
Normal file
978
tests/test_pointclouds.py
Normal file
@ -0,0 +1,978 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
np.random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
|
||||
@staticmethod
|
||||
def init_cloud(
|
||||
num_clouds: int = 3,
|
||||
max_points: int = 100,
|
||||
channels: int = 4,
|
||||
lists_to_tensors: bool = False,
|
||||
with_normals: bool = True,
|
||||
with_features: bool = True,
|
||||
):
|
||||
"""
|
||||
Function to generate a Pointclouds object of N meshes with
|
||||
random number of points.
|
||||
|
||||
Args:
|
||||
num_clouds: Number of clouds to generate.
|
||||
channels: Number of features.
|
||||
max_points: Max number of points per cloud.
|
||||
lists_to_tensors: Determines whether the generated clouds should be
|
||||
constructed from lists (=False) or
|
||||
tensors (=True) of points/normals/features.
|
||||
with_normals: bool whether to include normals
|
||||
with_features: bool whether to include features
|
||||
|
||||
Returns:
|
||||
Pointclouds object.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
p = torch.randint(max_points, size=(num_clouds,))
|
||||
if lists_to_tensors:
|
||||
p.fill_(p[0])
|
||||
|
||||
points_list = [
|
||||
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
|
||||
]
|
||||
normals_list, features_list = None, None
|
||||
if with_normals:
|
||||
normals_list = [
|
||||
torch.rand((i, 3), device=device, dtype=torch.float32)
|
||||
for i in p
|
||||
]
|
||||
if with_features:
|
||||
features_list = [
|
||||
torch.rand((i, channels), device=device, dtype=torch.float32)
|
||||
for i in p
|
||||
]
|
||||
|
||||
if lists_to_tensors:
|
||||
points_list = torch.stack(points_list)
|
||||
if with_normals:
|
||||
normals_list = torch.stack(normals_list)
|
||||
if with_features:
|
||||
features_list = torch.stack(features_list)
|
||||
|
||||
return Pointclouds(
|
||||
points_list, normals=normals_list, features=features_list
|
||||
)
|
||||
|
||||
def test_simple(self):
|
||||
device = torch.device("cuda:0")
|
||||
points = [
|
||||
torch.tensor(
|
||||
[[0.1, 0.3, 0.5], [0.5, 0.2, 0.1], [0.6, 0.8, 0.7]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[0.1, 0.3, 0.3],
|
||||
[0.6, 0.7, 0.8],
|
||||
[0.2, 0.3, 0.4],
|
||||
[0.1, 0.5, 0.3],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[0.7, 0.3, 0.6],
|
||||
[0.2, 0.4, 0.8],
|
||||
[0.9, 0.5, 0.2],
|
||||
[0.2, 0.3, 0.4],
|
||||
[0.9, 0.3, 0.8],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
clouds = Pointclouds(points)
|
||||
|
||||
self.assertClose(
|
||||
(clouds.packed_to_cloud_idx()).cpu(),
|
||||
torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]),
|
||||
)
|
||||
self.assertClose(
|
||||
clouds.cloud_to_packed_first_idx().cpu(), torch.tensor([0, 3, 7])
|
||||
)
|
||||
self.assertClose(
|
||||
clouds.num_points_per_cloud().cpu(), torch.tensor([3, 4, 5])
|
||||
)
|
||||
self.assertClose(
|
||||
clouds.padded_to_packed_idx().cpu(),
|
||||
torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
|
||||
)
|
||||
|
||||
def test_all_constructions(self):
|
||||
public_getters = [
|
||||
"points_list",
|
||||
"points_packed",
|
||||
"packed_to_cloud_idx",
|
||||
"cloud_to_packed_first_idx",
|
||||
"num_points_per_cloud",
|
||||
"points_padded",
|
||||
"padded_to_packed_idx",
|
||||
]
|
||||
public_normals_getters = [
|
||||
"normals_list",
|
||||
"normals_packed",
|
||||
"normals_padded",
|
||||
]
|
||||
public_features_getters = [
|
||||
"features_list",
|
||||
"features_packed",
|
||||
"features_padded",
|
||||
]
|
||||
|
||||
lengths = [3, 4, 2]
|
||||
max_len = max(lengths)
|
||||
C = 4
|
||||
|
||||
points_data = [torch.zeros((max_len, 3)).uniform_() for i in lengths]
|
||||
normals_data = [torch.zeros((max_len, 3)).uniform_() for i in lengths]
|
||||
features_data = [torch.zeros((max_len, C)).uniform_() for i in lengths]
|
||||
for length, p, n, f in zip(
|
||||
lengths, points_data, normals_data, features_data
|
||||
):
|
||||
p[length:] = 0.0
|
||||
n[length:] = 0.0
|
||||
f[length:] = 0.0
|
||||
points_list = [d[:length] for length, d in zip(lengths, points_data)]
|
||||
normals_list = [d[:length] for length, d in zip(lengths, normals_data)]
|
||||
features_list = [
|
||||
d[:length] for length, d in zip(lengths, features_data)
|
||||
]
|
||||
points_packed = torch.cat(points_data)
|
||||
normals_packed = torch.cat(normals_data)
|
||||
features_packed = torch.cat(features_data)
|
||||
test_cases_inputs = [
|
||||
("list_0_0", points_list, None, None),
|
||||
("list_1_0", points_list, normals_list, None),
|
||||
("list_0_1", points_list, None, features_list),
|
||||
("list_1_1", points_list, normals_list, features_list),
|
||||
("padded_0_0", points_data, None, None),
|
||||
("padded_1_0", points_data, normals_data, None),
|
||||
("padded_0_1", points_data, None, features_data),
|
||||
("padded_1_1", points_data, normals_data, features_data),
|
||||
("emptylist_emptylist_emptylist", [], [], []),
|
||||
]
|
||||
false_cases_inputs = [
|
||||
(
|
||||
"list_packed",
|
||||
points_list,
|
||||
normals_packed,
|
||||
features_packed,
|
||||
ValueError,
|
||||
),
|
||||
("packed_0", points_packed, None, None, ValueError),
|
||||
]
|
||||
|
||||
for name, points, normals, features in test_cases_inputs:
|
||||
with self.subTest(name=name):
|
||||
p = Pointclouds(points, normals, features)
|
||||
for method in public_getters:
|
||||
self.assertIsNotNone(getattr(p, method)())
|
||||
for method in public_normals_getters:
|
||||
if normals is None or p.isempty():
|
||||
self.assertIsNone(getattr(p, method)())
|
||||
for method in public_features_getters:
|
||||
if features is None or p.isempty():
|
||||
self.assertIsNone(getattr(p, method)())
|
||||
|
||||
for name, points, normals, features, error in false_cases_inputs:
|
||||
with self.subTest(name=name):
|
||||
with self.assertRaises(error):
|
||||
Pointclouds(points, normals, features)
|
||||
|
||||
def test_simple_random_clouds(self):
|
||||
# Define the test object either from lists or tensors.
|
||||
for with_normals in (False, True):
|
||||
for with_features in (False, True):
|
||||
for lists_to_tensors in (False, True):
|
||||
N = 10
|
||||
cloud = self.init_cloud(
|
||||
N,
|
||||
lists_to_tensors=lists_to_tensors,
|
||||
with_normals=with_normals,
|
||||
with_features=with_features,
|
||||
)
|
||||
points_list = cloud.points_list()
|
||||
normals_list = cloud.normals_list()
|
||||
features_list = cloud.features_list()
|
||||
|
||||
# Check batch calculations.
|
||||
points_padded = cloud.points_padded()
|
||||
normals_padded = cloud.normals_padded()
|
||||
features_padded = cloud.features_padded()
|
||||
points_per_cloud = cloud.num_points_per_cloud()
|
||||
|
||||
if not with_normals:
|
||||
self.assertIsNone(normals_list)
|
||||
self.assertIsNone(normals_padded)
|
||||
if not with_features:
|
||||
self.assertIsNone(features_list)
|
||||
self.assertIsNone(features_padded)
|
||||
for n in range(N):
|
||||
p = points_list[n].shape[0]
|
||||
self.assertClose(
|
||||
points_padded[n, :p, :], points_list[n]
|
||||
)
|
||||
if with_normals:
|
||||
norms = normals_list[n].shape[0]
|
||||
self.assertEqual(p, norms)
|
||||
self.assertClose(
|
||||
normals_padded[n, :p, :], normals_list[n]
|
||||
)
|
||||
if with_features:
|
||||
f = features_list[n].shape[0]
|
||||
self.assertEqual(p, f)
|
||||
self.assertClose(
|
||||
features_padded[n, :p, :], features_list[n]
|
||||
)
|
||||
if points_padded.shape[1] > p:
|
||||
self.assertTrue(points_padded[n, p:, :].eq(0).all())
|
||||
if with_features:
|
||||
self.assertTrue(
|
||||
features_padded[n, p:, :].eq(0).all()
|
||||
)
|
||||
self.assertEqual(points_per_cloud[n], p)
|
||||
|
||||
# Check compute packed.
|
||||
points_packed = cloud.points_packed()
|
||||
packed_to_cloud = cloud.packed_to_cloud_idx()
|
||||
cloud_to_packed = cloud.cloud_to_packed_first_idx()
|
||||
normals_packed = cloud.normals_packed()
|
||||
features_packed = cloud.features_packed()
|
||||
if not with_normals:
|
||||
self.assertIsNone(normals_packed)
|
||||
if not with_features:
|
||||
self.assertIsNone(features_packed)
|
||||
|
||||
cur = 0
|
||||
for n in range(N):
|
||||
p = points_list[n].shape[0]
|
||||
self.assertClose(
|
||||
points_packed[cur : cur + p, :], points_list[n]
|
||||
)
|
||||
if with_normals:
|
||||
self.assertClose(
|
||||
normals_packed[cur : cur + p, :],
|
||||
normals_list[n],
|
||||
)
|
||||
if with_features:
|
||||
self.assertClose(
|
||||
features_packed[cur : cur + p, :],
|
||||
features_list[n],
|
||||
)
|
||||
self.assertTrue(
|
||||
packed_to_cloud[cur : cur + p].eq(n).all()
|
||||
)
|
||||
self.assertTrue(cloud_to_packed[n] == cur)
|
||||
cur += p
|
||||
|
||||
def test_allempty(self):
|
||||
clouds = Pointclouds([], [])
|
||||
self.assertEqual(len(clouds), 0)
|
||||
self.assertIsNone(clouds.normals_list())
|
||||
self.assertIsNone(clouds.features_list())
|
||||
self.assertEqual(clouds.points_padded().shape[0], 0)
|
||||
self.assertIsNone(clouds.normals_padded())
|
||||
self.assertIsNone(clouds.features_padded())
|
||||
self.assertEqual(clouds.points_packed().shape[0], 0)
|
||||
self.assertIsNone(clouds.normals_packed())
|
||||
self.assertIsNone(clouds.features_packed())
|
||||
|
||||
def test_empty(self):
|
||||
N, P, C = 10, 100, 2
|
||||
device = torch.device("cuda:0")
|
||||
points_list = []
|
||||
normals_list = []
|
||||
features_list = []
|
||||
valid = torch.randint(2, size=(N,), dtype=torch.uint8, device=device)
|
||||
for n in range(N):
|
||||
if valid[n]:
|
||||
p = torch.randint(
|
||||
3, high=P, size=(1,), dtype=torch.int32, device=device
|
||||
)[0]
|
||||
points = torch.rand((p, 3), dtype=torch.float32, device=device)
|
||||
normals = torch.rand((p, 3), dtype=torch.float32, device=device)
|
||||
features = torch.rand(
|
||||
(p, C), dtype=torch.float32, device=device
|
||||
)
|
||||
else:
|
||||
points = torch.tensor([], dtype=torch.float32, device=device)
|
||||
normals = torch.tensor([], dtype=torch.float32, device=device)
|
||||
features = torch.tensor([], dtype=torch.int64, device=device)
|
||||
points_list.append(points)
|
||||
normals_list.append(normals)
|
||||
features_list.append(features)
|
||||
|
||||
for with_normals in (False, True):
|
||||
for with_features in (False, True):
|
||||
this_features, this_normals = None, None
|
||||
if with_normals:
|
||||
this_normals = normals_list
|
||||
if with_features:
|
||||
this_features = features_list
|
||||
clouds = Pointclouds(
|
||||
points=points_list,
|
||||
normals=this_normals,
|
||||
features=this_features,
|
||||
)
|
||||
points_padded = clouds.points_padded()
|
||||
normals_padded = clouds.normals_padded()
|
||||
features_padded = clouds.features_padded()
|
||||
if not with_normals:
|
||||
self.assertIsNone(normals_padded)
|
||||
if not with_features:
|
||||
self.assertIsNone(features_padded)
|
||||
points_per_cloud = clouds.num_points_per_cloud()
|
||||
for n in range(N):
|
||||
p = len(points_list[n])
|
||||
if p > 0:
|
||||
self.assertClose(
|
||||
points_padded[n, :p, :], points_list[n]
|
||||
)
|
||||
if with_normals:
|
||||
self.assertClose(
|
||||
normals_padded[n, :p, :], normals_list[n]
|
||||
)
|
||||
if with_features:
|
||||
self.assertClose(
|
||||
features_padded[n, :p, :], features_list[n]
|
||||
)
|
||||
if points_padded.shape[1] > p:
|
||||
self.assertTrue(points_padded[n, p:, :].eq(0).all())
|
||||
if with_normals:
|
||||
self.assertTrue(
|
||||
normals_padded[n, p:, :].eq(0).all()
|
||||
)
|
||||
if with_features:
|
||||
self.assertTrue(
|
||||
features_padded[n, p:, :].eq(0).all()
|
||||
)
|
||||
self.assertTrue(points_per_cloud[n] == p)
|
||||
|
||||
def test_clone_list(self):
|
||||
N = 5
|
||||
clouds = self.init_cloud(N, 100, 5)
|
||||
for force in (False, True):
|
||||
if force:
|
||||
clouds.points_packed()
|
||||
|
||||
new_clouds = clouds.clone()
|
||||
|
||||
# Check cloned and original objects do not share tensors.
|
||||
self.assertSeparate(
|
||||
new_clouds.points_list()[0], clouds.points_list()[0]
|
||||
)
|
||||
self.assertSeparate(
|
||||
new_clouds.normals_list()[0], clouds.normals_list()[0]
|
||||
)
|
||||
self.assertSeparate(
|
||||
new_clouds.features_list()[0], clouds.features_list()[0]
|
||||
)
|
||||
for attrib in [
|
||||
"points_packed",
|
||||
"normals_packed",
|
||||
"features_packed",
|
||||
"points_padded",
|
||||
"normals_padded",
|
||||
"features_padded",
|
||||
]:
|
||||
self.assertSeparate(
|
||||
getattr(new_clouds, attrib)(), getattr(clouds, attrib)()
|
||||
)
|
||||
|
||||
self.assertCloudsEqual(clouds, new_clouds)
|
||||
|
||||
def test_clone_tensor(self):
|
||||
N = 5
|
||||
clouds = self.init_cloud(N, 100, 5, lists_to_tensors=True)
|
||||
for force in (False, True):
|
||||
if force:
|
||||
clouds.points_packed()
|
||||
|
||||
new_clouds = clouds.clone()
|
||||
|
||||
# Check cloned and original objects do not share tensors.
|
||||
self.assertSeparate(
|
||||
new_clouds.points_list()[0], clouds.points_list()[0]
|
||||
)
|
||||
self.assertSeparate(
|
||||
new_clouds.normals_list()[0], clouds.normals_list()[0]
|
||||
)
|
||||
self.assertSeparate(
|
||||
new_clouds.features_list()[0], clouds.features_list()[0]
|
||||
)
|
||||
for attrib in [
|
||||
"points_packed",
|
||||
"normals_packed",
|
||||
"features_packed",
|
||||
"points_padded",
|
||||
"normals_padded",
|
||||
"features_padded",
|
||||
]:
|
||||
self.assertSeparate(
|
||||
getattr(new_clouds, attrib)(), getattr(clouds, attrib)()
|
||||
)
|
||||
|
||||
self.assertCloudsEqual(clouds, new_clouds)
|
||||
|
||||
def assertCloudsEqual(self, cloud1, cloud2):
|
||||
N = len(cloud1)
|
||||
self.assertEqual(N, len(cloud2))
|
||||
|
||||
for i in range(N):
|
||||
self.assertClose(cloud1.points_list()[i], cloud2.points_list()[i])
|
||||
self.assertClose(cloud1.normals_list()[i], cloud2.normals_list()[i])
|
||||
self.assertClose(
|
||||
cloud1.features_list()[i], cloud2.features_list()[i]
|
||||
)
|
||||
has_normals = cloud1.normals_list() is not None
|
||||
self.assertTrue(has_normals == (cloud2.normals_list() is not None))
|
||||
has_features = cloud1.features_list() is not None
|
||||
self.assertTrue(has_features == (cloud2.features_list() is not None))
|
||||
|
||||
# check padded & packed
|
||||
self.assertClose(cloud1.points_padded(), cloud2.points_padded())
|
||||
self.assertClose(cloud1.points_packed(), cloud2.points_packed())
|
||||
if has_normals:
|
||||
self.assertClose(cloud1.normals_padded(), cloud2.normals_padded())
|
||||
self.assertClose(cloud1.normals_packed(), cloud2.normals_packed())
|
||||
if has_features:
|
||||
self.assertClose(cloud1.features_padded(), cloud2.features_padded())
|
||||
self.assertClose(cloud1.features_packed(), cloud2.features_packed())
|
||||
self.assertClose(
|
||||
cloud1.packed_to_cloud_idx(), cloud2.packed_to_cloud_idx()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud1.cloud_to_packed_first_idx(),
|
||||
cloud2.cloud_to_packed_first_idx(),
|
||||
)
|
||||
self.assertClose(
|
||||
cloud1.num_points_per_cloud(), cloud2.num_points_per_cloud()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud1.packed_to_cloud_idx(), cloud2.packed_to_cloud_idx()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud1.padded_to_packed_idx(), cloud2.padded_to_packed_idx()
|
||||
)
|
||||
self.assertTrue(all(cloud1.valid == cloud2.valid))
|
||||
self.assertTrue(cloud1.equisized == cloud2.equisized)
|
||||
|
||||
def test_offset(self):
|
||||
def naive_offset(clouds, offsets_packed):
|
||||
new_points_packed = clouds.points_packed() + offsets_packed
|
||||
new_points_list = list(
|
||||
new_points_packed.split(
|
||||
clouds.num_points_per_cloud().tolist(), 0
|
||||
)
|
||||
)
|
||||
return Pointclouds(
|
||||
points=new_points_list,
|
||||
normals=clouds.normals_list(),
|
||||
features=clouds.features_list(),
|
||||
)
|
||||
|
||||
N = 5
|
||||
clouds = self.init_cloud(N, 100, 10)
|
||||
all_p = clouds.points_packed().size(0)
|
||||
points_per_cloud = clouds.num_points_per_cloud()
|
||||
for force in (False, True):
|
||||
if force:
|
||||
clouds._compute_packed(refresh=True)
|
||||
clouds._compute_padded()
|
||||
clouds.padded_to_packed_idx()
|
||||
|
||||
deform = torch.rand(
|
||||
(all_p, 3), dtype=torch.float32, device=clouds.device
|
||||
)
|
||||
new_clouds_naive = naive_offset(clouds, deform)
|
||||
|
||||
new_clouds = clouds.offset(deform)
|
||||
|
||||
points_cumsum = torch.cumsum(points_per_cloud, 0).tolist()
|
||||
points_cumsum.insert(0, 0)
|
||||
for i in range(N):
|
||||
self.assertClose(
|
||||
new_clouds.points_list()[i],
|
||||
clouds.points_list()[i]
|
||||
+ deform[points_cumsum[i] : points_cumsum[i + 1]],
|
||||
)
|
||||
self.assertClose(
|
||||
clouds.normals_list()[i], new_clouds_naive.normals_list()[i]
|
||||
)
|
||||
self.assertClose(
|
||||
clouds.features_list()[i],
|
||||
new_clouds_naive.features_list()[i],
|
||||
)
|
||||
self.assertCloudsEqual(new_clouds, new_clouds_naive)
|
||||
|
||||
def test_scale(self):
|
||||
def naive_scale(cloud, scale):
|
||||
if not torch.is_tensor(scale):
|
||||
scale = torch.full(len(cloud), scale)
|
||||
new_points_list = [
|
||||
scale[i] * points.clone()
|
||||
for (i, points) in enumerate(cloud.points_list())
|
||||
]
|
||||
return Pointclouds(
|
||||
new_points_list, cloud.normals_list(), cloud.features_list()
|
||||
)
|
||||
|
||||
N = 5
|
||||
clouds = self.init_cloud(N, 100, 10)
|
||||
for force in (False, True):
|
||||
if force:
|
||||
clouds._compute_packed(refresh=True)
|
||||
clouds._compute_padded()
|
||||
clouds.padded_to_packed_idx()
|
||||
scales = torch.rand(N)
|
||||
new_clouds_naive = naive_scale(clouds, scales)
|
||||
new_clouds = clouds.scale(scales)
|
||||
for i in range(N):
|
||||
self.assertClose(
|
||||
scales[i] * clouds.points_list()[i],
|
||||
new_clouds.points_list()[i],
|
||||
)
|
||||
self.assertClose(
|
||||
clouds.normals_list()[i], new_clouds_naive.normals_list()[i]
|
||||
)
|
||||
self.assertClose(
|
||||
clouds.features_list()[i],
|
||||
new_clouds_naive.features_list()[i],
|
||||
)
|
||||
self.assertCloudsEqual(new_clouds, new_clouds_naive)
|
||||
|
||||
def test_extend_list(self):
|
||||
N = 10
|
||||
clouds = self.init_cloud(N, 100, 10)
|
||||
for force in (False, True):
|
||||
if force:
|
||||
# force some computes to happen
|
||||
clouds._compute_packed(refresh=True)
|
||||
clouds._compute_padded()
|
||||
clouds.padded_to_packed_idx()
|
||||
new_clouds = clouds.extend(N)
|
||||
self.assertEqual(len(clouds) * 10, len(new_clouds))
|
||||
for i in range(len(clouds)):
|
||||
for n in range(N):
|
||||
self.assertClose(
|
||||
clouds.points_list()[i],
|
||||
new_clouds.points_list()[i * N + n],
|
||||
)
|
||||
self.assertClose(
|
||||
clouds.normals_list()[i],
|
||||
new_clouds.normals_list()[i * N + n],
|
||||
)
|
||||
self.assertClose(
|
||||
clouds.features_list()[i],
|
||||
new_clouds.features_list()[i * N + n],
|
||||
)
|
||||
self.assertTrue(
|
||||
clouds.valid[i] == new_clouds.valid[i * N + n]
|
||||
)
|
||||
self.assertAllSeparate(
|
||||
clouds.points_list()
|
||||
+ new_clouds.points_list()
|
||||
+ clouds.normals_list()
|
||||
+ new_clouds.normals_list()
|
||||
+ clouds.features_list()
|
||||
+ new_clouds.features_list()
|
||||
)
|
||||
self.assertIsNone(new_clouds._points_packed)
|
||||
self.assertIsNone(new_clouds._normals_packed)
|
||||
self.assertIsNone(new_clouds._features_packed)
|
||||
self.assertIsNone(new_clouds._points_padded)
|
||||
self.assertIsNone(new_clouds._normals_padded)
|
||||
self.assertIsNone(new_clouds._features_padded)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
clouds.extend(N=-1)
|
||||
|
||||
def test_to_list(self):
|
||||
cloud = self.init_cloud(5, 100, 10)
|
||||
device = torch.device("cuda:1")
|
||||
|
||||
new_cloud = cloud.to(device)
|
||||
self.assertTrue(new_cloud.device == device)
|
||||
self.assertTrue(cloud.device == torch.device("cuda:0"))
|
||||
for attrib in [
|
||||
"points_padded",
|
||||
"points_packed",
|
||||
"normals_padded",
|
||||
"normals_packed",
|
||||
"features_padded",
|
||||
"features_packed",
|
||||
"num_points_per_cloud",
|
||||
"cloud_to_packed_first_idx",
|
||||
"padded_to_packed_idx",
|
||||
]:
|
||||
self.assertClose(
|
||||
getattr(new_cloud, attrib)().cpu(),
|
||||
getattr(cloud, attrib)().cpu(),
|
||||
)
|
||||
for i in range(len(cloud)):
|
||||
self.assertClose(
|
||||
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.features_list()[i].cpu(),
|
||||
new_cloud.features_list()[i].cpu(),
|
||||
)
|
||||
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
|
||||
self.assertTrue(cloud.equisized == new_cloud.equisized)
|
||||
self.assertTrue(cloud._N == new_cloud._N)
|
||||
self.assertTrue(cloud._P == new_cloud._P)
|
||||
self.assertTrue(cloud._C == new_cloud._C)
|
||||
|
||||
def test_to_tensor(self):
|
||||
cloud = self.init_cloud(5, 100, 10, lists_to_tensors=True)
|
||||
device = torch.device("cuda:1")
|
||||
|
||||
new_cloud = cloud.to(device)
|
||||
self.assertTrue(new_cloud.device == device)
|
||||
self.assertTrue(cloud.device == torch.device("cuda:0"))
|
||||
for attrib in [
|
||||
"points_padded",
|
||||
"points_packed",
|
||||
"normals_padded",
|
||||
"normals_packed",
|
||||
"features_padded",
|
||||
"features_packed",
|
||||
"num_points_per_cloud",
|
||||
"cloud_to_packed_first_idx",
|
||||
"padded_to_packed_idx",
|
||||
]:
|
||||
self.assertClose(
|
||||
getattr(new_cloud, attrib)().cpu(),
|
||||
getattr(cloud, attrib)().cpu(),
|
||||
)
|
||||
for i in range(len(cloud)):
|
||||
self.assertClose(
|
||||
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.features_list()[i].cpu(),
|
||||
new_cloud.features_list()[i].cpu(),
|
||||
)
|
||||
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
|
||||
self.assertTrue(cloud.equisized == new_cloud.equisized)
|
||||
self.assertTrue(cloud._N == new_cloud._N)
|
||||
self.assertTrue(cloud._P == new_cloud._P)
|
||||
self.assertTrue(cloud._C == new_cloud._C)
|
||||
|
||||
def test_split(self):
|
||||
clouds = self.init_cloud(5, 100, 10)
|
||||
split_sizes = [2, 3]
|
||||
split_clouds = clouds.split(split_sizes)
|
||||
self.assertEqual(len(split_clouds[0]), 2)
|
||||
self.assertTrue(
|
||||
split_clouds[0].points_list()
|
||||
== [clouds.get_cloud(0)[0], clouds.get_cloud(1)[0]]
|
||||
)
|
||||
self.assertEqual(len(split_clouds[1]), 3)
|
||||
self.assertTrue(
|
||||
split_clouds[1].points_list()
|
||||
== [
|
||||
clouds.get_cloud(2)[0],
|
||||
clouds.get_cloud(3)[0],
|
||||
clouds.get_cloud(4)[0],
|
||||
]
|
||||
)
|
||||
|
||||
split_sizes = [2, 0.3]
|
||||
with self.assertRaises(ValueError):
|
||||
clouds.split(split_sizes)
|
||||
|
||||
def test_get_cloud(self):
|
||||
clouds = self.init_cloud(2, 100, 10)
|
||||
for i in range(len(clouds)):
|
||||
points, normals, features = clouds.get_cloud(i)
|
||||
self.assertClose(points, clouds.points_list()[i])
|
||||
self.assertClose(normals, clouds.normals_list()[i])
|
||||
self.assertClose(features, clouds.features_list()[i])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
clouds.get_cloud(5)
|
||||
with self.assertRaises(ValueError):
|
||||
clouds.get_cloud(0.2)
|
||||
|
||||
def test_get_bounding_boxes(self):
|
||||
device = torch.device("cuda:0")
|
||||
points_list = []
|
||||
for size in [10]:
|
||||
points = torch.rand((size, 3), dtype=torch.float32, device=device)
|
||||
points_list.append(points)
|
||||
|
||||
mins = torch.min(points, dim=0)[0]
|
||||
maxs = torch.max(points, dim=0)[0]
|
||||
bboxes_gt = torch.stack([mins, maxs], dim=1).unsqueeze(0)
|
||||
clouds = Pointclouds(points_list)
|
||||
bboxes = clouds.get_bounding_boxes()
|
||||
self.assertClose(bboxes_gt, bboxes)
|
||||
|
||||
def test_padded_to_packed_idx(self):
|
||||
device = torch.device("cuda:0")
|
||||
points_list = []
|
||||
npoints = [10, 20, 30]
|
||||
for p in npoints:
|
||||
points = torch.rand((p, 3), dtype=torch.float32, device=device)
|
||||
points_list.append(points)
|
||||
|
||||
clouds = Pointclouds(points_list)
|
||||
|
||||
padded_to_packed_idx = clouds.padded_to_packed_idx()
|
||||
points_packed = clouds.points_packed()
|
||||
points_padded = clouds.points_padded()
|
||||
points_padded_flat = points_padded.view(-1, 3)
|
||||
|
||||
self.assertClose(
|
||||
points_padded_flat[padded_to_packed_idx], points_packed
|
||||
)
|
||||
|
||||
idx = padded_to_packed_idx.view(-1, 1).expand(-1, 3)
|
||||
self.assertClose(points_padded_flat.gather(0, idx), points_packed)
|
||||
|
||||
def test_getitem(self):
|
||||
device = torch.device("cuda:0")
|
||||
clouds = self.init_cloud(3, 10, 100)
|
||||
|
||||
def check_equal(selected, indices):
|
||||
for selectedIdx, index in indices:
|
||||
self.assertClose(
|
||||
selected.points_list()[selectedIdx],
|
||||
clouds.points_list()[index],
|
||||
)
|
||||
self.assertClose(
|
||||
selected.normals_list()[selectedIdx],
|
||||
clouds.normals_list()[index],
|
||||
)
|
||||
self.assertClose(
|
||||
selected.features_list()[selectedIdx],
|
||||
clouds.features_list()[index],
|
||||
)
|
||||
|
||||
# int index
|
||||
index = 1
|
||||
clouds_selected = clouds[index]
|
||||
self.assertEqual(len(clouds_selected), 1)
|
||||
check_equal(clouds_selected, [(0, 1)])
|
||||
|
||||
# list index
|
||||
index = [1, 2]
|
||||
clouds_selected = clouds[index]
|
||||
self.assertEqual(len(clouds_selected), len(index))
|
||||
check_equal(clouds_selected, enumerate(index))
|
||||
|
||||
# slice index
|
||||
index = slice(0, 2, 1)
|
||||
clouds_selected = clouds[index]
|
||||
self.assertEqual(len(clouds_selected), 2)
|
||||
check_equal(clouds_selected, [(0, 0), (1, 1)])
|
||||
|
||||
# bool tensor
|
||||
index = torch.tensor([1, 0, 1], dtype=torch.bool, device=device)
|
||||
clouds_selected = clouds[index]
|
||||
self.assertEqual(len(clouds_selected), index.sum())
|
||||
check_equal(clouds_selected, [(0, 0), (1, 2)])
|
||||
|
||||
# int tensor
|
||||
index = torch.tensor([1, 2], dtype=torch.int64, device=device)
|
||||
clouds_selected = clouds[index]
|
||||
self.assertEqual(len(clouds_selected), index.numel())
|
||||
check_equal(clouds_selected, enumerate(index.tolist()))
|
||||
|
||||
# invalid index
|
||||
index = torch.tensor([1, 0, 1], dtype=torch.float32, device=device)
|
||||
with self.assertRaises(IndexError):
|
||||
clouds_selected = clouds[index]
|
||||
index = 1.2
|
||||
with self.assertRaises(IndexError):
|
||||
clouds_selected = clouds[index]
|
||||
|
||||
def test_update_padded(self):
|
||||
N, P, C = 5, 100, 4
|
||||
for with_normfeat in (True, False):
|
||||
for with_new_normfeat in (True, False):
|
||||
clouds = self.init_cloud(
|
||||
N,
|
||||
P,
|
||||
C,
|
||||
with_normals=with_normfeat,
|
||||
with_features=with_normfeat,
|
||||
)
|
||||
|
||||
num_points_per_cloud = clouds.num_points_per_cloud()
|
||||
|
||||
# initialize new points, normals, features
|
||||
new_points = torch.rand(
|
||||
clouds.points_padded().shape, device=clouds.device
|
||||
)
|
||||
new_points_list = [
|
||||
new_points[i, : num_points_per_cloud[i]] for i in range(N)
|
||||
]
|
||||
new_normals, new_normals_list = None, None
|
||||
new_features, new_features_list = None, None
|
||||
if with_new_normfeat:
|
||||
new_normals = torch.rand(
|
||||
clouds.points_padded().shape, device=clouds.device
|
||||
)
|
||||
new_normals_list = [
|
||||
new_normals[i, : num_points_per_cloud[i]]
|
||||
for i in range(N)
|
||||
]
|
||||
feat_shape = [
|
||||
clouds.points_padded().shape[0],
|
||||
clouds.points_padded().shape[1],
|
||||
C,
|
||||
]
|
||||
new_features = torch.rand(feat_shape, device=clouds.device)
|
||||
new_features_list = [
|
||||
new_features[i, : num_points_per_cloud[i]]
|
||||
for i in range(N)
|
||||
]
|
||||
|
||||
# update
|
||||
new_clouds = clouds.update_padded(
|
||||
new_points, new_normals, new_features
|
||||
)
|
||||
self.assertIsNone(new_clouds._points_list)
|
||||
self.assertIsNone(new_clouds._points_packed)
|
||||
|
||||
self.assertEqual(new_clouds.equisized, clouds.equisized)
|
||||
self.assertTrue(all(new_clouds.valid == clouds.valid))
|
||||
|
||||
self.assertClose(new_clouds.points_padded(), new_points)
|
||||
self.assertClose(
|
||||
new_clouds.points_packed(), torch.cat(new_points_list)
|
||||
)
|
||||
for i in range(N):
|
||||
self.assertClose(
|
||||
new_clouds.points_list()[i], new_points_list[i]
|
||||
)
|
||||
|
||||
if with_new_normfeat:
|
||||
for i in range(N):
|
||||
self.assertClose(
|
||||
new_clouds.normals_list()[i], new_normals_list[i]
|
||||
)
|
||||
self.assertClose(
|
||||
new_clouds.features_list()[i], new_features_list[i]
|
||||
)
|
||||
self.assertClose(new_clouds.normals_padded(), new_normals)
|
||||
self.assertClose(
|
||||
new_clouds.normals_packed(), torch.cat(new_normals_list)
|
||||
)
|
||||
self.assertClose(new_clouds.features_padded(), new_features)
|
||||
self.assertClose(
|
||||
new_clouds.features_packed(),
|
||||
torch.cat(new_features_list),
|
||||
)
|
||||
else:
|
||||
if with_normfeat:
|
||||
for i in range(N):
|
||||
self.assertClose(
|
||||
new_clouds.normals_list()[i],
|
||||
clouds.normals_list()[i],
|
||||
)
|
||||
self.assertClose(
|
||||
new_clouds.features_list()[i],
|
||||
clouds.features_list()[i],
|
||||
)
|
||||
self.assertNotSeparate(
|
||||
new_clouds.normals_list()[i],
|
||||
clouds.normals_list()[i],
|
||||
)
|
||||
self.assertNotSeparate(
|
||||
new_clouds.features_list()[i],
|
||||
clouds.features_list()[i],
|
||||
)
|
||||
|
||||
self.assertClose(
|
||||
new_clouds.normals_padded(), clouds.normals_padded()
|
||||
)
|
||||
self.assertClose(
|
||||
new_clouds.normals_packed(), clouds.normals_packed()
|
||||
)
|
||||
self.assertClose(
|
||||
new_clouds.features_padded(),
|
||||
clouds.features_padded(),
|
||||
)
|
||||
self.assertClose(
|
||||
new_clouds.features_packed(),
|
||||
clouds.features_packed(),
|
||||
)
|
||||
self.assertNotSeparate(
|
||||
new_clouds.normals_padded(), clouds.normals_padded()
|
||||
)
|
||||
self.assertNotSeparate(
|
||||
new_clouds.features_padded(),
|
||||
clouds.features_padded(),
|
||||
)
|
||||
else:
|
||||
self.assertIsNone(new_clouds.normals_list())
|
||||
self.assertIsNone(new_clouds.features_list())
|
||||
self.assertIsNone(new_clouds.normals_padded())
|
||||
self.assertIsNone(new_clouds.features_padded())
|
||||
self.assertIsNone(new_clouds.normals_packed())
|
||||
self.assertIsNone(new_clouds.features_packed())
|
||||
|
||||
for attrib in [
|
||||
"num_points_per_cloud",
|
||||
"cloud_to_packed_first_idx",
|
||||
"padded_to_packed_idx",
|
||||
]:
|
||||
self.assertClose(
|
||||
getattr(new_clouds, attrib)(), getattr(clouds, attrib)()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def compute_packed_with_init(
|
||||
num_clouds: int = 10, max_p: int = 100, features: int = 300
|
||||
):
|
||||
clouds = TestPointclouds.init_cloud(num_clouds, max_p, features)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def compute_packed():
|
||||
clouds._compute_packed(refresh=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return compute_packed
|
||||
|
||||
@staticmethod
|
||||
def compute_padded_with_init(
|
||||
num_clouds: int = 10, max_p: int = 100, features: int = 300
|
||||
):
|
||||
clouds = TestPointclouds.init_cloud(num_clouds, max_p, features)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def compute_padded():
|
||||
clouds._compute_padded(refresh=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return compute_padded
|
525
tests/test_rasterize_points.py
Normal file
525
tests/test_rasterize_points.py
Normal file
@ -0,0 +1,525 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer.points.rasterize_points import (
|
||||
rasterize_points,
|
||||
rasterize_points_python,
|
||||
)
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
def test_python_simple_cpu(self):
|
||||
self._simple_test_case(
|
||||
rasterize_points_python, torch.device("cpu"), bin_size=-1
|
||||
)
|
||||
|
||||
def test_naive_simple_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._simple_test_case(rasterize_points, device)
|
||||
|
||||
def test_naive_simple_cuda(self):
|
||||
device = torch.device("cuda")
|
||||
self._simple_test_case(rasterize_points, device, bin_size=0)
|
||||
|
||||
def test_python_behind_camera(self):
|
||||
self._test_behind_camera(
|
||||
rasterize_points_python, torch.device("cpu"), bin_size=-1
|
||||
)
|
||||
|
||||
def test_cpu_behind_camera(self):
|
||||
self._test_behind_camera(rasterize_points, torch.device("cpu"))
|
||||
|
||||
def test_cuda_behind_camera(self):
|
||||
self._test_behind_camera(
|
||||
rasterize_points, torch.device("cuda"), bin_size=0
|
||||
)
|
||||
|
||||
def test_cpp_vs_naive_vs_binned(self):
|
||||
# Make sure that the backward pass runs for all pathways
|
||||
N = 2
|
||||
P = 1000
|
||||
image_size = 32
|
||||
radius = 0.1
|
||||
points_per_pixel = 3
|
||||
points1 = torch.randn(P, 3, requires_grad=True)
|
||||
points2 = torch.randn(int(P / 2), 3, requires_grad=True)
|
||||
pointclouds = Pointclouds(points=[points1, points2])
|
||||
grad_zbuf = torch.randn(N, image_size, image_size, points_per_pixel)
|
||||
grad_dists = torch.randn(N, image_size, image_size, points_per_pixel)
|
||||
|
||||
# Option I: CPU, naive
|
||||
idx1, zbuf1, dists1 = rasterize_points(
|
||||
pointclouds, image_size, radius, points_per_pixel, bin_size=0
|
||||
)
|
||||
loss = (zbuf1 * grad_zbuf).sum() + (dists1 * grad_dists).sum()
|
||||
loss.backward()
|
||||
grad1 = points1.grad.data.clone()
|
||||
|
||||
# Option II: CUDA, naive
|
||||
points1_cuda = points1.cuda().detach().clone().requires_grad_(True)
|
||||
points2_cuda = points2.cuda().detach().clone().requires_grad_(True)
|
||||
pointclouds = Pointclouds(points=[points1_cuda, points2_cuda])
|
||||
grad_zbuf = grad_zbuf.cuda()
|
||||
grad_dists = grad_dists.cuda()
|
||||
idx2, zbuf2, dists2 = rasterize_points(
|
||||
pointclouds, image_size, radius, points_per_pixel, bin_size=0
|
||||
)
|
||||
loss = (zbuf2 * grad_zbuf).sum() + (dists2 * grad_dists).sum()
|
||||
loss.backward()
|
||||
idx2 = idx2.data.cpu().clone()
|
||||
zbuf2 = zbuf2.data.cpu().clone()
|
||||
dists2 = dists2.data.cpu().clone()
|
||||
grad2 = points1_cuda.grad.data.cpu().clone()
|
||||
|
||||
# Option III: CUDA, binned
|
||||
points1_cuda = points1.cuda().detach().clone().requires_grad_(True)
|
||||
points2_cuda = points2.cuda().detach().clone().requires_grad_(True)
|
||||
pointclouds = Pointclouds(points=[points1_cuda, points2_cuda])
|
||||
idx3, zbuf3, dists3 = rasterize_points(
|
||||
pointclouds, image_size, radius, points_per_pixel, bin_size=32
|
||||
)
|
||||
loss = (zbuf3 * grad_zbuf).sum() + (dists3 * grad_dists).sum()
|
||||
points1.grad.data.zero_()
|
||||
loss.backward()
|
||||
idx3 = idx3.data.cpu().clone()
|
||||
zbuf3 = zbuf3.data.cpu().clone()
|
||||
dists3 = dists3.data.cpu().clone()
|
||||
grad3 = points1_cuda.grad.data.cpu().clone()
|
||||
|
||||
# Make sure everything was the same
|
||||
idx12_same = (idx1 == idx2).all().item()
|
||||
idx13_same = (idx1 == idx3).all().item()
|
||||
zbuf12_same = (zbuf1 == zbuf2).all().item()
|
||||
zbuf13_same = (zbuf1 == zbuf3).all().item()
|
||||
dists12_diff = (dists1 - dists2).abs().max().item()
|
||||
dists13_diff = (dists1 - dists3).abs().max().item()
|
||||
self.assertTrue(idx12_same)
|
||||
self.assertTrue(idx13_same)
|
||||
self.assertTrue(zbuf12_same)
|
||||
self.assertTrue(zbuf13_same)
|
||||
self.assertTrue(dists12_diff < 1e-6)
|
||||
self.assertTrue(dists13_diff < 1e-6)
|
||||
|
||||
diff12 = (grad1 - grad2).abs().max().item()
|
||||
diff13 = (grad1 - grad3).abs().max().item()
|
||||
diff23 = (grad2 - grad3).abs().max().item()
|
||||
self.assertTrue(diff12 < 5e-6)
|
||||
self.assertTrue(diff13 < 5e-6)
|
||||
self.assertTrue(diff23 < 5e-6)
|
||||
|
||||
def test_python_vs_cpu_naive(self):
|
||||
torch.manual_seed(231)
|
||||
image_size = 32
|
||||
radius = 0.1
|
||||
points_per_pixel = 3
|
||||
|
||||
# Test a batch of homogeneous point clouds.
|
||||
N = 2
|
||||
P = 17
|
||||
points = torch.randn(N, P, 3, requires_grad=True)
|
||||
pointclouds = Pointclouds(points=points)
|
||||
args = (pointclouds, image_size, radius, points_per_pixel)
|
||||
self._compare_impls(
|
||||
rasterize_points_python,
|
||||
rasterize_points,
|
||||
args,
|
||||
args,
|
||||
points,
|
||||
points,
|
||||
compare_grads=True,
|
||||
)
|
||||
|
||||
# Test a batch of heterogeneous point clouds.
|
||||
P2 = 10
|
||||
points1 = torch.randn(P, 3, requires_grad=True)
|
||||
points2 = torch.randn(P2, 3)
|
||||
pointclouds = Pointclouds(points=[points1, points2])
|
||||
args = (pointclouds, image_size, radius, points_per_pixel)
|
||||
self._compare_impls(
|
||||
rasterize_points_python,
|
||||
rasterize_points,
|
||||
args,
|
||||
args,
|
||||
points1, # check gradients for first element in batch
|
||||
points1,
|
||||
compare_grads=True,
|
||||
)
|
||||
|
||||
def test_cpu_vs_cuda_naive(self):
|
||||
torch.manual_seed(231)
|
||||
image_size = 64
|
||||
radius = 0.1
|
||||
points_per_pixel = 5
|
||||
|
||||
# Test homogeneous point cloud batch.
|
||||
N = 2
|
||||
P = 1000
|
||||
bin_size = 0
|
||||
points_cpu = torch.rand(N, P, 3, requires_grad=True)
|
||||
points_cuda = points_cpu.cuda().detach().requires_grad_(True)
|
||||
pointclouds_cpu = Pointclouds(points=points_cpu)
|
||||
pointclouds_cuda = Pointclouds(points=points_cuda)
|
||||
args_cpu = (
|
||||
pointclouds_cpu,
|
||||
image_size,
|
||||
radius,
|
||||
points_per_pixel,
|
||||
bin_size,
|
||||
)
|
||||
args_cuda = (
|
||||
pointclouds_cuda,
|
||||
image_size,
|
||||
radius,
|
||||
points_per_pixel,
|
||||
bin_size,
|
||||
)
|
||||
self._compare_impls(
|
||||
rasterize_points,
|
||||
rasterize_points,
|
||||
args_cpu,
|
||||
args_cuda,
|
||||
points_cpu,
|
||||
points_cuda,
|
||||
compare_grads=True,
|
||||
)
|
||||
|
||||
def _compare_impls(
|
||||
self,
|
||||
fn1,
|
||||
fn2,
|
||||
args1,
|
||||
args2,
|
||||
grad_var1=None,
|
||||
grad_var2=None,
|
||||
compare_grads=False,
|
||||
):
|
||||
idx1, zbuf1, dist1 = fn1(*args1)
|
||||
torch.manual_seed(231)
|
||||
grad_zbuf = torch.randn_like(zbuf1)
|
||||
grad_dist = torch.randn_like(dist1)
|
||||
loss = (zbuf1 * grad_zbuf).sum() + (dist1 * grad_dist).sum()
|
||||
if compare_grads:
|
||||
loss.backward()
|
||||
grad_points1 = grad_var1.grad.data.clone().cpu()
|
||||
|
||||
idx2, zbuf2, dist2 = fn2(*args2)
|
||||
grad_zbuf = grad_zbuf.to(zbuf2)
|
||||
grad_dist = grad_dist.to(dist2)
|
||||
loss = (zbuf2 * grad_zbuf).sum() + (dist2 * grad_dist).sum()
|
||||
if compare_grads:
|
||||
# clear points1.grad in case args1 and args2 reused the same tensor
|
||||
grad_var1.grad.data.zero_()
|
||||
loss.backward()
|
||||
grad_points2 = grad_var2.grad.data.clone().cpu()
|
||||
|
||||
self.assertEqual((idx1.cpu() == idx2.cpu()).all().item(), 1)
|
||||
self.assertEqual((zbuf1.cpu() == zbuf2.cpu()).all().item(), 1)
|
||||
self.assertClose(dist1.cpu(), dist2.cpu())
|
||||
if compare_grads:
|
||||
self.assertTrue(
|
||||
torch.allclose(grad_points1, grad_points2, atol=2e-6)
|
||||
)
|
||||
|
||||
def _test_behind_camera(self, rasterize_points_fn, device, bin_size=None):
|
||||
# Test case where all points are behind the camera -- nothing should
|
||||
# get rasterized
|
||||
N = 2
|
||||
P = 32
|
||||
xy = torch.randn(N, P, 2)
|
||||
z = torch.randn(N, P, 1).abs().mul(-1) # Make them all negative
|
||||
points = torch.cat([xy, z], dim=2).to(device)
|
||||
image_size = 16
|
||||
points_per_pixel = 3
|
||||
radius = 0.2
|
||||
idx_expected = torch.full(
|
||||
(N, 16, 16, 3), fill_value=-1, dtype=torch.int32, device=device
|
||||
)
|
||||
zbuf_expected = torch.full(
|
||||
(N, 16, 16, 3), fill_value=-1, dtype=torch.float32, device=device
|
||||
)
|
||||
dists_expected = zbuf_expected.clone()
|
||||
pointclouds = Pointclouds(points=points)
|
||||
if bin_size == -1:
|
||||
# simple python case with no binning
|
||||
idx, zbuf, dists = rasterize_points_fn(
|
||||
pointclouds, image_size, radius, points_per_pixel
|
||||
)
|
||||
else:
|
||||
idx, zbuf, dists = rasterize_points_fn(
|
||||
pointclouds, image_size, radius, points_per_pixel, bin_size
|
||||
)
|
||||
idx_same = (idx == idx_expected).all().item() == 1
|
||||
zbuf_same = (zbuf == zbuf_expected).all().item() == 1
|
||||
|
||||
self.assertTrue(idx_same)
|
||||
self.assertTrue(zbuf_same)
|
||||
self.assertTrue(torch.allclose(dists, dists_expected))
|
||||
|
||||
def _simple_test_case(self, rasterize_points_fn, device, bin_size=0):
|
||||
# Create two pointclouds with different numbers of points.
|
||||
# fmt: off
|
||||
points1 = torch.tensor(
|
||||
[
|
||||
[0.0, 0.0, 0.0], # noqa: E241
|
||||
[0.4, 0.0, 0.1], # noqa: E241
|
||||
[0.0, 0.4, 0.2], # noqa: E241
|
||||
[0.0, 0.0, -0.1], # noqa: E241 Points with negative z should be skippped
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
points2 = torch.tensor(
|
||||
[
|
||||
[0.0, 0.0, 0.0], # noqa: E241
|
||||
[0.4, 0.0, 0.1], # noqa: E241
|
||||
[0.0, 0.4, 0.2], # noqa: E241
|
||||
[0.0, 0.0, -0.1], # noqa: E241 Points with negative z should be skippped
|
||||
[0.0, 0.0, -0.7], # noqa: E241 Points with negative z should be skippped
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
# fmt: on
|
||||
pointclouds = Pointclouds(points=[points1, points2])
|
||||
|
||||
image_size = 5
|
||||
points_per_pixel = 2
|
||||
radius = 0.5
|
||||
|
||||
# The expected output values. Note that in the outputs, the world space
|
||||
# +Y is up, and the world space +X is left.
|
||||
idx1_expected = torch.full(
|
||||
(1, 5, 5, 2), fill_value=-1, dtype=torch.int32, device=device
|
||||
)
|
||||
# fmt: off
|
||||
idx1_expected[0, :, :, 0] = torch.tensor([
|
||||
[-1, -1, 2, -1, -1], # noqa: E241
|
||||
[-1, 1, 0, 2, -1], # noqa: E241
|
||||
[ 1, 0, 0, 0, -1], # noqa: E241 E201
|
||||
[-1, 1, 0, -1, -1], # noqa: E241
|
||||
[-1, -1, -1, -1, -1], # noqa: E241
|
||||
], device=device)
|
||||
idx1_expected[0, :, :, 1] = torch.tensor([
|
||||
[-1, -1, -1, -1, -1], # noqa: E241
|
||||
[-1, 2, 2, -1, -1], # noqa: E241
|
||||
[-1, 1, 1, -1, -1], # noqa: E241
|
||||
[-1, -1, -1, -1, -1], # noqa: E241
|
||||
[-1, -1, -1, -1, -1], # noqa: E241
|
||||
], device=device)
|
||||
# fmt: on
|
||||
|
||||
zbuf1_expected = torch.full(
|
||||
(1, 5, 5, 2), fill_value=100, dtype=torch.float32, device=device
|
||||
)
|
||||
# fmt: off
|
||||
zbuf1_expected[0, :, :, 0] = torch.tensor([
|
||||
[-1.0, -1.0, 0.2, -1.0, -1.0], # noqa: E241
|
||||
[-1.0, 0.1, 0.0, 0.2, -1.0], # noqa: E241
|
||||
[ 0.1, 0.0, 0.0, 0.0, -1.0], # noqa: E241 E201
|
||||
[-1.0, 0.1, 0.0, -1.0, -1.0], # noqa: E241
|
||||
[-1.0, -1.0, -1.0, -1.0, -1.0] # noqa: E241
|
||||
], device=device)
|
||||
zbuf1_expected[0, :, :, 1] = torch.tensor([
|
||||
[-1.0, -1.0, -1.0, -1.0, -1.0], # noqa: E241
|
||||
[-1.0, 0.2, 0.2, -1.0, -1.0], # noqa: E241
|
||||
[-1.0, 0.1, 0.1, -1.0, -1.0], # noqa: E241
|
||||
[-1.0, -1.0, -1.0, -1.0, -1.0], # noqa: E241
|
||||
[-1.0, -1.0, -1.0, -1.0, -1.0], # noqa: E241
|
||||
], device=device)
|
||||
# fmt: on
|
||||
|
||||
dists1_expected = torch.full(
|
||||
(1, 5, 5, 2), fill_value=0.0, dtype=torch.float32, device=device
|
||||
)
|
||||
# fmt: off
|
||||
dists1_expected[0, :, :, 0] = torch.tensor([
|
||||
[-1.00, -1.00, 0.16, -1.00, -1.00], # noqa: E241
|
||||
[-1.00, 0.16, 0.16, 0.16, -1.00], # noqa: E241
|
||||
[ 0.16, 0.16, 0.00, 0.16, -1.00], # noqa: E241 E201
|
||||
[-1.00, 0.16, 0.16, -1.00, -1.00], # noqa: E241
|
||||
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
|
||||
], device=device)
|
||||
dists1_expected[0, :, :, 1] = torch.tensor([
|
||||
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
|
||||
[-1.00, 0.16, 0.00, -1.00, -1.00], # noqa: E241
|
||||
[-1.00, 0.00, 0.16, -1.00, -1.00], # noqa: E241
|
||||
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
|
||||
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
|
||||
], device=device)
|
||||
# fmt: on
|
||||
|
||||
if bin_size == -1:
|
||||
# simple python case with no binning
|
||||
idx, zbuf, dists = rasterize_points_fn(
|
||||
pointclouds, image_size, radius, points_per_pixel
|
||||
)
|
||||
else:
|
||||
idx, zbuf, dists = rasterize_points_fn(
|
||||
pointclouds, image_size, radius, points_per_pixel, bin_size
|
||||
)
|
||||
|
||||
# check first point cloud
|
||||
idx_same = (idx[0, ...] == idx1_expected).all().item() == 1
|
||||
if idx_same == 0:
|
||||
print(idx[0, :, :, 0])
|
||||
print(idx[0, :, :, 1])
|
||||
zbuf_same = (zbuf[0, ...] == zbuf1_expected).all().item() == 1
|
||||
dist_same = torch.allclose(dists[0, ...], dists1_expected)
|
||||
self.assertTrue(idx_same)
|
||||
self.assertTrue(zbuf_same)
|
||||
self.assertTrue(dist_same)
|
||||
|
||||
# Check second point cloud - the indices in idx refer to points in the
|
||||
# pointclouds.points_packed() tensor. In the second point cloud,
|
||||
# two points are behind the screen - the expected indices are the same
|
||||
# the first pointcloud but offset by the number of points in the
|
||||
# first pointcloud.
|
||||
num_points_per_cloud = pointclouds.num_points_per_cloud()
|
||||
idx1_expected[idx1_expected >= 0] += num_points_per_cloud[0]
|
||||
|
||||
idx_same = (idx[1, ...] == idx1_expected).all().item() == 1
|
||||
zbuf_same = (zbuf[1, ...] == zbuf1_expected).all().item() == 1
|
||||
self.assertTrue(idx_same)
|
||||
self.assertTrue(zbuf_same)
|
||||
self.assertTrue(torch.allclose(dists[1, ...], dists1_expected))
|
||||
|
||||
def test_coarse_cpu(self):
|
||||
return self._test_coarse_rasterize(torch.device("cpu"))
|
||||
|
||||
def test_coarse_cuda(self):
|
||||
return self._test_coarse_rasterize(torch.device("cuda"))
|
||||
|
||||
def test_compare_coarse_cpu_vs_cuda(self):
|
||||
torch.manual_seed(231)
|
||||
N = 3
|
||||
max_P = 1000
|
||||
image_size = 64
|
||||
radius = 0.1
|
||||
bin_size = 16
|
||||
max_points_per_bin = 500
|
||||
|
||||
# create heterogeneous point clouds
|
||||
points = []
|
||||
for _ in range(N):
|
||||
p = np.random.choice(max_P)
|
||||
points.append(torch.randn(p, 3))
|
||||
|
||||
pointclouds = Pointclouds(points=points)
|
||||
points_packed = pointclouds.points_packed()
|
||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds.num_points_per_cloud()
|
||||
args = (
|
||||
points_packed,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
bin_size,
|
||||
max_points_per_bin,
|
||||
)
|
||||
bp_cpu = _C._rasterize_points_coarse(*args)
|
||||
|
||||
pointclouds_cuda = pointclouds.to("cuda:0")
|
||||
points_packed = pointclouds_cuda.points_packed()
|
||||
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()
|
||||
args = (
|
||||
points_packed,
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size,
|
||||
radius,
|
||||
bin_size,
|
||||
max_points_per_bin,
|
||||
)
|
||||
bp_cuda = _C._rasterize_points_coarse(*args)
|
||||
|
||||
# Bin points might not be the same: CUDA version might write them in
|
||||
# any order. But if we sort the non-(-1) elements of the CUDA output
|
||||
# then they should be the same.
|
||||
for n in range(N):
|
||||
for by in range(bp_cpu.shape[1]):
|
||||
for bx in range(bp_cpu.shape[2]):
|
||||
K = (bp_cpu[n, by, bx] != -1).sum().item()
|
||||
idxs_cpu = bp_cpu[n, by, bx].tolist()
|
||||
idxs_cuda = bp_cuda[n, by, bx].tolist()
|
||||
idxs_cuda[:K] = sorted(idxs_cuda[:K])
|
||||
self.assertEqual(idxs_cpu, idxs_cuda)
|
||||
|
||||
def _test_coarse_rasterize(self, device):
|
||||
#
|
||||
# Note that +Y is up and +X is left in the diagram below.
|
||||
#
|
||||
# (4) |2
|
||||
# |
|
||||
# |
|
||||
# |
|
||||
# |1
|
||||
# |
|
||||
# (1) |
|
||||
# | (2)
|
||||
# ____________(0)__(5)___________________
|
||||
# 2 1 | -1 -2
|
||||
# |
|
||||
# (3) |
|
||||
# |
|
||||
# |-1
|
||||
# |
|
||||
#
|
||||
# Locations of the points are shown by o. The screen bounding box
|
||||
# is between [-1, 1] in both the x and y directions.
|
||||
#
|
||||
# These points are interesting because:
|
||||
# (0) Falls into two bins;
|
||||
# (1) and (2) fall into one bin;
|
||||
# (3) is out-of-bounds, but its disk is in-bounds;
|
||||
# (4) is out-of-bounds, and its entire disk is also out-of-bounds
|
||||
# (5) has a negative z-value, so it should be skipped
|
||||
# fmt: off
|
||||
points = torch.tensor(
|
||||
[
|
||||
[ 0.5, 0.0, 0.0], # noqa: E241, E201
|
||||
[ 0.5, 0.5, 0.1], # noqa: E241, E201
|
||||
[-0.3, 0.4, 0.0], # noqa: E241
|
||||
[ 1.1, -0.5, 0.2], # noqa: E241, E201
|
||||
[ 2.0, 2.0, 0.3], # noqa: E241, E201
|
||||
[ 0.0, 0.0, -0.1], # noqa: E241, E201
|
||||
],
|
||||
device=device
|
||||
)
|
||||
# fmt: on
|
||||
image_size = 16
|
||||
radius = 0.2
|
||||
bin_size = 8
|
||||
max_points_per_bin = 5
|
||||
|
||||
bin_points_expected = -1 * torch.ones(
|
||||
1, 2, 2, 5, dtype=torch.int32, device=device
|
||||
)
|
||||
# Note that the order is only deterministic here for CUDA if all points
|
||||
# fit in one chunk. This will the the case for this small example, but
|
||||
# to properly exercise coordianted writes among multiple chunks we need
|
||||
# to use a bigger test case.
|
||||
bin_points_expected[0, 1, 0, :2] = torch.tensor([0, 3])
|
||||
bin_points_expected[0, 0, 1, 0] = torch.tensor([2])
|
||||
bin_points_expected[0, 0, 0, :2] = torch.tensor([0, 1])
|
||||
|
||||
pointclouds = Pointclouds(points=[points])
|
||||
args = (
|
||||
pointclouds.points_packed(),
|
||||
pointclouds.cloud_to_packed_first_idx(),
|
||||
pointclouds.num_points_per_cloud(),
|
||||
image_size,
|
||||
radius,
|
||||
bin_size,
|
||||
max_points_per_bin,
|
||||
)
|
||||
bin_points = _C._rasterize_points_coarse(*args)
|
||||
bin_points_same = (bin_points == bin_points_expected).all()
|
||||
self.assertTrue(bin_points_same.item() == 1)
|
Loading…
x
Reference in New Issue
Block a user