diff --git a/pytorch3d/renderer/points/__init__.py b/pytorch3d/renderer/points/__init__.py new file mode 100644 index 00000000..40539064 --- /dev/null +++ b/pytorch3d/renderer/points/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. diff --git a/pytorch3d/renderer/points/rasterize_points.py b/pytorch3d/renderer/points/rasterize_points.py new file mode 100644 index 00000000..b35d83c1 --- /dev/null +++ b/pytorch3d/renderer/points/rasterize_points.py @@ -0,0 +1,227 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Optional +import torch + +from pytorch3d import _C +from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc + + +# TODO(jcjohns): Support non-square images +def rasterize_points( + pointclouds, + image_size: int = 256, + radius: float = 0.01, + points_per_pixel: int = 8, + bin_size: Optional[int] = None, + max_points_per_bin: Optional[int] = None, +): + """ + Pointcloud rasterization + + Args: + pointclouds: A Pointclouds object representing a batch of point clouds to be + rasterized. This is a batch of N pointclouds, where each point cloud + can have a different number of points; the coordinates of each point + are (x, y, z). The coordinates are expected to + be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at + (0, 0, 0); the x-axis goes from left-to-right, the y-axis goes from + top-to-bottom, and the z-axis goes from back-to-front. + image_size: Integer giving the resolution of the rasterized image + radius (Optional): Float giving the radius (in NDC units) of the disk to + be rasterized for each point. + points_per_pixel (Optional): We will keep track of this many points per + pixel, returning the nearest points_per_pixel points along the z-axis + bin_size: Size of bins to use for coarse-to-fine rasterization. Setting + bin_size=0 uses naive rasterization; setting bin_size=None attempts to + set it heuristically based on the shape of the input. This should not + affect the output, but can affect the speed of the forward pass. + points_per_bin: Only applicable when using coarse-to-fine rasterization + (bin_size > 0); this is the maxiumum number of points allowed within each + bin. If more than this many points actually fall into a bin, an error + will be raised. This should not affect the output values, but can affect + the memory usage in the forward pass. + + Returns: + 3-element tuple containing + + - **idx**: int32 Tensor of shape (N, image_size, image_size, points_per_pixel) + giving the indices of the nearest points at each pixel, in ascending + z-order. Concretely `idx[n, y, x, k] = p` means that `points[p]` is the kth + closest point (along the z-direction) to pixel (y, x) - note that points + represents the packed points of shape (P, 3). + Pixels that are hit by fewer than points_per_pixel are padded with -1. + - **zbuf**: Tensor of shape (N, image_size, image_size, points_per_pixel) + giving the z-coordinates of the nearest points at each pixel, sorted in + z-order. Concretely, if `idx[n, y, x, k] = p` then + `zbuf[n, y, x, k] = points[n, p, 2]`. Pixels hit by fewer than + points_per_pixel are padded with -1 + - **dists2**: Tensor of shape (N, image_size, image_size, points_per_pixel) + giving the squared Euclidean distance (in NDC units) in the x/y plane + for each point closest to the pixel. Concretely if `idx[n, y, x, k] = p` + then `dists[n, y, x, k]` is the squared distance between the pixel (y, x) + and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer + than points_per_pixel are padded with -1. + """ + points_packed = pointclouds.points_packed() + cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() + num_points_per_cloud = pointclouds.num_points_per_cloud() + + if bin_size is None: + if not points_packed.is_cuda: + # Binned CPU rasterization not fully implemented + bin_size = 0 + else: + # TODO: These heuristics are not well-thought out! + if image_size <= 64: + bin_size = 8 + elif image_size <= 256: + bin_size = 16 + elif image_size <= 512: + bin_size = 32 + elif image_size <= 1024: + bin_size = 64 + + if max_points_per_bin is None: + max_points_per_bin = int(max(10000, points_packed.shape[0] / 5)) + + # Function.apply cannot take keyword args, so we handle defaults in this + # wrapper and call apply with positional args only + return _RasterizePoints.apply( + points_packed, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + points_per_pixel, + bin_size, + max_points_per_bin, + ) + + +class _RasterizePoints(torch.autograd.Function): + @staticmethod + def forward( + ctx, + points, # (P, 3) + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size: int = 256, + radius: float = 0.01, + points_per_pixel: int = 8, + bin_size: int = 0, + max_points_per_bin: int = 0, + ): + # TODO: Add better error handling for when there are more than + # max_points_per_bin in any bin. + args = ( + points, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + points_per_pixel, + bin_size, + max_points_per_bin, + ) + idx, zbuf, dists = _C.rasterize_points(*args) + ctx.save_for_backward(points, idx) + return idx, zbuf, dists + + @staticmethod + def backward(ctx, grad_idx, grad_zbuf, grad_dists): + grad_points = None + grad_cloud_to_packed_first_idx = None + grad_num_points_per_cloud = None + grad_image_size = None + grad_radius = None + grad_points_per_pixel = None + grad_bin_size = None + grad_max_points_per_bin = None + points, idx = ctx.saved_tensors + args = (points, idx, grad_zbuf, grad_dists) + grad_points = _C.rasterize_points_backward(*args) + grads = ( + grad_points, + grad_cloud_to_packed_first_idx, + grad_num_points_per_cloud, + grad_image_size, + grad_radius, + grad_points_per_pixel, + grad_bin_size, + grad_max_points_per_bin, + ) + return grads + + +def rasterize_points_python( + pointclouds, + image_size: int = 256, + radius: float = 0.01, + points_per_pixel: int = 8, +): + """ + Naive pure PyTorch implementation of pointcloud rasterization. + + Inputs / Outputs: Same as above + """ + N = len(pointclouds) + S, K = image_size, points_per_pixel + device = pointclouds.device + + points_packed = pointclouds.points_packed() + cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() + num_points_per_cloud = pointclouds.num_points_per_cloud() + + # Intialize output tensors. + point_idxs = torch.full( + (N, S, S, K), fill_value=-1, dtype=torch.int32, device=device + ) + zbuf = torch.full( + (N, S, S, K), fill_value=-1, dtype=torch.float32, device=device + ) + pix_dists = torch.full( + (N, S, S, K), fill_value=-1, dtype=torch.float32, device=device + ) + + # NDC is from [-1, 1]. Get pixel size using specified image size. + radius2 = radius * radius + + # Iterate through the batch of point clouds. + for n in range(N): + point_start_idx = cloud_to_packed_first_idx[n] + point_stop_idx = point_start_idx + num_points_per_cloud[n] + + # Iterate through the horizontal lines of the image from top to bottom. + for yi in range(S): + # Y coordinate of one end of the image. Reverse the ordering + # of yi so that +Y is pointing up in the image. + yfix = S - 1 - yi + yf = pix_to_ndc(yfix, S) + + # Iterate through pixels on this horizontal line, left to right. + for xi in range(S): + # X coordinate of one end of the image. Reverse the ordering + # of xi so that +X is pointing to the left in the image. + xfix = S - 1 - xi + xf = pix_to_ndc(xfix, S) + + top_k_points = [] + # Check whether each point in the batch affects this pixel. + for p in range(point_start_idx, point_stop_idx): + px, py, pz = points_packed[p, :] + if pz < 0: + continue + dx = px - xf + dy = py - yf + dist2 = dx * dx + dy * dy + if dist2 < radius2: + top_k_points.append((pz, p, dist2)) + top_k_points.sort() + if len(top_k_points) > K: + top_k_points = top_k_points[:K] + for k, (pz, p, dist2) in enumerate(top_k_points): + zbuf[n, yi, xi, k] = pz + point_idxs[n, yi, xi, k] = p + pix_dists[n, yi, xi, k] = dist2 + return point_idxs, zbuf, pix_dists diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py new file mode 100644 index 00000000..bf64c6c9 --- /dev/null +++ b/pytorch3d/structures/pointclouds.py @@ -0,0 +1,992 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import torch + +from . import utils as struct_utils + + +class Pointclouds(object): + """ + This class provides functions for working with batches of 3d point clouds, + and converting between representations. + + Within Pointclouds, there are three different representations of the data. + + List + - only used for input as a starting point to convert to other representations. + Padded + - has specific batch dimension. + Packed + - no batch dimension. + - has auxillary variables used to index into the padded representation. + + Example + + Input list of points = [[P_1], [P_2], ... , [P_N]] + where P_1, ... , P_N are the number of points in each cloud and N is the + number of clouds. + + # SPHINX IGNORE + List | Padded | Packed + ---------------------------|-------------------------|------------------------ + [[P_1], ... , [P_N]] | size = (N, max(P_n), 3) | size = (sum(P_n), 3) + | | + Example for locations | | + or colors: | | + | | + P_1 = 3, P_2 = 4, P_3 = 5 | size = (3, 5, 3) | size = (12, 3) + | | + List([ | tensor([ | tensor([ + [ | [ | [0.1, 0.3, 0.5], + [0.1, 0.3, 0.5], | [0.1, 0.3, 0.5], | [0.5, 0.2, 0.1], + [0.5, 0.2, 0.1], | [0.5, 0.2, 0.1], | [0.6, 0.8, 0.7], + [0.6, 0.8, 0.7] | [0.6, 0.8, 0.7], | [0.1, 0.3, 0.3], + ], | [0, 0, 0], | [0.6, 0.7, 0.8], + [ | [0, 0, 0] | [0.2, 0.3, 0.4], + [0.1, 0.3, 0.3], | ], | [0.1, 0.5, 0.3], + [0.6, 0.7, 0.8], | [ | [0.7, 0.3, 0.6], + [0.2, 0.3, 0.4], | [0.1, 0.3, 0.3], | [0.2, 0.4, 0.8], + [0.1, 0.5, 0.3] | [0.6, 0.7, 0.8], | [0.9, 0.5, 0.2], + ], | [0.2, 0.3, 0.4], | [0.2, 0.3, 0.4], + [ | [0.1, 0.5, 0.3], | [0.9, 0.3, 0.8], + [0.7, 0.3, 0.6], | [0, 0, 0] | ]) + [0.2, 0.4, 0.8], | ], | + [0.9, 0.5, 0.2], | [ | + [0.2, 0.3, 0.4], | [0.7, 0.3, 0.6], | + [0.9, 0.3, 0.8], | [0.2, 0.4, 0.8], | + ] | [0.9, 0.5, 0.2], | + ]) | [0.2, 0.3, 0.4], | + | [0.9, 0.3, 0.8] | + | ] | + | ]) | + ----------------------------------------------------------------------------- + + Auxillary variables for packed representation + + Name | Size | Example from above + -------------------------------|---------------------|----------------------- + | | + packed_to_cloud_idx | size = (sum(P_n)) | tensor([ + | | 0, 0, 0, 1, 1, 1, + | | 1, 2, 2, 2, 2, 2 + | | )] + | | size = (12) + | | + cloud_to_packed_first_idx | size = (N) | tensor([0, 3, 7]) + | | size = (3) + | | + num_points_per_cloud | size = (N) | tensor([3, 4, 5]) + | | size = (3) + | | + padded_to_packed_idx | size = (sum(P_n)) | tensor([ + | | 0, 1, 2, 5, 6, 7, + | | 8, 10, 11, 12, 13, + | | 14 + | | )] + | | size = (12) + ----------------------------------------------------------------------------- + # SPHINX IGNORE + """ + + _INTERNAL_TENSORS = [ + "_points_packed", + "_points_padded", + "_normals_packed", + "_normals_padded", + "_features_packed", + "_features_padded", + "_packed_to_cloud_idx", + "_cloud_to_packed_first_idx", + "_num_points_per_cloud", + "_padded_to_packed_idx", + "valid", + "equisized", + ] + + def __init__(self, points, normals=None, features=None): + """ + Args: + points: + Can be either + + - List where each element is a tensor of shape (num_points, 3) + containing the (x, y, z) coordinates of each point. + - Padded float tensor with shape (num_clouds, num_points, 3). + normals: + Can be either + + - List where each element is a tensor of shape (num_points, 3) + containing the normal vector for each point. + - Padded float tensor of shape (num_clouds, num_points, 3). + features: + Can be either + + - List where each element is a tensor of shape (num_points, C) + containing the features for the points in the cloud. + - Padded float tensor of shape (num_clouds, num_points, C). + where C is the number of channels in the features. + For example 3 for RGB color. + + Refer to comments above for descriptions of List and Padded + representations. + """ + self.device = None + + # Indicates whether the clouds in the list/batch have the same number + # of points. + self.equisized = False + + # Boolean indicator for each cloud in the batch. + # True if cloud has non zero number of points, False otherwise. + self.valid = None + + self._N = 0 # batch size (number of clouds) + self._P = 0 # (max) number of points per cloud + self._C = None # number of channels in the features + + # List of Tensors of points and features. + self._points_list = None + self._normals_list = None + self._features_list = None + + # Number of points per cloud. + self._num_points_per_cloud = None # N + + # Packed representation. + self._points_packed = None # (sum(P_n), 3) + self._normals_packed = None # (sum(P_n), 3) + self._features_packed = None # (sum(P_n), C) + + self._packed_to_cloud_idx = None # sum(P_n) + + # Index of each cloud's first point in the packed points. + # Assumes packing is sequential. + self._cloud_to_packed_first_idx = None # N + + # Padded representation. + self._points_padded = None # (N, max(P_n), 3) + self._normals_padded = None # (N, max(P_n), 3) + self._features_padded = None # (N, max(P_n), C) + + # Index to convert points from flattened padded to packed. + self._padded_to_packed_idx = None # N * max_P + + # Identify type of points. + if isinstance(points, list): + self._points_list = points + self._N = len(self._points_list) + self.device = torch.device("cpu") + self.valid = torch.zeros( + (self._N,), dtype=torch.bool, device=self.device + ) + self._num_points_per_cloud = [] + + if self._N > 0: + for p in self._points_list: + if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3): + raise ValueError( + "Clouds in list must be of shape Px3 or empty" + ) + + self.device = self._points_list[0].device + num_points_per_cloud = torch.tensor( + [len(p) for p in self._points_list], device=self.device + ) + self._P = num_points_per_cloud.max() + self.valid = torch.tensor( + [len(p) > 0 for p in self._points_list], + dtype=torch.bool, + device=self.device, + ) + + if len(num_points_per_cloud.unique()) == 1: + self.equisized = True + self._num_points_per_cloud = num_points_per_cloud + + elif torch.is_tensor(points): + if points.dim() != 3 or points.shape[2] != 3: + raise ValueError("Points tensor has incorrect dimensions.") + self._points_padded = points + self._N = self._points_padded.shape[0] + self._P = self._points_padded.shape[1] + self.device = self._points_padded.device + self.valid = torch.ones( + (self._N,), dtype=torch.bool, device=self.device + ) + self._num_points_per_cloud = torch.tensor( + [self._P] * self._N, device=self.device + ) + self.equisized = True + else: + raise ValueError( + "Points must be either a list or a tensor with \ + shape (batch_size, P, 3) where P is the maximum number of \ + points in a cloud." + ) + + # parse normals + normals_parsed = self._parse_auxiliary_input(normals) + self._normals_list, self._normals_padded, normals_C = normals_parsed + if normals_C is not None and normals_C != 3: + raise ValueError("Normals are expected to be 3-dimensional") + + # parse features + features_parsed = self._parse_auxiliary_input(features) + self._features_list, self._features_padded, features_C = features_parsed + if features_C is not None: + self._C = features_C + + def _parse_auxiliary_input(self, aux_input): + """ + Interpret the auxiliary inputs (normals, features) given to __init__. + + Args: + aux_input: + Can be either + + - List where each element is a tensor of shape (num_points, C) + containing the features for the points in the cloud. + - Padded float tensor of shape (num_clouds, num_points, C). + For normals, C = 3 + + Returns: + 3-element tuple of list, padded, num_channels. + If aux_input is list, then padded is None. If aux_input is a tensor, then list is None. + """ + if aux_input is None or self._N == 0: + return None, None, None + + aux_input_C = None + + if isinstance(aux_input, list): + if len(aux_input) != self._N: + raise ValueError( + "Points and auxiliary input must be the same length." + ) + for p, d in zip(self._num_points_per_cloud, aux_input): + if p != d.shape[0]: + raise ValueError( + "A cloud has mismatched numbers of points and inputs" + ) + if p > 0: + if d.dim() != 2: + raise ValueError( + "A cloud auxiliary input must be of shape PxC or empty" + ) + if aux_input_C is None: + aux_input_C = d.shape[1] + if aux_input_C != d.shape[1]: + raise ValueError( + "The clouds must have the same number of channels" + ) + return aux_input, None, aux_input_C + elif torch.is_tensor(aux_input): + if aux_input.dim() != 3: + raise ValueError( + "Auxiliary input tensor has incorrect dimensions." + ) + if self._N != aux_input.shape[0]: + raise ValueError("Points and inputs must be the same length.") + if self._P != aux_input.shape[1]: + raise ValueError( + "Inputs tensor must have the right maximum \ + number of points in each cloud." + ) + aux_input_C = aux_input.shape[2] + return None, aux_input, aux_input_C + else: + raise ValueError( + "Auxiliary input must be either a list or a tensor with \ + shape (batch_size, P, C) where P is the maximum number of \ + points in a cloud." + ) + + def __len__(self): + return self._N + + def __getitem__(self, index): + """ + Args: + index: Specifying the index of the cloud to retrieve. + Can be an int, slice, list of ints or a boolean tensor. + + Returns: + Pointclouds object with selected clouds. The tensors are not cloned. + """ + normals, features = None, None + if isinstance(index, int): + points = [self.points_list()[index]] + if self.normals_list() is not None: + normals = [self.normals_list()[index]] + if self.features_list() is not None: + features = [self.features_list()[index]] + elif isinstance(index, slice): + points = self.points_list()[index] + if self.normals_list() is not None: + normals = self.normals_list()[index] + if self.features_list() is not None: + features = self.features_list()[index] + elif isinstance(index, list): + points = [self.points_list()[i] for i in index] + if self.normals_list() is not None: + normals = [self.normals_list()[i] for i in index] + if self.features_list() is not None: + features = [self.features_list()[i] for i in index] + elif isinstance(index, torch.Tensor): + if index.dim() != 1 or index.dtype.is_floating_point: + raise IndexError(index) + # NOTE consider converting index to cpu for efficiency + if index.dtype == torch.bool: + # advanced indexing on a single dimension + index = index.nonzero() + index = index.squeeze(1) if index.numel() > 0 else index + index = index.tolist() + points = [self.points_list()[i] for i in index] + if self.normals_list() is not None: + normals = [self.normals_list()[i] for i in index] + if self.features_list() is not None: + features = [self.features_list()[i] for i in index] + else: + raise IndexError(index) + + return Pointclouds(points=points, normals=normals, features=features) + + def isempty(self) -> bool: + """ + Checks whether any cloud is valid. + + Returns: + bool indicating whether there is any data. + """ + return self._N == 0 or self.valid.eq(False).all() + + def points_list(self): + """ + Get the list representation of the points. + + Returns: + list of tensors of points of shape (P_n, 3). + """ + if self._points_list is None: + assert ( + self._points_padded is not None + ), "points_padded is required to compute points_list." + points_list = [] + for i in range(self._N): + points_list.append( + self._points_padded[i, : self.num_points_per_cloud()[i]] + ) + self._points_list = points_list + return self._points_list + + def normals_list(self): + """ + Get the list representation of the normals. + + Returns: + list of tensors of normals of shape (P_n, 3). + """ + if self._normals_list is None: + if self._normals_padded is None: + # No normals provided so return None + return None + self._normals_list = [] + for i in range(self._N): + self._normals_list.append( + self._normals_padded[i, : self.num_points_per_cloud()[i]] + ) + return self._normals_list + + def features_list(self): + """ + Get the list representation of the features. + + Returns: + list of tensors of features of shape (P_n, C). + """ + if self._features_list is None: + if self._features_padded is None: + # No features provided so return None + return None + self._features_list = [] + for i in range(self._N): + self._features_list.append( + self._features_padded[i, : self.num_points_per_cloud()[i]] + ) + return self._features_list + + def points_packed(self): + """ + Get the packed representation of the points. + + Returns: + tensor of points of shape (sum(P_n), 3). + """ + self._compute_packed() + return self._points_packed + + def normals_packed(self): + """ + Get the packed representation of the normals. + + Returns: + tensor of normals of shape (sum(P_n), 3). + """ + self._compute_packed() + return self._normals_packed + + def features_packed(self): + """ + Get the packed representation of the features. + + Returns: + tensor of features of shape (sum(P_n), C). + """ + self._compute_packed() + return self._features_packed + + def packed_to_cloud_idx(self): + """ + Return a 1D tensor x with length equal to the total number of points. + packed_to_cloud_idx()[i] gives the index of the cloud which contains + points_packed()[i]. + + Returns: + 1D tensor of indices. + """ + self._compute_packed() + return self._packed_to_cloud_idx + + def cloud_to_packed_first_idx(self): + """ + Return a 1D tensor x with length equal to the number of clouds such that + the first point of the ith cloud is points_packed[x[i]]. + + Returns: + 1D tensor of indices of first items. + """ + self._compute_packed() + return self._cloud_to_packed_first_idx + + def num_points_per_cloud(self): + """ + Return a 1D tensor x with length equal to the number of clouds giving + the number of points in each cloud. + + Returns: + 1D tensor of sizes. + """ + return self._num_points_per_cloud + + def points_padded(self): + """ + Get the padded representation of the points. + + Returns: + tensor of points of shape (N, max(P_n), 3). + """ + self._compute_padded() + return self._points_padded + + def normals_padded(self): + """ + Get the padded representation of the normals. + + Returns: + tensor of normals of shape (N, max(P_n), 3). + """ + self._compute_padded() + return self._normals_padded + + def features_padded(self): + """ + Get the padded representation of the features. + + Returns: + tensor of features of shape (N, max(P_n), 3). + """ + self._compute_padded() + return self._features_padded + + def padded_to_packed_idx(self): + """ + Return a 1D tensor x with length equal to the total number of points + such that points_packed()[i] is element x[i] of the flattened padded + representation. + The packed representation can be calculated as follows. + + .. code-block:: python + + p = points_padded().reshape(-1, 3) + points_packed = p[x] + + Returns: + 1D tensor of indices. + """ + self._compute_packed() + if self._padded_to_packed_idx is not None: + return self._padded_to_packed_idx + if self._N == 0: + self._padded_to_packed_idx = [] + else: + self._padded_to_packed_idx = torch.cat( + [ + torch.arange(v, dtype=torch.int64, device=self.device) + + i * self._P + for (i, v) in enumerate(self._num_points_per_cloud) + ], + dim=0, + ) + return self._padded_to_packed_idx + + def _compute_padded(self, refresh: bool = False): + """ + Computes the padded version from points_list, normals_list and features_list. + + Args: + refresh: whether to force the recalculation. + """ + if not (refresh or self._points_padded is None): + return + + self._normals_padded, self._features_padded = None, None + if self.isempty(): + self._points_padded = torch.zeros( + (self._N, 0, 3), device=self.device + ) + else: + self._points_padded = struct_utils.list_to_padded( + self.points_list(), + (self._P, 3), + pad_value=0.0, + equisized=self.equisized, + ) + if self.normals_list() is not None: + self._normals_padded = struct_utils.list_to_padded( + self.normals_list(), + (self._P, 3), + pad_value=0.0, + equisized=self.equisized, + ) + if self.features_list() is not None: + self._features_padded = struct_utils.list_to_padded( + self.features_list(), + (self._P, self._C), + pad_value=0.0, + equisized=self.equisized, + ) + + # TODO(nikhilar) Improve performance of _compute_packed. + def _compute_packed(self, refresh: bool = False): + """ + Computes the packed version from points_list, normals_list and + features_list and sets the values of auxillary tensors. + + Args: + refresh: Set to True to force recomputation of packed + representations. Default: False. + """ + + if not ( + refresh + or any( + v is None + for v in [ + self._points_packed, + self._packed_to_cloud_idx, + self._cloud_to_packed_first_idx, + ] + ) + ): + return + + # Packed can be calculated from padded or list, so can call the + # accessor function for the lists. + points_list = self.points_list() + normals_list = self.normals_list() + features_list = self.features_list() + if self.isempty(): + self._points_packed = torch.zeros( + (0, 3), dtype=torch.float32, device=self.device + ) + self._packed_to_cloud_idx = torch.zeros( + (0,), dtype=torch.int64, device=self.device + ) + self._cloud_to_packed_first_idx = torch.zeros( + (0,), dtype=torch.int64, device=self.device + ) + self._normals_packed = None + self._features_packed = None + return + + points_list_to_packed = struct_utils.list_to_packed(points_list) + self._points_packed = points_list_to_packed[0] + if not torch.allclose( + self._num_points_per_cloud, points_list_to_packed[1] + ): + raise ValueError("Inconsistent list to packed conversion") + self._cloud_to_packed_first_idx = points_list_to_packed[2] + self._packed_to_cloud_idx = points_list_to_packed[3] + + self._normals_packed, self._features_packed = None, None + if normals_list is not None: + normals_list_to_packed = struct_utils.list_to_packed(normals_list) + self._normals_packed = normals_list_to_packed[0] + + if features_list is not None: + features_list_to_packed = struct_utils.list_to_packed(features_list) + self._features_packed = features_list_to_packed[0] + + def clone(self): + """ + Deep copy of Pointclouds object. All internal tensors are cloned + individually. + + Returns: + new Pointclouds object. + """ + # instantiate new pointcloud with the representation which is not None + # (either list or tensor) to save compute. + new_points, new_normals, new_features = None, None, None + if self._points_list is not None: + new_points = [v.clone() for v in self.points_list()] + normals_list = self.normals_list() + features_list = self.features_list() + if normals_list is not None: + new_normals = [n.clone() for n in normals_list] + if features_list is not None: + new_features = [f.clone() for f in features_list] + elif self._points_padded is not None: + new_points = self.points_padded().clone() + normals_padded = self.normals_padded() + features_padded = self.features_padded() + if normals_padded is not None: + new_normals = self.normals_padded().clone() + if features_padded is not None: + new_features = self.features_padded().clone() + other = Pointclouds( + points=new_points, normals=new_normals, features=new_features + ) + for k in self._INTERNAL_TENSORS: + v = getattr(self, k) + if torch.is_tensor(v): + setattr(other, k, v.clone()) + return other + + def to(self, device, copy: bool = False): + """ + Match functionality of torch.Tensor.to() + If copy = True or the self Tensor is on a different device, the + returned tensor is a copy of self with the desired torch.device. + If copy = False and the self Tensor already has the correct torch.device, + then self is returned. + + Args: + device: Device id for the new tensor. + copy: Boolean indicator whether or not to clone self. Default False. + + Returns: + Pointclouds object. + """ + if not copy and self.device == device: + return self + other = self.clone() + if self.device != device: + other.device = device + if other._N > 0: + other._points_list = [v.to(device) for v in other.points_list()] + if other._normals_list is not None: + other._normals_list = [ + n.to(device) for n in other.normals_list() + ] + if other._features_list is not None: + other._features_list = [ + f.to(device) for f in other.features_list() + ] + for k in self._INTERNAL_TENSORS: + v = getattr(self, k) + if torch.is_tensor(v): + setattr(other, k, v.to(device)) + return other + + def cpu(self): + return self.to(torch.device("cpu")) + + def cuda(self): + return self.to(torch.device("cuda")) + + def get_cloud(self, index: int): + """ + Get tensors for a single cloud from the list representation. + + Args: + index: Integer in the range [0, N). + + Returns: + points: Tensor of shape (P, 3). + normals: Tensor of shape (P, 3) + features: LongTensor of shape (P, C). + """ + if not isinstance(index, int): + raise ValueError("Cloud index must be an integer.") + if index < 0 or index > self._N: + raise ValueError( + "Cloud index must be in the range [0, N) where \ + N is the number of clouds in the batch." + ) + points = self.points_list()[index] + normals, features = None, None + if self.normals_list() is not None: + normals = self.normals_list()[index] + if self.features_list() is not None: + features = self.features_list()[index] + return points, normals, features + + # TODO(nikhilar) Move function to a utils file. + def split(self, split_sizes: list): + """ + Splits Pointclouds object of size N into a list of Pointclouds objects + of size len(split_sizes), where the i-th Pointclouds object is of size + split_sizes[i]. Similar to torch.split(). + + Args: + split_sizes: List of integer sizes of Pointclouds objects to be + returned. + + Returns: + list[PointClouds]. + """ + if not all(isinstance(x, int) for x in split_sizes): + raise ValueError("Value of split_sizes must be a list of integers.") + cloudlist = [] + curi = 0 + for i in split_sizes: + cloudlist.append(self[curi : curi + i]) + curi += i + return cloudlist + + def offset_(self, offsets_packed): + """ + Translate the point clouds by an offset. In place operation. + + Args: + offsets_packed: A Tensor of the same shape as self.points_packed + giving offsets to be added to all points. + Returns: + self. + """ + points_packed = self.points_packed() + if offsets_packed.shape != points_packed.shape: + raise ValueError("Offsets must have dimension (all_p, 3).") + self._points_packed = points_packed + offsets_packed + new_points_list = list( + self._points_packed.split(self.num_points_per_cloud().tolist(), 0) + ) + # Note that since _compute_packed() has been executed, points_list + # cannot be None even if not provided during construction. + self._points_list = new_points_list + if self._points_padded is not None: + for i, points in enumerate(new_points_list): + if len(points) > 0: + self._points_padded[i, : points.shape[0], :] = points + return self + + # TODO(nikhilar) Move out of place operator to a utils file. + def offset(self, offsets_packed): + """ + Out of place offset. + + Args: + offsets_packed: A Tensor of the same shape as self.points_packed + giving offsets to be added to all points. + Returns: + new Pointclouds object. + """ + new_clouds = self.clone() + return new_clouds.offset_(offsets_packed) + + def scale_(self, scale): + """ + Multiply the coordinates of this object by a scalar value. + - i.e. enlarge/dilate + In place operation. + + Args: + scale: A scalar, or a Tensor of shape (N,). + + Returns: + self. + """ + if not torch.is_tensor(scale): + scale = torch.full(len(self), scale) + new_points_list = [] + points_list = self.points_list() + for i, old_points in enumerate(points_list): + new_points_list.append(scale[i] * old_points) + self._points_list = new_points_list + if self._points_packed is not None: + self._points_packed = torch.cat(new_points_list, dim=0) + if self._points_padded is not None: + for i, points in enumerate(new_points_list): + if len(points) > 0: + self._points_padded[i, : points.shape[0], :] = points + return self + + def scale(self, scale): + """ + Out of place scale_. + + Args: + scale: A scalar, or a Tensor of shape (N,). + + Returns: + new Pointclouds object. + """ + new_clouds = self.clone() + return new_clouds.scale_(scale) + + # TODO(nikhilar) Move function to utils file. + def get_bounding_boxes(self): + """ + Compute an axis-aligned bounding box for each cloud. + + Returns: + bboxes: Tensor of shape (N, 3, 2) where bbox[i, j] gives the + min and max values of cloud i along the jth coordinate axis. + """ + all_mins, all_maxes = [], [] + for points in self.points_list(): + cur_mins = points.min(dim=0)[0] # (3,) + cur_maxes = points.max(dim=0)[0] # (3,) + all_mins.append(cur_mins) + all_maxes.append(cur_maxes) + all_mins = torch.stack(all_mins, dim=0) # (N, 3) + all_maxes = torch.stack(all_maxes, dim=0) # (N, 3) + bboxes = torch.stack([all_mins, all_maxes], dim=2) + return bboxes + + def extend(self, N: int): + """ + Create new Pointclouds which contains each cloud N times. + + Args: + N: number of new copies of each cloud. + + Returns: + new Pointclouds object. + """ + if not isinstance(N, int): + raise ValueError("N must be an integer.") + if N <= 0: + raise ValueError("N must be > 0.") + + new_points_list, new_normals_list, new_features_list = [], None, None + for points in self.points_list(): + new_points_list.extend(points.clone() for _ in range(N)) + if self.normals_list() is not None: + new_normals_list = [] + for normals in self.normals_list(): + new_normals_list.extend(normals.clone() for _ in range(N)) + if self.features_list() is not None: + new_features_list = [] + for features in self.features_list(): + new_features_list.extend(features.clone() for _ in range(N)) + return Pointclouds( + points=new_points_list, + normals=new_normals_list, + features=new_features_list, + ) + + def update_padded( + self, + new_points_padded, + new_normals_padded=None, + new_features_padded=None, + ): + """ + Returns a Pointcloud structure with updated padded tensors and copies of + the auxiliary tensors. This function allows for an update of + points_padded (and normals and features) without having to explicitly + convert it to the list representation for heterogeneous batches. + + Args: + new_points_padded: FloatTensor of shape (N, P, 3) + new_normals_padded: (optional) FloatTensor of shape (N, P, 3) + new_features_padded: (optional) FloatTensors of shape (N, P, C) + + Returns: + Pointcloud with updated padded representations + """ + + def check_shapes(x, size): + if x.shape[0] != size[0]: + raise ValueError( + "new values must have the same batch dimension." + ) + if x.shape[1] != size[1]: + raise ValueError( + "new values must have the same number of points." + ) + if size[2] is not None: + if x.shape[2] != size[2]: + raise ValueError( + "new values must have the same number of channels." + ) + + check_shapes(new_points_padded, [self._N, self._P, 3]) + if new_normals_padded is not None: + check_shapes(new_normals_padded, [self._N, self._P, 3]) + if new_features_padded is not None: + check_shapes(new_features_padded, [self._N, self._P, self._C]) + + new = Pointclouds( + points=new_points_padded, + normals=new_normals_padded, + features=new_features_padded, + ) + + # overwrite the equisized flag + new.equisized = self.equisized + + # copy normals + if new_normals_padded is None: + # If no normals are provided, keep old ones (shallow copy) + new._normals_list = self._normals_list + new._normals_padded = self._normals_padded + new._normals_packed = self._normals_packed + + # copy features + if new_features_padded is None: + # If no features are provided, keep old ones (shallow copy) + new._features_list = self._features_list + new._features_padded = self._features_padded + new._features_packed = self._features_packed + + # copy auxiliary tensors + copy_tensors = [ + "_packed_to_cloud_idx", + "_cloud_to_packed_first_idx", + "_num_points_per_cloud", + "_padded_to_packed_idx", + "valid", + ] + for k in copy_tensors: + v = getattr(self, k) + if torch.is_tensor(v): + setattr(new, k, v) # shallow copy + + # update points + new._points_padded = new_points_padded + assert new._points_list is None + assert new._points_packed is None + + # update normals and features if provided + if new_normals_padded is not None: + new._normals_padded = new_normals_padded + new._normals_list = None + new._normals_packed = None + if new_features_padded is not None: + new._features_padded = new_features_padded + new._features_list = None + new._features_packed = None + return new diff --git a/tests/bm_pointclouds.py b/tests/bm_pointclouds.py new file mode 100644 index 00000000..e3dbd83b --- /dev/null +++ b/tests/bm_pointclouds.py @@ -0,0 +1,30 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +from itertools import product +from fvcore.common.benchmark import benchmark + +from test_pointclouds import TestPointclouds + + +def bm_compute_packed_padded_pointclouds() -> None: + kwargs_list = [] + num_clouds = [32, 128] + max_p = [100, 10000] + feats = [1, 10, 300] + test_cases = product(num_clouds, max_p, feats) + for case in test_cases: + n, p, f = case + kwargs_list.append({"num_clouds": n, "max_p": p, "features": f}) + benchmark( + TestPointclouds.compute_packed_with_init, + "COMPUTE_PACKED", + kwargs_list, + warmup_iters=1, + ) + benchmark( + TestPointclouds.compute_padded_with_init, + "COMPUTE_PADDED", + kwargs_list, + warmup_iters=1, + ) diff --git a/tests/bm_rasterize_points.py b/tests/bm_rasterize_points.py new file mode 100644 index 00000000..3aaf77c3 --- /dev/null +++ b/tests/bm_rasterize_points.py @@ -0,0 +1,52 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import torch +from fvcore.common.benchmark import benchmark + +from pytorch3d.renderer.points.rasterize_points import ( + rasterize_points, + rasterize_points_python, +) +from pytorch3d.structures.pointclouds import Pointclouds + + +def _bm_python_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3): + torch.manual_seed(231) + points = torch.randn(N, P, 3) + pointclouds = Pointclouds(points=points) + args = (pointclouds, img_size, radius, pts_per_pxl) + return lambda: rasterize_points_python(*args) + + +def _bm_cpu_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3): + torch.manual_seed(231) + points = torch.randn(N, P, 3) + pointclouds = Pointclouds(points=points) + args = (pointclouds, img_size, radius, pts_per_pxl) + return lambda: rasterize_points(*args) + + +def _bm_cuda_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3): + torch.manual_seed(231) + points = torch.randn(N, P, 3, device=torch.device("cuda")) + pointclouds = Pointclouds(points=points) + args = (pointclouds, img_size, radius, pts_per_pxl) + return lambda: rasterize_points(*args) + + +def bm_python_vs_cpu() -> None: + kwargs_list = [ + {"N": 1, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3}, + {"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3}, + ] + benchmark( + _bm_python_with_init, "RASTERIZE_PYTHON", kwargs_list, warmup_iters=1 + ) + benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1) + kwargs_list = [ + {"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3}, + {"N": 4, "P": 1024, "img_size": 128, "radius": 0.05, "pts_per_pxl": 5}, + ] + benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1) + benchmark(_bm_cuda_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1) diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py new file mode 100644 index 00000000..d1deea2e --- /dev/null +++ b/tests/test_pointclouds.py @@ -0,0 +1,978 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import numpy as np +import unittest +import torch + +from pytorch3d.structures.pointclouds import Pointclouds + +from common_testing import TestCaseMixin + + +class TestPointclouds(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + np.random.seed(42) + torch.manual_seed(42) + + @staticmethod + def init_cloud( + num_clouds: int = 3, + max_points: int = 100, + channels: int = 4, + lists_to_tensors: bool = False, + with_normals: bool = True, + with_features: bool = True, + ): + """ + Function to generate a Pointclouds object of N meshes with + random number of points. + + Args: + num_clouds: Number of clouds to generate. + channels: Number of features. + max_points: Max number of points per cloud. + lists_to_tensors: Determines whether the generated clouds should be + constructed from lists (=False) or + tensors (=True) of points/normals/features. + with_normals: bool whether to include normals + with_features: bool whether to include features + + Returns: + Pointclouds object. + """ + device = torch.device("cuda:0") + p = torch.randint(max_points, size=(num_clouds,)) + if lists_to_tensors: + p.fill_(p[0]) + + points_list = [ + torch.rand((i, 3), device=device, dtype=torch.float32) for i in p + ] + normals_list, features_list = None, None + if with_normals: + normals_list = [ + torch.rand((i, 3), device=device, dtype=torch.float32) + for i in p + ] + if with_features: + features_list = [ + torch.rand((i, channels), device=device, dtype=torch.float32) + for i in p + ] + + if lists_to_tensors: + points_list = torch.stack(points_list) + if with_normals: + normals_list = torch.stack(normals_list) + if with_features: + features_list = torch.stack(features_list) + + return Pointclouds( + points_list, normals=normals_list, features=features_list + ) + + def test_simple(self): + device = torch.device("cuda:0") + points = [ + torch.tensor( + [[0.1, 0.3, 0.5], [0.5, 0.2, 0.1], [0.6, 0.8, 0.7]], + dtype=torch.float32, + device=device, + ), + torch.tensor( + [ + [0.1, 0.3, 0.3], + [0.6, 0.7, 0.8], + [0.2, 0.3, 0.4], + [0.1, 0.5, 0.3], + ], + dtype=torch.float32, + device=device, + ), + torch.tensor( + [ + [0.7, 0.3, 0.6], + [0.2, 0.4, 0.8], + [0.9, 0.5, 0.2], + [0.2, 0.3, 0.4], + [0.9, 0.3, 0.8], + ], + dtype=torch.float32, + device=device, + ), + ] + clouds = Pointclouds(points) + + self.assertClose( + (clouds.packed_to_cloud_idx()).cpu(), + torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]), + ) + self.assertClose( + clouds.cloud_to_packed_first_idx().cpu(), torch.tensor([0, 3, 7]) + ) + self.assertClose( + clouds.num_points_per_cloud().cpu(), torch.tensor([3, 4, 5]) + ) + self.assertClose( + clouds.padded_to_packed_idx().cpu(), + torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]), + ) + + def test_all_constructions(self): + public_getters = [ + "points_list", + "points_packed", + "packed_to_cloud_idx", + "cloud_to_packed_first_idx", + "num_points_per_cloud", + "points_padded", + "padded_to_packed_idx", + ] + public_normals_getters = [ + "normals_list", + "normals_packed", + "normals_padded", + ] + public_features_getters = [ + "features_list", + "features_packed", + "features_padded", + ] + + lengths = [3, 4, 2] + max_len = max(lengths) + C = 4 + + points_data = [torch.zeros((max_len, 3)).uniform_() for i in lengths] + normals_data = [torch.zeros((max_len, 3)).uniform_() for i in lengths] + features_data = [torch.zeros((max_len, C)).uniform_() for i in lengths] + for length, p, n, f in zip( + lengths, points_data, normals_data, features_data + ): + p[length:] = 0.0 + n[length:] = 0.0 + f[length:] = 0.0 + points_list = [d[:length] for length, d in zip(lengths, points_data)] + normals_list = [d[:length] for length, d in zip(lengths, normals_data)] + features_list = [ + d[:length] for length, d in zip(lengths, features_data) + ] + points_packed = torch.cat(points_data) + normals_packed = torch.cat(normals_data) + features_packed = torch.cat(features_data) + test_cases_inputs = [ + ("list_0_0", points_list, None, None), + ("list_1_0", points_list, normals_list, None), + ("list_0_1", points_list, None, features_list), + ("list_1_1", points_list, normals_list, features_list), + ("padded_0_0", points_data, None, None), + ("padded_1_0", points_data, normals_data, None), + ("padded_0_1", points_data, None, features_data), + ("padded_1_1", points_data, normals_data, features_data), + ("emptylist_emptylist_emptylist", [], [], []), + ] + false_cases_inputs = [ + ( + "list_packed", + points_list, + normals_packed, + features_packed, + ValueError, + ), + ("packed_0", points_packed, None, None, ValueError), + ] + + for name, points, normals, features in test_cases_inputs: + with self.subTest(name=name): + p = Pointclouds(points, normals, features) + for method in public_getters: + self.assertIsNotNone(getattr(p, method)()) + for method in public_normals_getters: + if normals is None or p.isempty(): + self.assertIsNone(getattr(p, method)()) + for method in public_features_getters: + if features is None or p.isempty(): + self.assertIsNone(getattr(p, method)()) + + for name, points, normals, features, error in false_cases_inputs: + with self.subTest(name=name): + with self.assertRaises(error): + Pointclouds(points, normals, features) + + def test_simple_random_clouds(self): + # Define the test object either from lists or tensors. + for with_normals in (False, True): + for with_features in (False, True): + for lists_to_tensors in (False, True): + N = 10 + cloud = self.init_cloud( + N, + lists_to_tensors=lists_to_tensors, + with_normals=with_normals, + with_features=with_features, + ) + points_list = cloud.points_list() + normals_list = cloud.normals_list() + features_list = cloud.features_list() + + # Check batch calculations. + points_padded = cloud.points_padded() + normals_padded = cloud.normals_padded() + features_padded = cloud.features_padded() + points_per_cloud = cloud.num_points_per_cloud() + + if not with_normals: + self.assertIsNone(normals_list) + self.assertIsNone(normals_padded) + if not with_features: + self.assertIsNone(features_list) + self.assertIsNone(features_padded) + for n in range(N): + p = points_list[n].shape[0] + self.assertClose( + points_padded[n, :p, :], points_list[n] + ) + if with_normals: + norms = normals_list[n].shape[0] + self.assertEqual(p, norms) + self.assertClose( + normals_padded[n, :p, :], normals_list[n] + ) + if with_features: + f = features_list[n].shape[0] + self.assertEqual(p, f) + self.assertClose( + features_padded[n, :p, :], features_list[n] + ) + if points_padded.shape[1] > p: + self.assertTrue(points_padded[n, p:, :].eq(0).all()) + if with_features: + self.assertTrue( + features_padded[n, p:, :].eq(0).all() + ) + self.assertEqual(points_per_cloud[n], p) + + # Check compute packed. + points_packed = cloud.points_packed() + packed_to_cloud = cloud.packed_to_cloud_idx() + cloud_to_packed = cloud.cloud_to_packed_first_idx() + normals_packed = cloud.normals_packed() + features_packed = cloud.features_packed() + if not with_normals: + self.assertIsNone(normals_packed) + if not with_features: + self.assertIsNone(features_packed) + + cur = 0 + for n in range(N): + p = points_list[n].shape[0] + self.assertClose( + points_packed[cur : cur + p, :], points_list[n] + ) + if with_normals: + self.assertClose( + normals_packed[cur : cur + p, :], + normals_list[n], + ) + if with_features: + self.assertClose( + features_packed[cur : cur + p, :], + features_list[n], + ) + self.assertTrue( + packed_to_cloud[cur : cur + p].eq(n).all() + ) + self.assertTrue(cloud_to_packed[n] == cur) + cur += p + + def test_allempty(self): + clouds = Pointclouds([], []) + self.assertEqual(len(clouds), 0) + self.assertIsNone(clouds.normals_list()) + self.assertIsNone(clouds.features_list()) + self.assertEqual(clouds.points_padded().shape[0], 0) + self.assertIsNone(clouds.normals_padded()) + self.assertIsNone(clouds.features_padded()) + self.assertEqual(clouds.points_packed().shape[0], 0) + self.assertIsNone(clouds.normals_packed()) + self.assertIsNone(clouds.features_packed()) + + def test_empty(self): + N, P, C = 10, 100, 2 + device = torch.device("cuda:0") + points_list = [] + normals_list = [] + features_list = [] + valid = torch.randint(2, size=(N,), dtype=torch.uint8, device=device) + for n in range(N): + if valid[n]: + p = torch.randint( + 3, high=P, size=(1,), dtype=torch.int32, device=device + )[0] + points = torch.rand((p, 3), dtype=torch.float32, device=device) + normals = torch.rand((p, 3), dtype=torch.float32, device=device) + features = torch.rand( + (p, C), dtype=torch.float32, device=device + ) + else: + points = torch.tensor([], dtype=torch.float32, device=device) + normals = torch.tensor([], dtype=torch.float32, device=device) + features = torch.tensor([], dtype=torch.int64, device=device) + points_list.append(points) + normals_list.append(normals) + features_list.append(features) + + for with_normals in (False, True): + for with_features in (False, True): + this_features, this_normals = None, None + if with_normals: + this_normals = normals_list + if with_features: + this_features = features_list + clouds = Pointclouds( + points=points_list, + normals=this_normals, + features=this_features, + ) + points_padded = clouds.points_padded() + normals_padded = clouds.normals_padded() + features_padded = clouds.features_padded() + if not with_normals: + self.assertIsNone(normals_padded) + if not with_features: + self.assertIsNone(features_padded) + points_per_cloud = clouds.num_points_per_cloud() + for n in range(N): + p = len(points_list[n]) + if p > 0: + self.assertClose( + points_padded[n, :p, :], points_list[n] + ) + if with_normals: + self.assertClose( + normals_padded[n, :p, :], normals_list[n] + ) + if with_features: + self.assertClose( + features_padded[n, :p, :], features_list[n] + ) + if points_padded.shape[1] > p: + self.assertTrue(points_padded[n, p:, :].eq(0).all()) + if with_normals: + self.assertTrue( + normals_padded[n, p:, :].eq(0).all() + ) + if with_features: + self.assertTrue( + features_padded[n, p:, :].eq(0).all() + ) + self.assertTrue(points_per_cloud[n] == p) + + def test_clone_list(self): + N = 5 + clouds = self.init_cloud(N, 100, 5) + for force in (False, True): + if force: + clouds.points_packed() + + new_clouds = clouds.clone() + + # Check cloned and original objects do not share tensors. + self.assertSeparate( + new_clouds.points_list()[0], clouds.points_list()[0] + ) + self.assertSeparate( + new_clouds.normals_list()[0], clouds.normals_list()[0] + ) + self.assertSeparate( + new_clouds.features_list()[0], clouds.features_list()[0] + ) + for attrib in [ + "points_packed", + "normals_packed", + "features_packed", + "points_padded", + "normals_padded", + "features_padded", + ]: + self.assertSeparate( + getattr(new_clouds, attrib)(), getattr(clouds, attrib)() + ) + + self.assertCloudsEqual(clouds, new_clouds) + + def test_clone_tensor(self): + N = 5 + clouds = self.init_cloud(N, 100, 5, lists_to_tensors=True) + for force in (False, True): + if force: + clouds.points_packed() + + new_clouds = clouds.clone() + + # Check cloned and original objects do not share tensors. + self.assertSeparate( + new_clouds.points_list()[0], clouds.points_list()[0] + ) + self.assertSeparate( + new_clouds.normals_list()[0], clouds.normals_list()[0] + ) + self.assertSeparate( + new_clouds.features_list()[0], clouds.features_list()[0] + ) + for attrib in [ + "points_packed", + "normals_packed", + "features_packed", + "points_padded", + "normals_padded", + "features_padded", + ]: + self.assertSeparate( + getattr(new_clouds, attrib)(), getattr(clouds, attrib)() + ) + + self.assertCloudsEqual(clouds, new_clouds) + + def assertCloudsEqual(self, cloud1, cloud2): + N = len(cloud1) + self.assertEqual(N, len(cloud2)) + + for i in range(N): + self.assertClose(cloud1.points_list()[i], cloud2.points_list()[i]) + self.assertClose(cloud1.normals_list()[i], cloud2.normals_list()[i]) + self.assertClose( + cloud1.features_list()[i], cloud2.features_list()[i] + ) + has_normals = cloud1.normals_list() is not None + self.assertTrue(has_normals == (cloud2.normals_list() is not None)) + has_features = cloud1.features_list() is not None + self.assertTrue(has_features == (cloud2.features_list() is not None)) + + # check padded & packed + self.assertClose(cloud1.points_padded(), cloud2.points_padded()) + self.assertClose(cloud1.points_packed(), cloud2.points_packed()) + if has_normals: + self.assertClose(cloud1.normals_padded(), cloud2.normals_padded()) + self.assertClose(cloud1.normals_packed(), cloud2.normals_packed()) + if has_features: + self.assertClose(cloud1.features_padded(), cloud2.features_padded()) + self.assertClose(cloud1.features_packed(), cloud2.features_packed()) + self.assertClose( + cloud1.packed_to_cloud_idx(), cloud2.packed_to_cloud_idx() + ) + self.assertClose( + cloud1.cloud_to_packed_first_idx(), + cloud2.cloud_to_packed_first_idx(), + ) + self.assertClose( + cloud1.num_points_per_cloud(), cloud2.num_points_per_cloud() + ) + self.assertClose( + cloud1.packed_to_cloud_idx(), cloud2.packed_to_cloud_idx() + ) + self.assertClose( + cloud1.padded_to_packed_idx(), cloud2.padded_to_packed_idx() + ) + self.assertTrue(all(cloud1.valid == cloud2.valid)) + self.assertTrue(cloud1.equisized == cloud2.equisized) + + def test_offset(self): + def naive_offset(clouds, offsets_packed): + new_points_packed = clouds.points_packed() + offsets_packed + new_points_list = list( + new_points_packed.split( + clouds.num_points_per_cloud().tolist(), 0 + ) + ) + return Pointclouds( + points=new_points_list, + normals=clouds.normals_list(), + features=clouds.features_list(), + ) + + N = 5 + clouds = self.init_cloud(N, 100, 10) + all_p = clouds.points_packed().size(0) + points_per_cloud = clouds.num_points_per_cloud() + for force in (False, True): + if force: + clouds._compute_packed(refresh=True) + clouds._compute_padded() + clouds.padded_to_packed_idx() + + deform = torch.rand( + (all_p, 3), dtype=torch.float32, device=clouds.device + ) + new_clouds_naive = naive_offset(clouds, deform) + + new_clouds = clouds.offset(deform) + + points_cumsum = torch.cumsum(points_per_cloud, 0).tolist() + points_cumsum.insert(0, 0) + for i in range(N): + self.assertClose( + new_clouds.points_list()[i], + clouds.points_list()[i] + + deform[points_cumsum[i] : points_cumsum[i + 1]], + ) + self.assertClose( + clouds.normals_list()[i], new_clouds_naive.normals_list()[i] + ) + self.assertClose( + clouds.features_list()[i], + new_clouds_naive.features_list()[i], + ) + self.assertCloudsEqual(new_clouds, new_clouds_naive) + + def test_scale(self): + def naive_scale(cloud, scale): + if not torch.is_tensor(scale): + scale = torch.full(len(cloud), scale) + new_points_list = [ + scale[i] * points.clone() + for (i, points) in enumerate(cloud.points_list()) + ] + return Pointclouds( + new_points_list, cloud.normals_list(), cloud.features_list() + ) + + N = 5 + clouds = self.init_cloud(N, 100, 10) + for force in (False, True): + if force: + clouds._compute_packed(refresh=True) + clouds._compute_padded() + clouds.padded_to_packed_idx() + scales = torch.rand(N) + new_clouds_naive = naive_scale(clouds, scales) + new_clouds = clouds.scale(scales) + for i in range(N): + self.assertClose( + scales[i] * clouds.points_list()[i], + new_clouds.points_list()[i], + ) + self.assertClose( + clouds.normals_list()[i], new_clouds_naive.normals_list()[i] + ) + self.assertClose( + clouds.features_list()[i], + new_clouds_naive.features_list()[i], + ) + self.assertCloudsEqual(new_clouds, new_clouds_naive) + + def test_extend_list(self): + N = 10 + clouds = self.init_cloud(N, 100, 10) + for force in (False, True): + if force: + # force some computes to happen + clouds._compute_packed(refresh=True) + clouds._compute_padded() + clouds.padded_to_packed_idx() + new_clouds = clouds.extend(N) + self.assertEqual(len(clouds) * 10, len(new_clouds)) + for i in range(len(clouds)): + for n in range(N): + self.assertClose( + clouds.points_list()[i], + new_clouds.points_list()[i * N + n], + ) + self.assertClose( + clouds.normals_list()[i], + new_clouds.normals_list()[i * N + n], + ) + self.assertClose( + clouds.features_list()[i], + new_clouds.features_list()[i * N + n], + ) + self.assertTrue( + clouds.valid[i] == new_clouds.valid[i * N + n] + ) + self.assertAllSeparate( + clouds.points_list() + + new_clouds.points_list() + + clouds.normals_list() + + new_clouds.normals_list() + + clouds.features_list() + + new_clouds.features_list() + ) + self.assertIsNone(new_clouds._points_packed) + self.assertIsNone(new_clouds._normals_packed) + self.assertIsNone(new_clouds._features_packed) + self.assertIsNone(new_clouds._points_padded) + self.assertIsNone(new_clouds._normals_padded) + self.assertIsNone(new_clouds._features_padded) + + with self.assertRaises(ValueError): + clouds.extend(N=-1) + + def test_to_list(self): + cloud = self.init_cloud(5, 100, 10) + device = torch.device("cuda:1") + + new_cloud = cloud.to(device) + self.assertTrue(new_cloud.device == device) + self.assertTrue(cloud.device == torch.device("cuda:0")) + for attrib in [ + "points_padded", + "points_packed", + "normals_padded", + "normals_packed", + "features_padded", + "features_packed", + "num_points_per_cloud", + "cloud_to_packed_first_idx", + "padded_to_packed_idx", + ]: + self.assertClose( + getattr(new_cloud, attrib)().cpu(), + getattr(cloud, attrib)().cpu(), + ) + for i in range(len(cloud)): + self.assertClose( + cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu() + ) + self.assertClose( + cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu() + ) + self.assertClose( + cloud.features_list()[i].cpu(), + new_cloud.features_list()[i].cpu(), + ) + self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu())) + self.assertTrue(cloud.equisized == new_cloud.equisized) + self.assertTrue(cloud._N == new_cloud._N) + self.assertTrue(cloud._P == new_cloud._P) + self.assertTrue(cloud._C == new_cloud._C) + + def test_to_tensor(self): + cloud = self.init_cloud(5, 100, 10, lists_to_tensors=True) + device = torch.device("cuda:1") + + new_cloud = cloud.to(device) + self.assertTrue(new_cloud.device == device) + self.assertTrue(cloud.device == torch.device("cuda:0")) + for attrib in [ + "points_padded", + "points_packed", + "normals_padded", + "normals_packed", + "features_padded", + "features_packed", + "num_points_per_cloud", + "cloud_to_packed_first_idx", + "padded_to_packed_idx", + ]: + self.assertClose( + getattr(new_cloud, attrib)().cpu(), + getattr(cloud, attrib)().cpu(), + ) + for i in range(len(cloud)): + self.assertClose( + cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu() + ) + self.assertClose( + cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu() + ) + self.assertClose( + cloud.features_list()[i].cpu(), + new_cloud.features_list()[i].cpu(), + ) + self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu())) + self.assertTrue(cloud.equisized == new_cloud.equisized) + self.assertTrue(cloud._N == new_cloud._N) + self.assertTrue(cloud._P == new_cloud._P) + self.assertTrue(cloud._C == new_cloud._C) + + def test_split(self): + clouds = self.init_cloud(5, 100, 10) + split_sizes = [2, 3] + split_clouds = clouds.split(split_sizes) + self.assertEqual(len(split_clouds[0]), 2) + self.assertTrue( + split_clouds[0].points_list() + == [clouds.get_cloud(0)[0], clouds.get_cloud(1)[0]] + ) + self.assertEqual(len(split_clouds[1]), 3) + self.assertTrue( + split_clouds[1].points_list() + == [ + clouds.get_cloud(2)[0], + clouds.get_cloud(3)[0], + clouds.get_cloud(4)[0], + ] + ) + + split_sizes = [2, 0.3] + with self.assertRaises(ValueError): + clouds.split(split_sizes) + + def test_get_cloud(self): + clouds = self.init_cloud(2, 100, 10) + for i in range(len(clouds)): + points, normals, features = clouds.get_cloud(i) + self.assertClose(points, clouds.points_list()[i]) + self.assertClose(normals, clouds.normals_list()[i]) + self.assertClose(features, clouds.features_list()[i]) + + with self.assertRaises(ValueError): + clouds.get_cloud(5) + with self.assertRaises(ValueError): + clouds.get_cloud(0.2) + + def test_get_bounding_boxes(self): + device = torch.device("cuda:0") + points_list = [] + for size in [10]: + points = torch.rand((size, 3), dtype=torch.float32, device=device) + points_list.append(points) + + mins = torch.min(points, dim=0)[0] + maxs = torch.max(points, dim=0)[0] + bboxes_gt = torch.stack([mins, maxs], dim=1).unsqueeze(0) + clouds = Pointclouds(points_list) + bboxes = clouds.get_bounding_boxes() + self.assertClose(bboxes_gt, bboxes) + + def test_padded_to_packed_idx(self): + device = torch.device("cuda:0") + points_list = [] + npoints = [10, 20, 30] + for p in npoints: + points = torch.rand((p, 3), dtype=torch.float32, device=device) + points_list.append(points) + + clouds = Pointclouds(points_list) + + padded_to_packed_idx = clouds.padded_to_packed_idx() + points_packed = clouds.points_packed() + points_padded = clouds.points_padded() + points_padded_flat = points_padded.view(-1, 3) + + self.assertClose( + points_padded_flat[padded_to_packed_idx], points_packed + ) + + idx = padded_to_packed_idx.view(-1, 1).expand(-1, 3) + self.assertClose(points_padded_flat.gather(0, idx), points_packed) + + def test_getitem(self): + device = torch.device("cuda:0") + clouds = self.init_cloud(3, 10, 100) + + def check_equal(selected, indices): + for selectedIdx, index in indices: + self.assertClose( + selected.points_list()[selectedIdx], + clouds.points_list()[index], + ) + self.assertClose( + selected.normals_list()[selectedIdx], + clouds.normals_list()[index], + ) + self.assertClose( + selected.features_list()[selectedIdx], + clouds.features_list()[index], + ) + + # int index + index = 1 + clouds_selected = clouds[index] + self.assertEqual(len(clouds_selected), 1) + check_equal(clouds_selected, [(0, 1)]) + + # list index + index = [1, 2] + clouds_selected = clouds[index] + self.assertEqual(len(clouds_selected), len(index)) + check_equal(clouds_selected, enumerate(index)) + + # slice index + index = slice(0, 2, 1) + clouds_selected = clouds[index] + self.assertEqual(len(clouds_selected), 2) + check_equal(clouds_selected, [(0, 0), (1, 1)]) + + # bool tensor + index = torch.tensor([1, 0, 1], dtype=torch.bool, device=device) + clouds_selected = clouds[index] + self.assertEqual(len(clouds_selected), index.sum()) + check_equal(clouds_selected, [(0, 0), (1, 2)]) + + # int tensor + index = torch.tensor([1, 2], dtype=torch.int64, device=device) + clouds_selected = clouds[index] + self.assertEqual(len(clouds_selected), index.numel()) + check_equal(clouds_selected, enumerate(index.tolist())) + + # invalid index + index = torch.tensor([1, 0, 1], dtype=torch.float32, device=device) + with self.assertRaises(IndexError): + clouds_selected = clouds[index] + index = 1.2 + with self.assertRaises(IndexError): + clouds_selected = clouds[index] + + def test_update_padded(self): + N, P, C = 5, 100, 4 + for with_normfeat in (True, False): + for with_new_normfeat in (True, False): + clouds = self.init_cloud( + N, + P, + C, + with_normals=with_normfeat, + with_features=with_normfeat, + ) + + num_points_per_cloud = clouds.num_points_per_cloud() + + # initialize new points, normals, features + new_points = torch.rand( + clouds.points_padded().shape, device=clouds.device + ) + new_points_list = [ + new_points[i, : num_points_per_cloud[i]] for i in range(N) + ] + new_normals, new_normals_list = None, None + new_features, new_features_list = None, None + if with_new_normfeat: + new_normals = torch.rand( + clouds.points_padded().shape, device=clouds.device + ) + new_normals_list = [ + new_normals[i, : num_points_per_cloud[i]] + for i in range(N) + ] + feat_shape = [ + clouds.points_padded().shape[0], + clouds.points_padded().shape[1], + C, + ] + new_features = torch.rand(feat_shape, device=clouds.device) + new_features_list = [ + new_features[i, : num_points_per_cloud[i]] + for i in range(N) + ] + + # update + new_clouds = clouds.update_padded( + new_points, new_normals, new_features + ) + self.assertIsNone(new_clouds._points_list) + self.assertIsNone(new_clouds._points_packed) + + self.assertEqual(new_clouds.equisized, clouds.equisized) + self.assertTrue(all(new_clouds.valid == clouds.valid)) + + self.assertClose(new_clouds.points_padded(), new_points) + self.assertClose( + new_clouds.points_packed(), torch.cat(new_points_list) + ) + for i in range(N): + self.assertClose( + new_clouds.points_list()[i], new_points_list[i] + ) + + if with_new_normfeat: + for i in range(N): + self.assertClose( + new_clouds.normals_list()[i], new_normals_list[i] + ) + self.assertClose( + new_clouds.features_list()[i], new_features_list[i] + ) + self.assertClose(new_clouds.normals_padded(), new_normals) + self.assertClose( + new_clouds.normals_packed(), torch.cat(new_normals_list) + ) + self.assertClose(new_clouds.features_padded(), new_features) + self.assertClose( + new_clouds.features_packed(), + torch.cat(new_features_list), + ) + else: + if with_normfeat: + for i in range(N): + self.assertClose( + new_clouds.normals_list()[i], + clouds.normals_list()[i], + ) + self.assertClose( + new_clouds.features_list()[i], + clouds.features_list()[i], + ) + self.assertNotSeparate( + new_clouds.normals_list()[i], + clouds.normals_list()[i], + ) + self.assertNotSeparate( + new_clouds.features_list()[i], + clouds.features_list()[i], + ) + + self.assertClose( + new_clouds.normals_padded(), clouds.normals_padded() + ) + self.assertClose( + new_clouds.normals_packed(), clouds.normals_packed() + ) + self.assertClose( + new_clouds.features_padded(), + clouds.features_padded(), + ) + self.assertClose( + new_clouds.features_packed(), + clouds.features_packed(), + ) + self.assertNotSeparate( + new_clouds.normals_padded(), clouds.normals_padded() + ) + self.assertNotSeparate( + new_clouds.features_padded(), + clouds.features_padded(), + ) + else: + self.assertIsNone(new_clouds.normals_list()) + self.assertIsNone(new_clouds.features_list()) + self.assertIsNone(new_clouds.normals_padded()) + self.assertIsNone(new_clouds.features_padded()) + self.assertIsNone(new_clouds.normals_packed()) + self.assertIsNone(new_clouds.features_packed()) + + for attrib in [ + "num_points_per_cloud", + "cloud_to_packed_first_idx", + "padded_to_packed_idx", + ]: + self.assertClose( + getattr(new_clouds, attrib)(), getattr(clouds, attrib)() + ) + + @staticmethod + def compute_packed_with_init( + num_clouds: int = 10, max_p: int = 100, features: int = 300 + ): + clouds = TestPointclouds.init_cloud(num_clouds, max_p, features) + torch.cuda.synchronize() + + def compute_packed(): + clouds._compute_packed(refresh=True) + torch.cuda.synchronize() + + return compute_packed + + @staticmethod + def compute_padded_with_init( + num_clouds: int = 10, max_p: int = 100, features: int = 300 + ): + clouds = TestPointclouds.init_cloud(num_clouds, max_p, features) + torch.cuda.synchronize() + + def compute_padded(): + clouds._compute_padded(refresh=True) + torch.cuda.synchronize() + + return compute_padded diff --git a/tests/test_rasterize_points.py b/tests/test_rasterize_points.py new file mode 100644 index 00000000..8313e738 --- /dev/null +++ b/tests/test_rasterize_points.py @@ -0,0 +1,525 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import numpy as np +import unittest +import torch + +from pytorch3d import _C +from pytorch3d.renderer.points.rasterize_points import ( + rasterize_points, + rasterize_points_python, +) +from pytorch3d.structures.pointclouds import Pointclouds + +from common_testing import TestCaseMixin + + +class TestRasterizePoints(TestCaseMixin, unittest.TestCase): + def test_python_simple_cpu(self): + self._simple_test_case( + rasterize_points_python, torch.device("cpu"), bin_size=-1 + ) + + def test_naive_simple_cpu(self): + device = torch.device("cpu") + self._simple_test_case(rasterize_points, device) + + def test_naive_simple_cuda(self): + device = torch.device("cuda") + self._simple_test_case(rasterize_points, device, bin_size=0) + + def test_python_behind_camera(self): + self._test_behind_camera( + rasterize_points_python, torch.device("cpu"), bin_size=-1 + ) + + def test_cpu_behind_camera(self): + self._test_behind_camera(rasterize_points, torch.device("cpu")) + + def test_cuda_behind_camera(self): + self._test_behind_camera( + rasterize_points, torch.device("cuda"), bin_size=0 + ) + + def test_cpp_vs_naive_vs_binned(self): + # Make sure that the backward pass runs for all pathways + N = 2 + P = 1000 + image_size = 32 + radius = 0.1 + points_per_pixel = 3 + points1 = torch.randn(P, 3, requires_grad=True) + points2 = torch.randn(int(P / 2), 3, requires_grad=True) + pointclouds = Pointclouds(points=[points1, points2]) + grad_zbuf = torch.randn(N, image_size, image_size, points_per_pixel) + grad_dists = torch.randn(N, image_size, image_size, points_per_pixel) + + # Option I: CPU, naive + idx1, zbuf1, dists1 = rasterize_points( + pointclouds, image_size, radius, points_per_pixel, bin_size=0 + ) + loss = (zbuf1 * grad_zbuf).sum() + (dists1 * grad_dists).sum() + loss.backward() + grad1 = points1.grad.data.clone() + + # Option II: CUDA, naive + points1_cuda = points1.cuda().detach().clone().requires_grad_(True) + points2_cuda = points2.cuda().detach().clone().requires_grad_(True) + pointclouds = Pointclouds(points=[points1_cuda, points2_cuda]) + grad_zbuf = grad_zbuf.cuda() + grad_dists = grad_dists.cuda() + idx2, zbuf2, dists2 = rasterize_points( + pointclouds, image_size, radius, points_per_pixel, bin_size=0 + ) + loss = (zbuf2 * grad_zbuf).sum() + (dists2 * grad_dists).sum() + loss.backward() + idx2 = idx2.data.cpu().clone() + zbuf2 = zbuf2.data.cpu().clone() + dists2 = dists2.data.cpu().clone() + grad2 = points1_cuda.grad.data.cpu().clone() + + # Option III: CUDA, binned + points1_cuda = points1.cuda().detach().clone().requires_grad_(True) + points2_cuda = points2.cuda().detach().clone().requires_grad_(True) + pointclouds = Pointclouds(points=[points1_cuda, points2_cuda]) + idx3, zbuf3, dists3 = rasterize_points( + pointclouds, image_size, radius, points_per_pixel, bin_size=32 + ) + loss = (zbuf3 * grad_zbuf).sum() + (dists3 * grad_dists).sum() + points1.grad.data.zero_() + loss.backward() + idx3 = idx3.data.cpu().clone() + zbuf3 = zbuf3.data.cpu().clone() + dists3 = dists3.data.cpu().clone() + grad3 = points1_cuda.grad.data.cpu().clone() + + # Make sure everything was the same + idx12_same = (idx1 == idx2).all().item() + idx13_same = (idx1 == idx3).all().item() + zbuf12_same = (zbuf1 == zbuf2).all().item() + zbuf13_same = (zbuf1 == zbuf3).all().item() + dists12_diff = (dists1 - dists2).abs().max().item() + dists13_diff = (dists1 - dists3).abs().max().item() + self.assertTrue(idx12_same) + self.assertTrue(idx13_same) + self.assertTrue(zbuf12_same) + self.assertTrue(zbuf13_same) + self.assertTrue(dists12_diff < 1e-6) + self.assertTrue(dists13_diff < 1e-6) + + diff12 = (grad1 - grad2).abs().max().item() + diff13 = (grad1 - grad3).abs().max().item() + diff23 = (grad2 - grad3).abs().max().item() + self.assertTrue(diff12 < 5e-6) + self.assertTrue(diff13 < 5e-6) + self.assertTrue(diff23 < 5e-6) + + def test_python_vs_cpu_naive(self): + torch.manual_seed(231) + image_size = 32 + radius = 0.1 + points_per_pixel = 3 + + # Test a batch of homogeneous point clouds. + N = 2 + P = 17 + points = torch.randn(N, P, 3, requires_grad=True) + pointclouds = Pointclouds(points=points) + args = (pointclouds, image_size, radius, points_per_pixel) + self._compare_impls( + rasterize_points_python, + rasterize_points, + args, + args, + points, + points, + compare_grads=True, + ) + + # Test a batch of heterogeneous point clouds. + P2 = 10 + points1 = torch.randn(P, 3, requires_grad=True) + points2 = torch.randn(P2, 3) + pointclouds = Pointclouds(points=[points1, points2]) + args = (pointclouds, image_size, radius, points_per_pixel) + self._compare_impls( + rasterize_points_python, + rasterize_points, + args, + args, + points1, # check gradients for first element in batch + points1, + compare_grads=True, + ) + + def test_cpu_vs_cuda_naive(self): + torch.manual_seed(231) + image_size = 64 + radius = 0.1 + points_per_pixel = 5 + + # Test homogeneous point cloud batch. + N = 2 + P = 1000 + bin_size = 0 + points_cpu = torch.rand(N, P, 3, requires_grad=True) + points_cuda = points_cpu.cuda().detach().requires_grad_(True) + pointclouds_cpu = Pointclouds(points=points_cpu) + pointclouds_cuda = Pointclouds(points=points_cuda) + args_cpu = ( + pointclouds_cpu, + image_size, + radius, + points_per_pixel, + bin_size, + ) + args_cuda = ( + pointclouds_cuda, + image_size, + radius, + points_per_pixel, + bin_size, + ) + self._compare_impls( + rasterize_points, + rasterize_points, + args_cpu, + args_cuda, + points_cpu, + points_cuda, + compare_grads=True, + ) + + def _compare_impls( + self, + fn1, + fn2, + args1, + args2, + grad_var1=None, + grad_var2=None, + compare_grads=False, + ): + idx1, zbuf1, dist1 = fn1(*args1) + torch.manual_seed(231) + grad_zbuf = torch.randn_like(zbuf1) + grad_dist = torch.randn_like(dist1) + loss = (zbuf1 * grad_zbuf).sum() + (dist1 * grad_dist).sum() + if compare_grads: + loss.backward() + grad_points1 = grad_var1.grad.data.clone().cpu() + + idx2, zbuf2, dist2 = fn2(*args2) + grad_zbuf = grad_zbuf.to(zbuf2) + grad_dist = grad_dist.to(dist2) + loss = (zbuf2 * grad_zbuf).sum() + (dist2 * grad_dist).sum() + if compare_grads: + # clear points1.grad in case args1 and args2 reused the same tensor + grad_var1.grad.data.zero_() + loss.backward() + grad_points2 = grad_var2.grad.data.clone().cpu() + + self.assertEqual((idx1.cpu() == idx2.cpu()).all().item(), 1) + self.assertEqual((zbuf1.cpu() == zbuf2.cpu()).all().item(), 1) + self.assertClose(dist1.cpu(), dist2.cpu()) + if compare_grads: + self.assertTrue( + torch.allclose(grad_points1, grad_points2, atol=2e-6) + ) + + def _test_behind_camera(self, rasterize_points_fn, device, bin_size=None): + # Test case where all points are behind the camera -- nothing should + # get rasterized + N = 2 + P = 32 + xy = torch.randn(N, P, 2) + z = torch.randn(N, P, 1).abs().mul(-1) # Make them all negative + points = torch.cat([xy, z], dim=2).to(device) + image_size = 16 + points_per_pixel = 3 + radius = 0.2 + idx_expected = torch.full( + (N, 16, 16, 3), fill_value=-1, dtype=torch.int32, device=device + ) + zbuf_expected = torch.full( + (N, 16, 16, 3), fill_value=-1, dtype=torch.float32, device=device + ) + dists_expected = zbuf_expected.clone() + pointclouds = Pointclouds(points=points) + if bin_size == -1: + # simple python case with no binning + idx, zbuf, dists = rasterize_points_fn( + pointclouds, image_size, radius, points_per_pixel + ) + else: + idx, zbuf, dists = rasterize_points_fn( + pointclouds, image_size, radius, points_per_pixel, bin_size + ) + idx_same = (idx == idx_expected).all().item() == 1 + zbuf_same = (zbuf == zbuf_expected).all().item() == 1 + + self.assertTrue(idx_same) + self.assertTrue(zbuf_same) + self.assertTrue(torch.allclose(dists, dists_expected)) + + def _simple_test_case(self, rasterize_points_fn, device, bin_size=0): + # Create two pointclouds with different numbers of points. + # fmt: off + points1 = torch.tensor( + [ + [0.0, 0.0, 0.0], # noqa: E241 + [0.4, 0.0, 0.1], # noqa: E241 + [0.0, 0.4, 0.2], # noqa: E241 + [0.0, 0.0, -0.1], # noqa: E241 Points with negative z should be skippped + ], + device=device, + ) + points2 = torch.tensor( + [ + [0.0, 0.0, 0.0], # noqa: E241 + [0.4, 0.0, 0.1], # noqa: E241 + [0.0, 0.4, 0.2], # noqa: E241 + [0.0, 0.0, -0.1], # noqa: E241 Points with negative z should be skippped + [0.0, 0.0, -0.7], # noqa: E241 Points with negative z should be skippped + ], + device=device, + ) + # fmt: on + pointclouds = Pointclouds(points=[points1, points2]) + + image_size = 5 + points_per_pixel = 2 + radius = 0.5 + + # The expected output values. Note that in the outputs, the world space + # +Y is up, and the world space +X is left. + idx1_expected = torch.full( + (1, 5, 5, 2), fill_value=-1, dtype=torch.int32, device=device + ) + # fmt: off + idx1_expected[0, :, :, 0] = torch.tensor([ + [-1, -1, 2, -1, -1], # noqa: E241 + [-1, 1, 0, 2, -1], # noqa: E241 + [ 1, 0, 0, 0, -1], # noqa: E241 E201 + [-1, 1, 0, -1, -1], # noqa: E241 + [-1, -1, -1, -1, -1], # noqa: E241 + ], device=device) + idx1_expected[0, :, :, 1] = torch.tensor([ + [-1, -1, -1, -1, -1], # noqa: E241 + [-1, 2, 2, -1, -1], # noqa: E241 + [-1, 1, 1, -1, -1], # noqa: E241 + [-1, -1, -1, -1, -1], # noqa: E241 + [-1, -1, -1, -1, -1], # noqa: E241 + ], device=device) + # fmt: on + + zbuf1_expected = torch.full( + (1, 5, 5, 2), fill_value=100, dtype=torch.float32, device=device + ) + # fmt: off + zbuf1_expected[0, :, :, 0] = torch.tensor([ + [-1.0, -1.0, 0.2, -1.0, -1.0], # noqa: E241 + [-1.0, 0.1, 0.0, 0.2, -1.0], # noqa: E241 + [ 0.1, 0.0, 0.0, 0.0, -1.0], # noqa: E241 E201 + [-1.0, 0.1, 0.0, -1.0, -1.0], # noqa: E241 + [-1.0, -1.0, -1.0, -1.0, -1.0] # noqa: E241 + ], device=device) + zbuf1_expected[0, :, :, 1] = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0, -1.0], # noqa: E241 + [-1.0, 0.2, 0.2, -1.0, -1.0], # noqa: E241 + [-1.0, 0.1, 0.1, -1.0, -1.0], # noqa: E241 + [-1.0, -1.0, -1.0, -1.0, -1.0], # noqa: E241 + [-1.0, -1.0, -1.0, -1.0, -1.0], # noqa: E241 + ], device=device) + # fmt: on + + dists1_expected = torch.full( + (1, 5, 5, 2), fill_value=0.0, dtype=torch.float32, device=device + ) + # fmt: off + dists1_expected[0, :, :, 0] = torch.tensor([ + [-1.00, -1.00, 0.16, -1.00, -1.00], # noqa: E241 + [-1.00, 0.16, 0.16, 0.16, -1.00], # noqa: E241 + [ 0.16, 0.16, 0.00, 0.16, -1.00], # noqa: E241 E201 + [-1.00, 0.16, 0.16, -1.00, -1.00], # noqa: E241 + [-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241 + ], device=device) + dists1_expected[0, :, :, 1] = torch.tensor([ + [-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241 + [-1.00, 0.16, 0.00, -1.00, -1.00], # noqa: E241 + [-1.00, 0.00, 0.16, -1.00, -1.00], # noqa: E241 + [-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241 + [-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241 + ], device=device) + # fmt: on + + if bin_size == -1: + # simple python case with no binning + idx, zbuf, dists = rasterize_points_fn( + pointclouds, image_size, radius, points_per_pixel + ) + else: + idx, zbuf, dists = rasterize_points_fn( + pointclouds, image_size, radius, points_per_pixel, bin_size + ) + + # check first point cloud + idx_same = (idx[0, ...] == idx1_expected).all().item() == 1 + if idx_same == 0: + print(idx[0, :, :, 0]) + print(idx[0, :, :, 1]) + zbuf_same = (zbuf[0, ...] == zbuf1_expected).all().item() == 1 + dist_same = torch.allclose(dists[0, ...], dists1_expected) + self.assertTrue(idx_same) + self.assertTrue(zbuf_same) + self.assertTrue(dist_same) + + # Check second point cloud - the indices in idx refer to points in the + # pointclouds.points_packed() tensor. In the second point cloud, + # two points are behind the screen - the expected indices are the same + # the first pointcloud but offset by the number of points in the + # first pointcloud. + num_points_per_cloud = pointclouds.num_points_per_cloud() + idx1_expected[idx1_expected >= 0] += num_points_per_cloud[0] + + idx_same = (idx[1, ...] == idx1_expected).all().item() == 1 + zbuf_same = (zbuf[1, ...] == zbuf1_expected).all().item() == 1 + self.assertTrue(idx_same) + self.assertTrue(zbuf_same) + self.assertTrue(torch.allclose(dists[1, ...], dists1_expected)) + + def test_coarse_cpu(self): + return self._test_coarse_rasterize(torch.device("cpu")) + + def test_coarse_cuda(self): + return self._test_coarse_rasterize(torch.device("cuda")) + + def test_compare_coarse_cpu_vs_cuda(self): + torch.manual_seed(231) + N = 3 + max_P = 1000 + image_size = 64 + radius = 0.1 + bin_size = 16 + max_points_per_bin = 500 + + # create heterogeneous point clouds + points = [] + for _ in range(N): + p = np.random.choice(max_P) + points.append(torch.randn(p, 3)) + + pointclouds = Pointclouds(points=points) + points_packed = pointclouds.points_packed() + cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx() + num_points_per_cloud = pointclouds.num_points_per_cloud() + args = ( + points_packed, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + bin_size, + max_points_per_bin, + ) + bp_cpu = _C._rasterize_points_coarse(*args) + + pointclouds_cuda = pointclouds.to("cuda:0") + points_packed = pointclouds_cuda.points_packed() + cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx() + num_points_per_cloud = pointclouds_cuda.num_points_per_cloud() + args = ( + points_packed, + cloud_to_packed_first_idx, + num_points_per_cloud, + image_size, + radius, + bin_size, + max_points_per_bin, + ) + bp_cuda = _C._rasterize_points_coarse(*args) + + # Bin points might not be the same: CUDA version might write them in + # any order. But if we sort the non-(-1) elements of the CUDA output + # then they should be the same. + for n in range(N): + for by in range(bp_cpu.shape[1]): + for bx in range(bp_cpu.shape[2]): + K = (bp_cpu[n, by, bx] != -1).sum().item() + idxs_cpu = bp_cpu[n, by, bx].tolist() + idxs_cuda = bp_cuda[n, by, bx].tolist() + idxs_cuda[:K] = sorted(idxs_cuda[:K]) + self.assertEqual(idxs_cpu, idxs_cuda) + + def _test_coarse_rasterize(self, device): + # + # Note that +Y is up and +X is left in the diagram below. + # + # (4) |2 + # | + # | + # | + # |1 + # | + # (1) | + # | (2) + # ____________(0)__(5)___________________ + # 2 1 | -1 -2 + # | + # (3) | + # | + # |-1 + # | + # + # Locations of the points are shown by o. The screen bounding box + # is between [-1, 1] in both the x and y directions. + # + # These points are interesting because: + # (0) Falls into two bins; + # (1) and (2) fall into one bin; + # (3) is out-of-bounds, but its disk is in-bounds; + # (4) is out-of-bounds, and its entire disk is also out-of-bounds + # (5) has a negative z-value, so it should be skipped + # fmt: off + points = torch.tensor( + [ + [ 0.5, 0.0, 0.0], # noqa: E241, E201 + [ 0.5, 0.5, 0.1], # noqa: E241, E201 + [-0.3, 0.4, 0.0], # noqa: E241 + [ 1.1, -0.5, 0.2], # noqa: E241, E201 + [ 2.0, 2.0, 0.3], # noqa: E241, E201 + [ 0.0, 0.0, -0.1], # noqa: E241, E201 + ], + device=device + ) + # fmt: on + image_size = 16 + radius = 0.2 + bin_size = 8 + max_points_per_bin = 5 + + bin_points_expected = -1 * torch.ones( + 1, 2, 2, 5, dtype=torch.int32, device=device + ) + # Note that the order is only deterministic here for CUDA if all points + # fit in one chunk. This will the the case for this small example, but + # to properly exercise coordianted writes among multiple chunks we need + # to use a bigger test case. + bin_points_expected[0, 1, 0, :2] = torch.tensor([0, 3]) + bin_points_expected[0, 0, 1, 0] = torch.tensor([2]) + bin_points_expected[0, 0, 0, :2] = torch.tensor([0, 1]) + + pointclouds = Pointclouds(points=[points]) + args = ( + pointclouds.points_packed(), + pointclouds.cloud_to_packed_first_idx(), + pointclouds.num_points_per_cloud(), + image_size, + radius, + bin_size, + max_points_per_bin, + ) + bin_points = _C._rasterize_points_coarse(*args) + bin_points_same = (bin_points == bin_points_expected).all() + self.assertTrue(bin_points_same.item() == 1)