mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Merge pull request #114 from nikhilaravi/fixup-T64213310-master
Re-sync with internal repository
This commit is contained in:
		
						commit
						eeb6bd3b09
					
				
							
								
								
									
										1
									
								
								pytorch3d/renderer/points/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								pytorch3d/renderer/points/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
							
								
								
									
										227
									
								
								pytorch3d/renderer/points/rasterize_points.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								pytorch3d/renderer/points/rasterize_points.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,227 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d import _C
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO(jcjohns): Support non-square images
 | 
			
		||||
def rasterize_points(
 | 
			
		||||
    pointclouds,
 | 
			
		||||
    image_size: int = 256,
 | 
			
		||||
    radius: float = 0.01,
 | 
			
		||||
    points_per_pixel: int = 8,
 | 
			
		||||
    bin_size: Optional[int] = None,
 | 
			
		||||
    max_points_per_bin: Optional[int] = None,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Pointcloud rasterization
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        pointclouds: A Pointclouds object representing a batch of point clouds to be
 | 
			
		||||
            rasterized. This is a batch of N pointclouds, where each point cloud
 | 
			
		||||
            can have a different number of points; the coordinates of each point
 | 
			
		||||
            are (x, y, z). The coordinates are expected to
 | 
			
		||||
            be in normalized device coordinates (NDC): [-1, 1]^3 with the camera at
 | 
			
		||||
            (0, 0, 0); the x-axis goes from left-to-right, the y-axis goes from
 | 
			
		||||
            top-to-bottom, and the z-axis goes from back-to-front.
 | 
			
		||||
        image_size: Integer giving the resolution of the rasterized image
 | 
			
		||||
        radius (Optional): Float giving the radius (in NDC units) of the disk to
 | 
			
		||||
            be rasterized for each point.
 | 
			
		||||
        points_per_pixel (Optional): We will keep track of this many points per
 | 
			
		||||
            pixel, returning the nearest points_per_pixel points along the z-axis
 | 
			
		||||
        bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
 | 
			
		||||
            bin_size=0 uses naive rasterization; setting bin_size=None attempts to
 | 
			
		||||
            set it heuristically based on the shape of the input. This should not
 | 
			
		||||
            affect the output, but can affect the speed of the forward pass.
 | 
			
		||||
        points_per_bin: Only applicable when using coarse-to-fine rasterization
 | 
			
		||||
            (bin_size > 0); this is the maxiumum number of points allowed within each
 | 
			
		||||
            bin. If more than this many points actually fall into a bin, an error
 | 
			
		||||
            will be raised. This should not affect the output values, but can affect
 | 
			
		||||
            the memory usage in the forward pass.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        3-element tuple containing
 | 
			
		||||
 | 
			
		||||
        - **idx**: int32 Tensor of shape (N, image_size, image_size, points_per_pixel)
 | 
			
		||||
          giving the indices of the nearest points at each pixel, in ascending
 | 
			
		||||
          z-order. Concretely `idx[n, y, x, k] = p` means that `points[p]` is the kth
 | 
			
		||||
          closest point (along the z-direction) to pixel (y, x) - note that points
 | 
			
		||||
          represents the packed points of shape (P, 3).
 | 
			
		||||
          Pixels that are hit by fewer than points_per_pixel are padded with -1.
 | 
			
		||||
        - **zbuf**: Tensor of shape (N, image_size, image_size, points_per_pixel)
 | 
			
		||||
          giving the z-coordinates of the nearest points at each pixel, sorted in
 | 
			
		||||
          z-order. Concretely, if `idx[n, y, x, k] = p` then
 | 
			
		||||
          `zbuf[n, y, x, k] = points[n, p, 2]`. Pixels hit by fewer than
 | 
			
		||||
          points_per_pixel are padded with -1
 | 
			
		||||
        - **dists2**: Tensor of shape (N, image_size, image_size, points_per_pixel)
 | 
			
		||||
          giving the squared Euclidean distance (in NDC units) in the x/y plane
 | 
			
		||||
          for each point closest to the pixel. Concretely if `idx[n, y, x, k] = p`
 | 
			
		||||
          then `dists[n, y, x, k]` is the squared distance between the pixel (y, x)
 | 
			
		||||
          and the point `(points[n, p, 0], points[n, p, 1])`. Pixels hit with fewer
 | 
			
		||||
          than points_per_pixel are padded with -1.
 | 
			
		||||
    """
 | 
			
		||||
    points_packed = pointclouds.points_packed()
 | 
			
		||||
    cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
 | 
			
		||||
    num_points_per_cloud = pointclouds.num_points_per_cloud()
 | 
			
		||||
 | 
			
		||||
    if bin_size is None:
 | 
			
		||||
        if not points_packed.is_cuda:
 | 
			
		||||
            # Binned CPU rasterization not fully implemented
 | 
			
		||||
            bin_size = 0
 | 
			
		||||
        else:
 | 
			
		||||
            # TODO: These heuristics are not well-thought out!
 | 
			
		||||
            if image_size <= 64:
 | 
			
		||||
                bin_size = 8
 | 
			
		||||
            elif image_size <= 256:
 | 
			
		||||
                bin_size = 16
 | 
			
		||||
            elif image_size <= 512:
 | 
			
		||||
                bin_size = 32
 | 
			
		||||
            elif image_size <= 1024:
 | 
			
		||||
                bin_size = 64
 | 
			
		||||
 | 
			
		||||
    if max_points_per_bin is None:
 | 
			
		||||
        max_points_per_bin = int(max(10000, points_packed.shape[0] / 5))
 | 
			
		||||
 | 
			
		||||
    # Function.apply cannot take keyword args, so we handle defaults in this
 | 
			
		||||
    # wrapper and call apply with positional args only
 | 
			
		||||
    return _RasterizePoints.apply(
 | 
			
		||||
        points_packed,
 | 
			
		||||
        cloud_to_packed_first_idx,
 | 
			
		||||
        num_points_per_cloud,
 | 
			
		||||
        image_size,
 | 
			
		||||
        radius,
 | 
			
		||||
        points_per_pixel,
 | 
			
		||||
        bin_size,
 | 
			
		||||
        max_points_per_bin,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _RasterizePoints(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def forward(
 | 
			
		||||
        ctx,
 | 
			
		||||
        points,  # (P, 3)
 | 
			
		||||
        cloud_to_packed_first_idx,
 | 
			
		||||
        num_points_per_cloud,
 | 
			
		||||
        image_size: int = 256,
 | 
			
		||||
        radius: float = 0.01,
 | 
			
		||||
        points_per_pixel: int = 8,
 | 
			
		||||
        bin_size: int = 0,
 | 
			
		||||
        max_points_per_bin: int = 0,
 | 
			
		||||
    ):
 | 
			
		||||
        # TODO: Add better error handling for when there are more than
 | 
			
		||||
        # max_points_per_bin in any bin.
 | 
			
		||||
        args = (
 | 
			
		||||
            points,
 | 
			
		||||
            cloud_to_packed_first_idx,
 | 
			
		||||
            num_points_per_cloud,
 | 
			
		||||
            image_size,
 | 
			
		||||
            radius,
 | 
			
		||||
            points_per_pixel,
 | 
			
		||||
            bin_size,
 | 
			
		||||
            max_points_per_bin,
 | 
			
		||||
        )
 | 
			
		||||
        idx, zbuf, dists = _C.rasterize_points(*args)
 | 
			
		||||
        ctx.save_for_backward(points, idx)
 | 
			
		||||
        return idx, zbuf, dists
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def backward(ctx, grad_idx, grad_zbuf, grad_dists):
 | 
			
		||||
        grad_points = None
 | 
			
		||||
        grad_cloud_to_packed_first_idx = None
 | 
			
		||||
        grad_num_points_per_cloud = None
 | 
			
		||||
        grad_image_size = None
 | 
			
		||||
        grad_radius = None
 | 
			
		||||
        grad_points_per_pixel = None
 | 
			
		||||
        grad_bin_size = None
 | 
			
		||||
        grad_max_points_per_bin = None
 | 
			
		||||
        points, idx = ctx.saved_tensors
 | 
			
		||||
        args = (points, idx, grad_zbuf, grad_dists)
 | 
			
		||||
        grad_points = _C.rasterize_points_backward(*args)
 | 
			
		||||
        grads = (
 | 
			
		||||
            grad_points,
 | 
			
		||||
            grad_cloud_to_packed_first_idx,
 | 
			
		||||
            grad_num_points_per_cloud,
 | 
			
		||||
            grad_image_size,
 | 
			
		||||
            grad_radius,
 | 
			
		||||
            grad_points_per_pixel,
 | 
			
		||||
            grad_bin_size,
 | 
			
		||||
            grad_max_points_per_bin,
 | 
			
		||||
        )
 | 
			
		||||
        return grads
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rasterize_points_python(
 | 
			
		||||
    pointclouds,
 | 
			
		||||
    image_size: int = 256,
 | 
			
		||||
    radius: float = 0.01,
 | 
			
		||||
    points_per_pixel: int = 8,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Naive pure PyTorch implementation of pointcloud rasterization.
 | 
			
		||||
 | 
			
		||||
    Inputs / Outputs: Same as above
 | 
			
		||||
    """
 | 
			
		||||
    N = len(pointclouds)
 | 
			
		||||
    S, K = image_size, points_per_pixel
 | 
			
		||||
    device = pointclouds.device
 | 
			
		||||
 | 
			
		||||
    points_packed = pointclouds.points_packed()
 | 
			
		||||
    cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
 | 
			
		||||
    num_points_per_cloud = pointclouds.num_points_per_cloud()
 | 
			
		||||
 | 
			
		||||
    # Intialize output tensors.
 | 
			
		||||
    point_idxs = torch.full(
 | 
			
		||||
        (N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
 | 
			
		||||
    )
 | 
			
		||||
    zbuf = torch.full(
 | 
			
		||||
        (N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
 | 
			
		||||
    )
 | 
			
		||||
    pix_dists = torch.full(
 | 
			
		||||
        (N, S, S, K), fill_value=-1, dtype=torch.float32, device=device
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # NDC is from [-1, 1]. Get pixel size using specified image size.
 | 
			
		||||
    radius2 = radius * radius
 | 
			
		||||
 | 
			
		||||
    # Iterate through the batch of point clouds.
 | 
			
		||||
    for n in range(N):
 | 
			
		||||
        point_start_idx = cloud_to_packed_first_idx[n]
 | 
			
		||||
        point_stop_idx = point_start_idx + num_points_per_cloud[n]
 | 
			
		||||
 | 
			
		||||
        # Iterate through the horizontal lines of the image from top to bottom.
 | 
			
		||||
        for yi in range(S):
 | 
			
		||||
            # Y coordinate of one end of the image. Reverse the ordering
 | 
			
		||||
            # of yi so that +Y is pointing up in the image.
 | 
			
		||||
            yfix = S - 1 - yi
 | 
			
		||||
            yf = pix_to_ndc(yfix, S)
 | 
			
		||||
 | 
			
		||||
            # Iterate through pixels on this horizontal line, left to right.
 | 
			
		||||
            for xi in range(S):
 | 
			
		||||
                # X coordinate of one end of the image. Reverse the ordering
 | 
			
		||||
                # of xi so that +X is pointing to the left in the image.
 | 
			
		||||
                xfix = S - 1 - xi
 | 
			
		||||
                xf = pix_to_ndc(xfix, S)
 | 
			
		||||
 | 
			
		||||
                top_k_points = []
 | 
			
		||||
                # Check whether each point in the batch affects this pixel.
 | 
			
		||||
                for p in range(point_start_idx, point_stop_idx):
 | 
			
		||||
                    px, py, pz = points_packed[p, :]
 | 
			
		||||
                    if pz < 0:
 | 
			
		||||
                        continue
 | 
			
		||||
                    dx = px - xf
 | 
			
		||||
                    dy = py - yf
 | 
			
		||||
                    dist2 = dx * dx + dy * dy
 | 
			
		||||
                    if dist2 < radius2:
 | 
			
		||||
                        top_k_points.append((pz, p, dist2))
 | 
			
		||||
                        top_k_points.sort()
 | 
			
		||||
                        if len(top_k_points) > K:
 | 
			
		||||
                            top_k_points = top_k_points[:K]
 | 
			
		||||
                for k, (pz, p, dist2) in enumerate(top_k_points):
 | 
			
		||||
                    zbuf[n, yi, xi, k] = pz
 | 
			
		||||
                    point_idxs[n, yi, xi, k] = p
 | 
			
		||||
                    pix_dists[n, yi, xi, k] = dist2
 | 
			
		||||
    return point_idxs, zbuf, pix_dists
 | 
			
		||||
							
								
								
									
										992
									
								
								pytorch3d/structures/pointclouds.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										992
									
								
								pytorch3d/structures/pointclouds.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,992 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from . import utils as struct_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Pointclouds(object):
 | 
			
		||||
    """
 | 
			
		||||
    This class provides functions for working with batches of 3d point clouds,
 | 
			
		||||
    and converting between representations.
 | 
			
		||||
 | 
			
		||||
    Within Pointclouds, there are three different representations of the data.
 | 
			
		||||
 | 
			
		||||
    List
 | 
			
		||||
       - only used for input as a starting point to convert to other representations.
 | 
			
		||||
    Padded
 | 
			
		||||
       - has specific batch dimension.
 | 
			
		||||
    Packed
 | 
			
		||||
       - no batch dimension.
 | 
			
		||||
       - has auxillary variables used to index into the padded representation.
 | 
			
		||||
 | 
			
		||||
    Example
 | 
			
		||||
 | 
			
		||||
    Input list of points = [[P_1], [P_2], ... , [P_N]]
 | 
			
		||||
    where P_1, ... , P_N are the number of points in each cloud and N is the
 | 
			
		||||
    number of clouds.
 | 
			
		||||
 | 
			
		||||
    # SPHINX IGNORE
 | 
			
		||||
     List                      | Padded                  | Packed
 | 
			
		||||
    ---------------------------|-------------------------|------------------------
 | 
			
		||||
    [[P_1], ... , [P_N]]       | size = (N, max(P_n), 3) |  size = (sum(P_n), 3)
 | 
			
		||||
                               |                         |
 | 
			
		||||
    Example for locations      |                         |
 | 
			
		||||
    or colors:                 |                         |
 | 
			
		||||
                               |                         |
 | 
			
		||||
    P_1 = 3, P_2 = 4, P_3 = 5  | size = (3, 5, 3)        |  size = (12, 3)
 | 
			
		||||
                               |                         |
 | 
			
		||||
    List([                     | tensor([                |  tensor([
 | 
			
		||||
      [                        |     [                   |    [0.1, 0.3, 0.5],
 | 
			
		||||
        [0.1, 0.3, 0.5],       |       [0.1, 0.3, 0.5],  |    [0.5, 0.2, 0.1],
 | 
			
		||||
        [0.5, 0.2, 0.1],       |       [0.5, 0.2, 0.1],  |    [0.6, 0.8, 0.7],
 | 
			
		||||
        [0.6, 0.8, 0.7]        |       [0.6, 0.8, 0.7],  |    [0.1, 0.3, 0.3],
 | 
			
		||||
      ],                       |       [0,    0,    0],  |    [0.6, 0.7, 0.8],
 | 
			
		||||
      [                        |       [0,    0,    0]   |    [0.2, 0.3, 0.4],
 | 
			
		||||
        [0.1, 0.3, 0.3],       |     ],                  |    [0.1, 0.5, 0.3],
 | 
			
		||||
        [0.6, 0.7, 0.8],       |     [                   |    [0.7, 0.3, 0.6],
 | 
			
		||||
        [0.2, 0.3, 0.4],       |       [0.1, 0.3, 0.3],  |    [0.2, 0.4, 0.8],
 | 
			
		||||
        [0.1, 0.5, 0.3]        |       [0.6, 0.7, 0.8],  |    [0.9, 0.5, 0.2],
 | 
			
		||||
      ],                       |       [0.2, 0.3, 0.4],  |    [0.2, 0.3, 0.4],
 | 
			
		||||
      [                        |       [0.1, 0.5, 0.3],  |    [0.9, 0.3, 0.8],
 | 
			
		||||
        [0.7, 0.3, 0.6],       |       [0,    0,    0]   |  ])
 | 
			
		||||
        [0.2, 0.4, 0.8],       |     ],                  |
 | 
			
		||||
        [0.9, 0.5, 0.2],       |     [                   |
 | 
			
		||||
        [0.2, 0.3, 0.4],       |       [0.7, 0.3, 0.6],  |
 | 
			
		||||
        [0.9, 0.3, 0.8],       |       [0.2, 0.4, 0.8],  |
 | 
			
		||||
      ]                        |       [0.9, 0.5, 0.2],  |
 | 
			
		||||
   ])                          |       [0.2, 0.3, 0.4],  |
 | 
			
		||||
                               |       [0.9, 0.3, 0.8]   |
 | 
			
		||||
                               |     ]                   |
 | 
			
		||||
                               |  ])                     |
 | 
			
		||||
    -----------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
    Auxillary variables for packed representation
 | 
			
		||||
 | 
			
		||||
    Name                           |   Size              |  Example from above
 | 
			
		||||
    -------------------------------|---------------------|-----------------------
 | 
			
		||||
                                   |                     |
 | 
			
		||||
    packed_to_cloud_idx            |  size = (sum(P_n))  |   tensor([
 | 
			
		||||
                                   |                     |     0, 0, 0, 1, 1, 1,
 | 
			
		||||
                                   |                     |     1, 2, 2, 2, 2, 2
 | 
			
		||||
                                   |                     |   )]
 | 
			
		||||
                                   |                     |   size = (12)
 | 
			
		||||
                                   |                     |
 | 
			
		||||
    cloud_to_packed_first_idx      |  size = (N)         |   tensor([0, 3, 7])
 | 
			
		||||
                                   |                     |   size = (3)
 | 
			
		||||
                                   |                     |
 | 
			
		||||
    num_points_per_cloud           |  size = (N)         |   tensor([3, 4, 5])
 | 
			
		||||
                                   |                     |   size = (3)
 | 
			
		||||
                                   |                     |
 | 
			
		||||
    padded_to_packed_idx           |  size = (sum(P_n))  |  tensor([
 | 
			
		||||
                                   |                     |     0, 1, 2, 5, 6, 7,
 | 
			
		||||
                                   |                     |     8, 10, 11, 12, 13,
 | 
			
		||||
                                   |                     |     14
 | 
			
		||||
                                   |                     |  )]
 | 
			
		||||
                                   |                     |  size = (12)
 | 
			
		||||
    -----------------------------------------------------------------------------
 | 
			
		||||
    # SPHINX IGNORE
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    _INTERNAL_TENSORS = [
 | 
			
		||||
        "_points_packed",
 | 
			
		||||
        "_points_padded",
 | 
			
		||||
        "_normals_packed",
 | 
			
		||||
        "_normals_padded",
 | 
			
		||||
        "_features_packed",
 | 
			
		||||
        "_features_padded",
 | 
			
		||||
        "_packed_to_cloud_idx",
 | 
			
		||||
        "_cloud_to_packed_first_idx",
 | 
			
		||||
        "_num_points_per_cloud",
 | 
			
		||||
        "_padded_to_packed_idx",
 | 
			
		||||
        "valid",
 | 
			
		||||
        "equisized",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    def __init__(self, points, normals=None, features=None):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            points:
 | 
			
		||||
                Can be either
 | 
			
		||||
 | 
			
		||||
                - List where each element is a tensor of shape (num_points, 3)
 | 
			
		||||
                  containing the (x, y, z) coordinates of each point.
 | 
			
		||||
                - Padded float tensor with shape (num_clouds, num_points, 3).
 | 
			
		||||
            normals:
 | 
			
		||||
                Can be either
 | 
			
		||||
 | 
			
		||||
                - List where each element is a tensor of shape (num_points, 3)
 | 
			
		||||
                  containing the normal vector for each point.
 | 
			
		||||
                - Padded float tensor of shape (num_clouds, num_points, 3).
 | 
			
		||||
            features:
 | 
			
		||||
                Can be either
 | 
			
		||||
 | 
			
		||||
                - List where each element is a tensor of shape (num_points, C)
 | 
			
		||||
                  containing the features for the points in the cloud.
 | 
			
		||||
                - Padded float tensor of shape (num_clouds, num_points, C).
 | 
			
		||||
                where C is the number of channels in the features.
 | 
			
		||||
                For example 3 for RGB color.
 | 
			
		||||
 | 
			
		||||
        Refer to comments above for descriptions of List and Padded
 | 
			
		||||
        representations.
 | 
			
		||||
        """
 | 
			
		||||
        self.device = None
 | 
			
		||||
 | 
			
		||||
        # Indicates whether the clouds in the list/batch have the same number
 | 
			
		||||
        # of points.
 | 
			
		||||
        self.equisized = False
 | 
			
		||||
 | 
			
		||||
        # Boolean indicator for each cloud in the batch.
 | 
			
		||||
        # True if cloud has non zero number of points, False otherwise.
 | 
			
		||||
        self.valid = None
 | 
			
		||||
 | 
			
		||||
        self._N = 0  # batch size (number of clouds)
 | 
			
		||||
        self._P = 0  # (max) number of points per cloud
 | 
			
		||||
        self._C = None  # number of channels in the features
 | 
			
		||||
 | 
			
		||||
        # List of Tensors of points and features.
 | 
			
		||||
        self._points_list = None
 | 
			
		||||
        self._normals_list = None
 | 
			
		||||
        self._features_list = None
 | 
			
		||||
 | 
			
		||||
        # Number of points per cloud.
 | 
			
		||||
        self._num_points_per_cloud = None  # N
 | 
			
		||||
 | 
			
		||||
        # Packed representation.
 | 
			
		||||
        self._points_packed = None  # (sum(P_n), 3)
 | 
			
		||||
        self._normals_packed = None  # (sum(P_n), 3)
 | 
			
		||||
        self._features_packed = None  # (sum(P_n), C)
 | 
			
		||||
 | 
			
		||||
        self._packed_to_cloud_idx = None  # sum(P_n)
 | 
			
		||||
 | 
			
		||||
        # Index of each cloud's first point in the packed points.
 | 
			
		||||
        # Assumes packing is sequential.
 | 
			
		||||
        self._cloud_to_packed_first_idx = None  # N
 | 
			
		||||
 | 
			
		||||
        # Padded representation.
 | 
			
		||||
        self._points_padded = None  # (N, max(P_n), 3)
 | 
			
		||||
        self._normals_padded = None  # (N, max(P_n), 3)
 | 
			
		||||
        self._features_padded = None  # (N, max(P_n), C)
 | 
			
		||||
 | 
			
		||||
        # Index to convert points from flattened padded to packed.
 | 
			
		||||
        self._padded_to_packed_idx = None  # N * max_P
 | 
			
		||||
 | 
			
		||||
        # Identify type of points.
 | 
			
		||||
        if isinstance(points, list):
 | 
			
		||||
            self._points_list = points
 | 
			
		||||
            self._N = len(self._points_list)
 | 
			
		||||
            self.device = torch.device("cpu")
 | 
			
		||||
            self.valid = torch.zeros(
 | 
			
		||||
                (self._N,), dtype=torch.bool, device=self.device
 | 
			
		||||
            )
 | 
			
		||||
            self._num_points_per_cloud = []
 | 
			
		||||
 | 
			
		||||
            if self._N > 0:
 | 
			
		||||
                for p in self._points_list:
 | 
			
		||||
                    if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3):
 | 
			
		||||
                        raise ValueError(
 | 
			
		||||
                            "Clouds in list must be of shape Px3 or empty"
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                self.device = self._points_list[0].device
 | 
			
		||||
                num_points_per_cloud = torch.tensor(
 | 
			
		||||
                    [len(p) for p in self._points_list], device=self.device
 | 
			
		||||
                )
 | 
			
		||||
                self._P = num_points_per_cloud.max()
 | 
			
		||||
                self.valid = torch.tensor(
 | 
			
		||||
                    [len(p) > 0 for p in self._points_list],
 | 
			
		||||
                    dtype=torch.bool,
 | 
			
		||||
                    device=self.device,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                if len(num_points_per_cloud.unique()) == 1:
 | 
			
		||||
                    self.equisized = True
 | 
			
		||||
                self._num_points_per_cloud = num_points_per_cloud
 | 
			
		||||
 | 
			
		||||
        elif torch.is_tensor(points):
 | 
			
		||||
            if points.dim() != 3 or points.shape[2] != 3:
 | 
			
		||||
                raise ValueError("Points tensor has incorrect dimensions.")
 | 
			
		||||
            self._points_padded = points
 | 
			
		||||
            self._N = self._points_padded.shape[0]
 | 
			
		||||
            self._P = self._points_padded.shape[1]
 | 
			
		||||
            self.device = self._points_padded.device
 | 
			
		||||
            self.valid = torch.ones(
 | 
			
		||||
                (self._N,), dtype=torch.bool, device=self.device
 | 
			
		||||
            )
 | 
			
		||||
            self._num_points_per_cloud = torch.tensor(
 | 
			
		||||
                [self._P] * self._N, device=self.device
 | 
			
		||||
            )
 | 
			
		||||
            self.equisized = True
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Points must be either a list or a tensor with \
 | 
			
		||||
                    shape (batch_size, P, 3) where P is the maximum number of \
 | 
			
		||||
                    points in a cloud."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # parse normals
 | 
			
		||||
        normals_parsed = self._parse_auxiliary_input(normals)
 | 
			
		||||
        self._normals_list, self._normals_padded, normals_C = normals_parsed
 | 
			
		||||
        if normals_C is not None and normals_C != 3:
 | 
			
		||||
            raise ValueError("Normals are expected to be 3-dimensional")
 | 
			
		||||
 | 
			
		||||
        # parse features
 | 
			
		||||
        features_parsed = self._parse_auxiliary_input(features)
 | 
			
		||||
        self._features_list, self._features_padded, features_C = features_parsed
 | 
			
		||||
        if features_C is not None:
 | 
			
		||||
            self._C = features_C
 | 
			
		||||
 | 
			
		||||
    def _parse_auxiliary_input(self, aux_input):
 | 
			
		||||
        """
 | 
			
		||||
        Interpret the auxiliary inputs (normals, features) given to __init__.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            aux_input:
 | 
			
		||||
              Can be either
 | 
			
		||||
 | 
			
		||||
                - List where each element is a tensor of shape (num_points, C)
 | 
			
		||||
                  containing the features for the points in the cloud.
 | 
			
		||||
                - Padded float tensor of shape (num_clouds, num_points, C).
 | 
			
		||||
              For normals, C = 3
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            3-element tuple of list, padded, num_channels.
 | 
			
		||||
            If aux_input is list, then padded is None. If aux_input is a tensor, then list is None.
 | 
			
		||||
        """
 | 
			
		||||
        if aux_input is None or self._N == 0:
 | 
			
		||||
            return None, None, None
 | 
			
		||||
 | 
			
		||||
        aux_input_C = None
 | 
			
		||||
 | 
			
		||||
        if isinstance(aux_input, list):
 | 
			
		||||
            if len(aux_input) != self._N:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Points and auxiliary input must be the same length."
 | 
			
		||||
                )
 | 
			
		||||
            for p, d in zip(self._num_points_per_cloud, aux_input):
 | 
			
		||||
                if p != d.shape[0]:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "A cloud has mismatched numbers of points and inputs"
 | 
			
		||||
                    )
 | 
			
		||||
                if p > 0:
 | 
			
		||||
                    if d.dim() != 2:
 | 
			
		||||
                        raise ValueError(
 | 
			
		||||
                            "A cloud auxiliary input must be of shape PxC or empty"
 | 
			
		||||
                        )
 | 
			
		||||
                    if aux_input_C is None:
 | 
			
		||||
                        aux_input_C = d.shape[1]
 | 
			
		||||
                    if aux_input_C != d.shape[1]:
 | 
			
		||||
                        raise ValueError(
 | 
			
		||||
                            "The clouds must have the same number of channels"
 | 
			
		||||
                        )
 | 
			
		||||
            return aux_input, None, aux_input_C
 | 
			
		||||
        elif torch.is_tensor(aux_input):
 | 
			
		||||
            if aux_input.dim() != 3:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Auxiliary input tensor has incorrect dimensions."
 | 
			
		||||
                )
 | 
			
		||||
            if self._N != aux_input.shape[0]:
 | 
			
		||||
                raise ValueError("Points and inputs must be the same length.")
 | 
			
		||||
            if self._P != aux_input.shape[1]:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Inputs tensor must have the right maximum \
 | 
			
		||||
                    number of points in each cloud."
 | 
			
		||||
                )
 | 
			
		||||
            aux_input_C = aux_input.shape[2]
 | 
			
		||||
            return None, aux_input, aux_input_C
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Auxiliary input must be either a list or a tensor with \
 | 
			
		||||
                    shape (batch_size, P, C) where P is the maximum number of \
 | 
			
		||||
                    points in a cloud."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return self._N
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            index: Specifying the index of the cloud to retrieve.
 | 
			
		||||
                Can be an int, slice, list of ints or a boolean tensor.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Pointclouds object with selected clouds. The tensors are not cloned.
 | 
			
		||||
        """
 | 
			
		||||
        normals, features = None, None
 | 
			
		||||
        if isinstance(index, int):
 | 
			
		||||
            points = [self.points_list()[index]]
 | 
			
		||||
            if self.normals_list() is not None:
 | 
			
		||||
                normals = [self.normals_list()[index]]
 | 
			
		||||
            if self.features_list() is not None:
 | 
			
		||||
                features = [self.features_list()[index]]
 | 
			
		||||
        elif isinstance(index, slice):
 | 
			
		||||
            points = self.points_list()[index]
 | 
			
		||||
            if self.normals_list() is not None:
 | 
			
		||||
                normals = self.normals_list()[index]
 | 
			
		||||
            if self.features_list() is not None:
 | 
			
		||||
                features = self.features_list()[index]
 | 
			
		||||
        elif isinstance(index, list):
 | 
			
		||||
            points = [self.points_list()[i] for i in index]
 | 
			
		||||
            if self.normals_list() is not None:
 | 
			
		||||
                normals = [self.normals_list()[i] for i in index]
 | 
			
		||||
            if self.features_list() is not None:
 | 
			
		||||
                features = [self.features_list()[i] for i in index]
 | 
			
		||||
        elif isinstance(index, torch.Tensor):
 | 
			
		||||
            if index.dim() != 1 or index.dtype.is_floating_point:
 | 
			
		||||
                raise IndexError(index)
 | 
			
		||||
            # NOTE consider converting index to cpu for efficiency
 | 
			
		||||
            if index.dtype == torch.bool:
 | 
			
		||||
                # advanced indexing on a single dimension
 | 
			
		||||
                index = index.nonzero()
 | 
			
		||||
                index = index.squeeze(1) if index.numel() > 0 else index
 | 
			
		||||
                index = index.tolist()
 | 
			
		||||
            points = [self.points_list()[i] for i in index]
 | 
			
		||||
            if self.normals_list() is not None:
 | 
			
		||||
                normals = [self.normals_list()[i] for i in index]
 | 
			
		||||
            if self.features_list() is not None:
 | 
			
		||||
                features = [self.features_list()[i] for i in index]
 | 
			
		||||
        else:
 | 
			
		||||
            raise IndexError(index)
 | 
			
		||||
 | 
			
		||||
        return Pointclouds(points=points, normals=normals, features=features)
 | 
			
		||||
 | 
			
		||||
    def isempty(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Checks whether any cloud is valid.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            bool indicating whether there is any data.
 | 
			
		||||
        """
 | 
			
		||||
        return self._N == 0 or self.valid.eq(False).all()
 | 
			
		||||
 | 
			
		||||
    def points_list(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get the list representation of the points.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of tensors of points of shape (P_n, 3).
 | 
			
		||||
        """
 | 
			
		||||
        if self._points_list is None:
 | 
			
		||||
            assert (
 | 
			
		||||
                self._points_padded is not None
 | 
			
		||||
            ), "points_padded is required to compute points_list."
 | 
			
		||||
            points_list = []
 | 
			
		||||
            for i in range(self._N):
 | 
			
		||||
                points_list.append(
 | 
			
		||||
                    self._points_padded[i, : self.num_points_per_cloud()[i]]
 | 
			
		||||
                )
 | 
			
		||||
            self._points_list = points_list
 | 
			
		||||
        return self._points_list
 | 
			
		||||
 | 
			
		||||
    def normals_list(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get the list representation of the normals.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of tensors of normals of shape (P_n, 3).
 | 
			
		||||
        """
 | 
			
		||||
        if self._normals_list is None:
 | 
			
		||||
            if self._normals_padded is None:
 | 
			
		||||
                # No normals provided so return None
 | 
			
		||||
                return None
 | 
			
		||||
            self._normals_list = []
 | 
			
		||||
            for i in range(self._N):
 | 
			
		||||
                self._normals_list.append(
 | 
			
		||||
                    self._normals_padded[i, : self.num_points_per_cloud()[i]]
 | 
			
		||||
                )
 | 
			
		||||
        return self._normals_list
 | 
			
		||||
 | 
			
		||||
    def features_list(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get the list representation of the features.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list of tensors of features of shape (P_n, C).
 | 
			
		||||
        """
 | 
			
		||||
        if self._features_list is None:
 | 
			
		||||
            if self._features_padded is None:
 | 
			
		||||
                # No features provided so return None
 | 
			
		||||
                return None
 | 
			
		||||
            self._features_list = []
 | 
			
		||||
            for i in range(self._N):
 | 
			
		||||
                self._features_list.append(
 | 
			
		||||
                    self._features_padded[i, : self.num_points_per_cloud()[i]]
 | 
			
		||||
                )
 | 
			
		||||
        return self._features_list
 | 
			
		||||
 | 
			
		||||
    def points_packed(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get the packed representation of the points.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tensor of points of shape (sum(P_n), 3).
 | 
			
		||||
        """
 | 
			
		||||
        self._compute_packed()
 | 
			
		||||
        return self._points_packed
 | 
			
		||||
 | 
			
		||||
    def normals_packed(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get the packed representation of the normals.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tensor of normals of shape (sum(P_n), 3).
 | 
			
		||||
        """
 | 
			
		||||
        self._compute_packed()
 | 
			
		||||
        return self._normals_packed
 | 
			
		||||
 | 
			
		||||
    def features_packed(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get the packed representation of the features.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tensor of features of shape (sum(P_n), C).
 | 
			
		||||
        """
 | 
			
		||||
        self._compute_packed()
 | 
			
		||||
        return self._features_packed
 | 
			
		||||
 | 
			
		||||
    def packed_to_cloud_idx(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return a 1D tensor x with length equal to the total number of points.
 | 
			
		||||
        packed_to_cloud_idx()[i] gives the index of the cloud which contains
 | 
			
		||||
        points_packed()[i].
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            1D tensor of indices.
 | 
			
		||||
        """
 | 
			
		||||
        self._compute_packed()
 | 
			
		||||
        return self._packed_to_cloud_idx
 | 
			
		||||
 | 
			
		||||
    def cloud_to_packed_first_idx(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return a 1D tensor x with length equal to the number of clouds such that
 | 
			
		||||
        the first point of the ith cloud is points_packed[x[i]].
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            1D tensor of indices of first items.
 | 
			
		||||
        """
 | 
			
		||||
        self._compute_packed()
 | 
			
		||||
        return self._cloud_to_packed_first_idx
 | 
			
		||||
 | 
			
		||||
    def num_points_per_cloud(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return a 1D tensor x with length equal to the number of clouds giving
 | 
			
		||||
        the number of points in each cloud.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            1D tensor of sizes.
 | 
			
		||||
        """
 | 
			
		||||
        return self._num_points_per_cloud
 | 
			
		||||
 | 
			
		||||
    def points_padded(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get the padded representation of the points.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tensor of points of shape (N, max(P_n), 3).
 | 
			
		||||
        """
 | 
			
		||||
        self._compute_padded()
 | 
			
		||||
        return self._points_padded
 | 
			
		||||
 | 
			
		||||
    def normals_padded(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get the padded representation of the normals.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tensor of normals of shape (N, max(P_n), 3).
 | 
			
		||||
        """
 | 
			
		||||
        self._compute_padded()
 | 
			
		||||
        return self._normals_padded
 | 
			
		||||
 | 
			
		||||
    def features_padded(self):
 | 
			
		||||
        """
 | 
			
		||||
        Get the padded representation of the features.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            tensor of features of shape (N, max(P_n), 3).
 | 
			
		||||
        """
 | 
			
		||||
        self._compute_padded()
 | 
			
		||||
        return self._features_padded
 | 
			
		||||
 | 
			
		||||
    def padded_to_packed_idx(self):
 | 
			
		||||
        """
 | 
			
		||||
        Return a 1D tensor x with length equal to the total number of points
 | 
			
		||||
        such that points_packed()[i] is element x[i] of the flattened padded
 | 
			
		||||
        representation.
 | 
			
		||||
        The packed representation can be calculated as follows.
 | 
			
		||||
 | 
			
		||||
        .. code-block:: python
 | 
			
		||||
 | 
			
		||||
            p = points_padded().reshape(-1, 3)
 | 
			
		||||
            points_packed = p[x]
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            1D tensor of indices.
 | 
			
		||||
        """
 | 
			
		||||
        self._compute_packed()
 | 
			
		||||
        if self._padded_to_packed_idx is not None:
 | 
			
		||||
            return self._padded_to_packed_idx
 | 
			
		||||
        if self._N == 0:
 | 
			
		||||
            self._padded_to_packed_idx = []
 | 
			
		||||
        else:
 | 
			
		||||
            self._padded_to_packed_idx = torch.cat(
 | 
			
		||||
                [
 | 
			
		||||
                    torch.arange(v, dtype=torch.int64, device=self.device)
 | 
			
		||||
                    + i * self._P
 | 
			
		||||
                    for (i, v) in enumerate(self._num_points_per_cloud)
 | 
			
		||||
                ],
 | 
			
		||||
                dim=0,
 | 
			
		||||
            )
 | 
			
		||||
        return self._padded_to_packed_idx
 | 
			
		||||
 | 
			
		||||
    def _compute_padded(self, refresh: bool = False):
 | 
			
		||||
        """
 | 
			
		||||
        Computes the padded version from points_list, normals_list and features_list.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            refresh: whether to force the recalculation.
 | 
			
		||||
        """
 | 
			
		||||
        if not (refresh or self._points_padded is None):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        self._normals_padded, self._features_padded = None, None
 | 
			
		||||
        if self.isempty():
 | 
			
		||||
            self._points_padded = torch.zeros(
 | 
			
		||||
                (self._N, 0, 3), device=self.device
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            self._points_padded = struct_utils.list_to_padded(
 | 
			
		||||
                self.points_list(),
 | 
			
		||||
                (self._P, 3),
 | 
			
		||||
                pad_value=0.0,
 | 
			
		||||
                equisized=self.equisized,
 | 
			
		||||
            )
 | 
			
		||||
            if self.normals_list() is not None:
 | 
			
		||||
                self._normals_padded = struct_utils.list_to_padded(
 | 
			
		||||
                    self.normals_list(),
 | 
			
		||||
                    (self._P, 3),
 | 
			
		||||
                    pad_value=0.0,
 | 
			
		||||
                    equisized=self.equisized,
 | 
			
		||||
                )
 | 
			
		||||
            if self.features_list() is not None:
 | 
			
		||||
                self._features_padded = struct_utils.list_to_padded(
 | 
			
		||||
                    self.features_list(),
 | 
			
		||||
                    (self._P, self._C),
 | 
			
		||||
                    pad_value=0.0,
 | 
			
		||||
                    equisized=self.equisized,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    # TODO(nikhilar) Improve performance of _compute_packed.
 | 
			
		||||
    def _compute_packed(self, refresh: bool = False):
 | 
			
		||||
        """
 | 
			
		||||
        Computes the packed version from points_list, normals_list and
 | 
			
		||||
        features_list and sets the values of auxillary tensors.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            refresh: Set to True to force recomputation of packed
 | 
			
		||||
                representations. Default: False.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if not (
 | 
			
		||||
            refresh
 | 
			
		||||
            or any(
 | 
			
		||||
                v is None
 | 
			
		||||
                for v in [
 | 
			
		||||
                    self._points_packed,
 | 
			
		||||
                    self._packed_to_cloud_idx,
 | 
			
		||||
                    self._cloud_to_packed_first_idx,
 | 
			
		||||
                ]
 | 
			
		||||
            )
 | 
			
		||||
        ):
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        # Packed can be calculated from padded or list, so can call the
 | 
			
		||||
        # accessor function for the lists.
 | 
			
		||||
        points_list = self.points_list()
 | 
			
		||||
        normals_list = self.normals_list()
 | 
			
		||||
        features_list = self.features_list()
 | 
			
		||||
        if self.isempty():
 | 
			
		||||
            self._points_packed = torch.zeros(
 | 
			
		||||
                (0, 3), dtype=torch.float32, device=self.device
 | 
			
		||||
            )
 | 
			
		||||
            self._packed_to_cloud_idx = torch.zeros(
 | 
			
		||||
                (0,), dtype=torch.int64, device=self.device
 | 
			
		||||
            )
 | 
			
		||||
            self._cloud_to_packed_first_idx = torch.zeros(
 | 
			
		||||
                (0,), dtype=torch.int64, device=self.device
 | 
			
		||||
            )
 | 
			
		||||
            self._normals_packed = None
 | 
			
		||||
            self._features_packed = None
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        points_list_to_packed = struct_utils.list_to_packed(points_list)
 | 
			
		||||
        self._points_packed = points_list_to_packed[0]
 | 
			
		||||
        if not torch.allclose(
 | 
			
		||||
            self._num_points_per_cloud, points_list_to_packed[1]
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError("Inconsistent list to packed conversion")
 | 
			
		||||
        self._cloud_to_packed_first_idx = points_list_to_packed[2]
 | 
			
		||||
        self._packed_to_cloud_idx = points_list_to_packed[3]
 | 
			
		||||
 | 
			
		||||
        self._normals_packed, self._features_packed = None, None
 | 
			
		||||
        if normals_list is not None:
 | 
			
		||||
            normals_list_to_packed = struct_utils.list_to_packed(normals_list)
 | 
			
		||||
            self._normals_packed = normals_list_to_packed[0]
 | 
			
		||||
 | 
			
		||||
        if features_list is not None:
 | 
			
		||||
            features_list_to_packed = struct_utils.list_to_packed(features_list)
 | 
			
		||||
            self._features_packed = features_list_to_packed[0]
 | 
			
		||||
 | 
			
		||||
    def clone(self):
 | 
			
		||||
        """
 | 
			
		||||
        Deep copy of Pointclouds object. All internal tensors are cloned
 | 
			
		||||
        individually.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            new Pointclouds object.
 | 
			
		||||
        """
 | 
			
		||||
        # instantiate new pointcloud with the representation which is not None
 | 
			
		||||
        # (either list or tensor) to save compute.
 | 
			
		||||
        new_points, new_normals, new_features = None, None, None
 | 
			
		||||
        if self._points_list is not None:
 | 
			
		||||
            new_points = [v.clone() for v in self.points_list()]
 | 
			
		||||
            normals_list = self.normals_list()
 | 
			
		||||
            features_list = self.features_list()
 | 
			
		||||
            if normals_list is not None:
 | 
			
		||||
                new_normals = [n.clone() for n in normals_list]
 | 
			
		||||
            if features_list is not None:
 | 
			
		||||
                new_features = [f.clone() for f in features_list]
 | 
			
		||||
        elif self._points_padded is not None:
 | 
			
		||||
            new_points = self.points_padded().clone()
 | 
			
		||||
            normals_padded = self.normals_padded()
 | 
			
		||||
            features_padded = self.features_padded()
 | 
			
		||||
            if normals_padded is not None:
 | 
			
		||||
                new_normals = self.normals_padded().clone()
 | 
			
		||||
            if features_padded is not None:
 | 
			
		||||
                new_features = self.features_padded().clone()
 | 
			
		||||
        other = Pointclouds(
 | 
			
		||||
            points=new_points, normals=new_normals, features=new_features
 | 
			
		||||
        )
 | 
			
		||||
        for k in self._INTERNAL_TENSORS:
 | 
			
		||||
            v = getattr(self, k)
 | 
			
		||||
            if torch.is_tensor(v):
 | 
			
		||||
                setattr(other, k, v.clone())
 | 
			
		||||
        return other
 | 
			
		||||
 | 
			
		||||
    def to(self, device, copy: bool = False):
 | 
			
		||||
        """
 | 
			
		||||
        Match functionality of torch.Tensor.to()
 | 
			
		||||
        If copy = True or the self Tensor is on a different device, the
 | 
			
		||||
        returned tensor is a copy of self with the desired torch.device.
 | 
			
		||||
        If copy = False and the self Tensor already has the correct torch.device,
 | 
			
		||||
        then self is returned.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
          device: Device id for the new tensor.
 | 
			
		||||
          copy: Boolean indicator whether or not to clone self. Default False.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
          Pointclouds object.
 | 
			
		||||
        """
 | 
			
		||||
        if not copy and self.device == device:
 | 
			
		||||
            return self
 | 
			
		||||
        other = self.clone()
 | 
			
		||||
        if self.device != device:
 | 
			
		||||
            other.device = device
 | 
			
		||||
            if other._N > 0:
 | 
			
		||||
                other._points_list = [v.to(device) for v in other.points_list()]
 | 
			
		||||
                if other._normals_list is not None:
 | 
			
		||||
                    other._normals_list = [
 | 
			
		||||
                        n.to(device) for n in other.normals_list()
 | 
			
		||||
                    ]
 | 
			
		||||
                if other._features_list is not None:
 | 
			
		||||
                    other._features_list = [
 | 
			
		||||
                        f.to(device) for f in other.features_list()
 | 
			
		||||
                    ]
 | 
			
		||||
            for k in self._INTERNAL_TENSORS:
 | 
			
		||||
                v = getattr(self, k)
 | 
			
		||||
                if torch.is_tensor(v):
 | 
			
		||||
                    setattr(other, k, v.to(device))
 | 
			
		||||
        return other
 | 
			
		||||
 | 
			
		||||
    def cpu(self):
 | 
			
		||||
        return self.to(torch.device("cpu"))
 | 
			
		||||
 | 
			
		||||
    def cuda(self):
 | 
			
		||||
        return self.to(torch.device("cuda"))
 | 
			
		||||
 | 
			
		||||
    def get_cloud(self, index: int):
 | 
			
		||||
        """
 | 
			
		||||
        Get tensors for a single cloud from the list representation.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            index: Integer in the range [0, N).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            points: Tensor of shape (P, 3).
 | 
			
		||||
            normals: Tensor of shape (P, 3)
 | 
			
		||||
            features: LongTensor of shape (P, C).
 | 
			
		||||
        """
 | 
			
		||||
        if not isinstance(index, int):
 | 
			
		||||
            raise ValueError("Cloud index must be an integer.")
 | 
			
		||||
        if index < 0 or index > self._N:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Cloud index must be in the range [0, N) where \
 | 
			
		||||
            N is the number of clouds in the batch."
 | 
			
		||||
            )
 | 
			
		||||
        points = self.points_list()[index]
 | 
			
		||||
        normals, features = None, None
 | 
			
		||||
        if self.normals_list() is not None:
 | 
			
		||||
            normals = self.normals_list()[index]
 | 
			
		||||
        if self.features_list() is not None:
 | 
			
		||||
            features = self.features_list()[index]
 | 
			
		||||
        return points, normals, features
 | 
			
		||||
 | 
			
		||||
    # TODO(nikhilar) Move function to a utils file.
 | 
			
		||||
    def split(self, split_sizes: list):
 | 
			
		||||
        """
 | 
			
		||||
        Splits Pointclouds object of size N into a list of Pointclouds objects
 | 
			
		||||
        of size len(split_sizes), where the i-th Pointclouds object is of size
 | 
			
		||||
        split_sizes[i]. Similar to torch.split().
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            split_sizes: List of integer sizes of Pointclouds objects to be
 | 
			
		||||
            returned.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            list[PointClouds].
 | 
			
		||||
        """
 | 
			
		||||
        if not all(isinstance(x, int) for x in split_sizes):
 | 
			
		||||
            raise ValueError("Value of split_sizes must be a list of integers.")
 | 
			
		||||
        cloudlist = []
 | 
			
		||||
        curi = 0
 | 
			
		||||
        for i in split_sizes:
 | 
			
		||||
            cloudlist.append(self[curi : curi + i])
 | 
			
		||||
            curi += i
 | 
			
		||||
        return cloudlist
 | 
			
		||||
 | 
			
		||||
    def offset_(self, offsets_packed):
 | 
			
		||||
        """
 | 
			
		||||
        Translate the point clouds by an offset. In place operation.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            offsets_packed: A Tensor of the same shape as self.points_packed
 | 
			
		||||
                giving offsets to be added to all points.
 | 
			
		||||
        Returns:
 | 
			
		||||
            self.
 | 
			
		||||
        """
 | 
			
		||||
        points_packed = self.points_packed()
 | 
			
		||||
        if offsets_packed.shape != points_packed.shape:
 | 
			
		||||
            raise ValueError("Offsets must have dimension (all_p, 3).")
 | 
			
		||||
        self._points_packed = points_packed + offsets_packed
 | 
			
		||||
        new_points_list = list(
 | 
			
		||||
            self._points_packed.split(self.num_points_per_cloud().tolist(), 0)
 | 
			
		||||
        )
 | 
			
		||||
        # Note that since _compute_packed() has been executed, points_list
 | 
			
		||||
        # cannot be None even if not provided during construction.
 | 
			
		||||
        self._points_list = new_points_list
 | 
			
		||||
        if self._points_padded is not None:
 | 
			
		||||
            for i, points in enumerate(new_points_list):
 | 
			
		||||
                if len(points) > 0:
 | 
			
		||||
                    self._points_padded[i, : points.shape[0], :] = points
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    # TODO(nikhilar) Move out of place operator to a utils file.
 | 
			
		||||
    def offset(self, offsets_packed):
 | 
			
		||||
        """
 | 
			
		||||
        Out of place offset.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            offsets_packed: A Tensor of the same shape as self.points_packed
 | 
			
		||||
                giving offsets to be added to all points.
 | 
			
		||||
        Returns:
 | 
			
		||||
            new Pointclouds object.
 | 
			
		||||
        """
 | 
			
		||||
        new_clouds = self.clone()
 | 
			
		||||
        return new_clouds.offset_(offsets_packed)
 | 
			
		||||
 | 
			
		||||
    def scale_(self, scale):
 | 
			
		||||
        """
 | 
			
		||||
        Multiply the coordinates of this object by a scalar value.
 | 
			
		||||
        - i.e. enlarge/dilate
 | 
			
		||||
        In place operation.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            scale: A scalar, or a Tensor of shape (N,).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            self.
 | 
			
		||||
        """
 | 
			
		||||
        if not torch.is_tensor(scale):
 | 
			
		||||
            scale = torch.full(len(self), scale)
 | 
			
		||||
        new_points_list = []
 | 
			
		||||
        points_list = self.points_list()
 | 
			
		||||
        for i, old_points in enumerate(points_list):
 | 
			
		||||
            new_points_list.append(scale[i] * old_points)
 | 
			
		||||
        self._points_list = new_points_list
 | 
			
		||||
        if self._points_packed is not None:
 | 
			
		||||
            self._points_packed = torch.cat(new_points_list, dim=0)
 | 
			
		||||
        if self._points_padded is not None:
 | 
			
		||||
            for i, points in enumerate(new_points_list):
 | 
			
		||||
                if len(points) > 0:
 | 
			
		||||
                    self._points_padded[i, : points.shape[0], :] = points
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def scale(self, scale):
 | 
			
		||||
        """
 | 
			
		||||
        Out of place scale_.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            scale: A scalar, or a Tensor of shape (N,).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            new Pointclouds object.
 | 
			
		||||
        """
 | 
			
		||||
        new_clouds = self.clone()
 | 
			
		||||
        return new_clouds.scale_(scale)
 | 
			
		||||
 | 
			
		||||
    # TODO(nikhilar) Move function to utils file.
 | 
			
		||||
    def get_bounding_boxes(self):
 | 
			
		||||
        """
 | 
			
		||||
        Compute an axis-aligned bounding box for each cloud.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            bboxes: Tensor of shape (N, 3, 2) where bbox[i, j] gives the
 | 
			
		||||
            min and max values of cloud i along the jth coordinate axis.
 | 
			
		||||
        """
 | 
			
		||||
        all_mins, all_maxes = [], []
 | 
			
		||||
        for points in self.points_list():
 | 
			
		||||
            cur_mins = points.min(dim=0)[0]  # (3,)
 | 
			
		||||
            cur_maxes = points.max(dim=0)[0]  # (3,)
 | 
			
		||||
            all_mins.append(cur_mins)
 | 
			
		||||
            all_maxes.append(cur_maxes)
 | 
			
		||||
        all_mins = torch.stack(all_mins, dim=0)  # (N, 3)
 | 
			
		||||
        all_maxes = torch.stack(all_maxes, dim=0)  # (N, 3)
 | 
			
		||||
        bboxes = torch.stack([all_mins, all_maxes], dim=2)
 | 
			
		||||
        return bboxes
 | 
			
		||||
 | 
			
		||||
    def extend(self, N: int):
 | 
			
		||||
        """
 | 
			
		||||
        Create new Pointclouds which contains each cloud N times.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            N: number of new copies of each cloud.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            new Pointclouds object.
 | 
			
		||||
        """
 | 
			
		||||
        if not isinstance(N, int):
 | 
			
		||||
            raise ValueError("N must be an integer.")
 | 
			
		||||
        if N <= 0:
 | 
			
		||||
            raise ValueError("N must be > 0.")
 | 
			
		||||
 | 
			
		||||
        new_points_list, new_normals_list, new_features_list = [], None, None
 | 
			
		||||
        for points in self.points_list():
 | 
			
		||||
            new_points_list.extend(points.clone() for _ in range(N))
 | 
			
		||||
        if self.normals_list() is not None:
 | 
			
		||||
            new_normals_list = []
 | 
			
		||||
            for normals in self.normals_list():
 | 
			
		||||
                new_normals_list.extend(normals.clone() for _ in range(N))
 | 
			
		||||
        if self.features_list() is not None:
 | 
			
		||||
            new_features_list = []
 | 
			
		||||
            for features in self.features_list():
 | 
			
		||||
                new_features_list.extend(features.clone() for _ in range(N))
 | 
			
		||||
        return Pointclouds(
 | 
			
		||||
            points=new_points_list,
 | 
			
		||||
            normals=new_normals_list,
 | 
			
		||||
            features=new_features_list,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def update_padded(
 | 
			
		||||
        self,
 | 
			
		||||
        new_points_padded,
 | 
			
		||||
        new_normals_padded=None,
 | 
			
		||||
        new_features_padded=None,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Returns a Pointcloud structure with updated padded tensors and copies of
 | 
			
		||||
        the auxiliary tensors. This function allows for an update of
 | 
			
		||||
        points_padded (and normals and features) without having to explicitly
 | 
			
		||||
        convert it to the list representation for heterogeneous batches.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            new_points_padded: FloatTensor of shape (N, P, 3)
 | 
			
		||||
            new_normals_padded: (optional) FloatTensor of shape (N, P, 3)
 | 
			
		||||
            new_features_padded: (optional) FloatTensors of shape (N, P, C)
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Pointcloud with updated padded representations
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def check_shapes(x, size):
 | 
			
		||||
            if x.shape[0] != size[0]:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "new values must have the same batch dimension."
 | 
			
		||||
                )
 | 
			
		||||
            if x.shape[1] != size[1]:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "new values must have the same number of points."
 | 
			
		||||
                )
 | 
			
		||||
            if size[2] is not None:
 | 
			
		||||
                if x.shape[2] != size[2]:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "new values must have the same number of channels."
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
        check_shapes(new_points_padded, [self._N, self._P, 3])
 | 
			
		||||
        if new_normals_padded is not None:
 | 
			
		||||
            check_shapes(new_normals_padded, [self._N, self._P, 3])
 | 
			
		||||
        if new_features_padded is not None:
 | 
			
		||||
            check_shapes(new_features_padded, [self._N, self._P, self._C])
 | 
			
		||||
 | 
			
		||||
        new = Pointclouds(
 | 
			
		||||
            points=new_points_padded,
 | 
			
		||||
            normals=new_normals_padded,
 | 
			
		||||
            features=new_features_padded,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # overwrite the equisized flag
 | 
			
		||||
        new.equisized = self.equisized
 | 
			
		||||
 | 
			
		||||
        # copy normals
 | 
			
		||||
        if new_normals_padded is None:
 | 
			
		||||
            # If no normals are provided, keep old ones (shallow copy)
 | 
			
		||||
            new._normals_list = self._normals_list
 | 
			
		||||
            new._normals_padded = self._normals_padded
 | 
			
		||||
            new._normals_packed = self._normals_packed
 | 
			
		||||
 | 
			
		||||
        # copy features
 | 
			
		||||
        if new_features_padded is None:
 | 
			
		||||
            # If no features are provided, keep old ones (shallow copy)
 | 
			
		||||
            new._features_list = self._features_list
 | 
			
		||||
            new._features_padded = self._features_padded
 | 
			
		||||
            new._features_packed = self._features_packed
 | 
			
		||||
 | 
			
		||||
        # copy auxiliary tensors
 | 
			
		||||
        copy_tensors = [
 | 
			
		||||
            "_packed_to_cloud_idx",
 | 
			
		||||
            "_cloud_to_packed_first_idx",
 | 
			
		||||
            "_num_points_per_cloud",
 | 
			
		||||
            "_padded_to_packed_idx",
 | 
			
		||||
            "valid",
 | 
			
		||||
        ]
 | 
			
		||||
        for k in copy_tensors:
 | 
			
		||||
            v = getattr(self, k)
 | 
			
		||||
            if torch.is_tensor(v):
 | 
			
		||||
                setattr(new, k, v)  # shallow copy
 | 
			
		||||
 | 
			
		||||
        # update points
 | 
			
		||||
        new._points_padded = new_points_padded
 | 
			
		||||
        assert new._points_list is None
 | 
			
		||||
        assert new._points_packed is None
 | 
			
		||||
 | 
			
		||||
        # update normals and features if provided
 | 
			
		||||
        if new_normals_padded is not None:
 | 
			
		||||
            new._normals_padded = new_normals_padded
 | 
			
		||||
            new._normals_list = None
 | 
			
		||||
            new._normals_packed = None
 | 
			
		||||
        if new_features_padded is not None:
 | 
			
		||||
            new._features_padded = new_features_padded
 | 
			
		||||
            new._features_list = None
 | 
			
		||||
            new._features_packed = None
 | 
			
		||||
        return new
 | 
			
		||||
							
								
								
									
										30
									
								
								tests/bm_pointclouds.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								tests/bm_pointclouds.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,30 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from itertools import product
 | 
			
		||||
from fvcore.common.benchmark import benchmark
 | 
			
		||||
 | 
			
		||||
from test_pointclouds import TestPointclouds
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bm_compute_packed_padded_pointclouds() -> None:
 | 
			
		||||
    kwargs_list = []
 | 
			
		||||
    num_clouds = [32, 128]
 | 
			
		||||
    max_p = [100, 10000]
 | 
			
		||||
    feats = [1, 10, 300]
 | 
			
		||||
    test_cases = product(num_clouds, max_p, feats)
 | 
			
		||||
    for case in test_cases:
 | 
			
		||||
        n, p, f = case
 | 
			
		||||
        kwargs_list.append({"num_clouds": n, "max_p": p, "features": f})
 | 
			
		||||
    benchmark(
 | 
			
		||||
        TestPointclouds.compute_packed_with_init,
 | 
			
		||||
        "COMPUTE_PACKED",
 | 
			
		||||
        kwargs_list,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
    )
 | 
			
		||||
    benchmark(
 | 
			
		||||
        TestPointclouds.compute_padded_with_init,
 | 
			
		||||
        "COMPUTE_PADDED",
 | 
			
		||||
        kwargs_list,
 | 
			
		||||
        warmup_iters=1,
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										52
									
								
								tests/bm_rasterize_points.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								tests/bm_rasterize_points.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,52 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from fvcore.common.benchmark import benchmark
 | 
			
		||||
 | 
			
		||||
from pytorch3d.renderer.points.rasterize_points import (
 | 
			
		||||
    rasterize_points,
 | 
			
		||||
    rasterize_points_python,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _bm_python_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
 | 
			
		||||
    torch.manual_seed(231)
 | 
			
		||||
    points = torch.randn(N, P, 3)
 | 
			
		||||
    pointclouds = Pointclouds(points=points)
 | 
			
		||||
    args = (pointclouds, img_size, radius, pts_per_pxl)
 | 
			
		||||
    return lambda: rasterize_points_python(*args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _bm_cpu_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
 | 
			
		||||
    torch.manual_seed(231)
 | 
			
		||||
    points = torch.randn(N, P, 3)
 | 
			
		||||
    pointclouds = Pointclouds(points=points)
 | 
			
		||||
    args = (pointclouds, img_size, radius, pts_per_pxl)
 | 
			
		||||
    return lambda: rasterize_points(*args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _bm_cuda_with_init(N, P, img_size=32, radius=0.1, pts_per_pxl=3):
 | 
			
		||||
    torch.manual_seed(231)
 | 
			
		||||
    points = torch.randn(N, P, 3, device=torch.device("cuda"))
 | 
			
		||||
    pointclouds = Pointclouds(points=points)
 | 
			
		||||
    args = (pointclouds, img_size, radius, pts_per_pxl)
 | 
			
		||||
    return lambda: rasterize_points(*args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def bm_python_vs_cpu() -> None:
 | 
			
		||||
    kwargs_list = [
 | 
			
		||||
        {"N": 1, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
 | 
			
		||||
        {"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
 | 
			
		||||
    ]
 | 
			
		||||
    benchmark(
 | 
			
		||||
        _bm_python_with_init, "RASTERIZE_PYTHON", kwargs_list, warmup_iters=1
 | 
			
		||||
    )
 | 
			
		||||
    benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
 | 
			
		||||
    kwargs_list = [
 | 
			
		||||
        {"N": 2, "P": 32, "img_size": 32, "radius": 0.1, "pts_per_pxl": 3},
 | 
			
		||||
        {"N": 4, "P": 1024, "img_size": 128, "radius": 0.05, "pts_per_pxl": 5},
 | 
			
		||||
    ]
 | 
			
		||||
    benchmark(_bm_cpu_with_init, "RASTERIZE_CPU", kwargs_list, warmup_iters=1)
 | 
			
		||||
    benchmark(_bm_cuda_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1)
 | 
			
		||||
							
								
								
									
										978
									
								
								tests/test_pointclouds.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										978
									
								
								tests/test_pointclouds.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,978 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import unittest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
 | 
			
		||||
from common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPointclouds(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def setUp(self) -> None:
 | 
			
		||||
        np.random.seed(42)
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def init_cloud(
 | 
			
		||||
        num_clouds: int = 3,
 | 
			
		||||
        max_points: int = 100,
 | 
			
		||||
        channels: int = 4,
 | 
			
		||||
        lists_to_tensors: bool = False,
 | 
			
		||||
        with_normals: bool = True,
 | 
			
		||||
        with_features: bool = True,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Function to generate a Pointclouds object of N meshes with
 | 
			
		||||
        random number of points.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            num_clouds: Number of clouds to generate.
 | 
			
		||||
            channels: Number of features.
 | 
			
		||||
            max_points: Max number of points per cloud.
 | 
			
		||||
            lists_to_tensors: Determines whether the generated clouds should be
 | 
			
		||||
                              constructed from lists (=False) or
 | 
			
		||||
                              tensors (=True) of points/normals/features.
 | 
			
		||||
            with_normals: bool whether to include normals
 | 
			
		||||
            with_features: bool whether to include features
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Pointclouds object.
 | 
			
		||||
        """
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        p = torch.randint(max_points, size=(num_clouds,))
 | 
			
		||||
        if lists_to_tensors:
 | 
			
		||||
            p.fill_(p[0])
 | 
			
		||||
 | 
			
		||||
        points_list = [
 | 
			
		||||
            torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
 | 
			
		||||
        ]
 | 
			
		||||
        normals_list, features_list = None, None
 | 
			
		||||
        if with_normals:
 | 
			
		||||
            normals_list = [
 | 
			
		||||
                torch.rand((i, 3), device=device, dtype=torch.float32)
 | 
			
		||||
                for i in p
 | 
			
		||||
            ]
 | 
			
		||||
        if with_features:
 | 
			
		||||
            features_list = [
 | 
			
		||||
                torch.rand((i, channels), device=device, dtype=torch.float32)
 | 
			
		||||
                for i in p
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        if lists_to_tensors:
 | 
			
		||||
            points_list = torch.stack(points_list)
 | 
			
		||||
            if with_normals:
 | 
			
		||||
                normals_list = torch.stack(normals_list)
 | 
			
		||||
            if with_features:
 | 
			
		||||
                features_list = torch.stack(features_list)
 | 
			
		||||
 | 
			
		||||
        return Pointclouds(
 | 
			
		||||
            points_list, normals=normals_list, features=features_list
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_simple(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        points = [
 | 
			
		||||
            torch.tensor(
 | 
			
		||||
                [[0.1, 0.3, 0.5], [0.5, 0.2, 0.1], [0.6, 0.8, 0.7]],
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
                device=device,
 | 
			
		||||
            ),
 | 
			
		||||
            torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [0.1, 0.3, 0.3],
 | 
			
		||||
                    [0.6, 0.7, 0.8],
 | 
			
		||||
                    [0.2, 0.3, 0.4],
 | 
			
		||||
                    [0.1, 0.5, 0.3],
 | 
			
		||||
                ],
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
                device=device,
 | 
			
		||||
            ),
 | 
			
		||||
            torch.tensor(
 | 
			
		||||
                [
 | 
			
		||||
                    [0.7, 0.3, 0.6],
 | 
			
		||||
                    [0.2, 0.4, 0.8],
 | 
			
		||||
                    [0.9, 0.5, 0.2],
 | 
			
		||||
                    [0.2, 0.3, 0.4],
 | 
			
		||||
                    [0.9, 0.3, 0.8],
 | 
			
		||||
                ],
 | 
			
		||||
                dtype=torch.float32,
 | 
			
		||||
                device=device,
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
        clouds = Pointclouds(points)
 | 
			
		||||
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            (clouds.packed_to_cloud_idx()).cpu(),
 | 
			
		||||
            torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            clouds.cloud_to_packed_first_idx().cpu(), torch.tensor([0, 3, 7])
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            clouds.num_points_per_cloud().cpu(), torch.tensor([3, 4, 5])
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            clouds.padded_to_packed_idx().cpu(),
 | 
			
		||||
            torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_all_constructions(self):
 | 
			
		||||
        public_getters = [
 | 
			
		||||
            "points_list",
 | 
			
		||||
            "points_packed",
 | 
			
		||||
            "packed_to_cloud_idx",
 | 
			
		||||
            "cloud_to_packed_first_idx",
 | 
			
		||||
            "num_points_per_cloud",
 | 
			
		||||
            "points_padded",
 | 
			
		||||
            "padded_to_packed_idx",
 | 
			
		||||
        ]
 | 
			
		||||
        public_normals_getters = [
 | 
			
		||||
            "normals_list",
 | 
			
		||||
            "normals_packed",
 | 
			
		||||
            "normals_padded",
 | 
			
		||||
        ]
 | 
			
		||||
        public_features_getters = [
 | 
			
		||||
            "features_list",
 | 
			
		||||
            "features_packed",
 | 
			
		||||
            "features_padded",
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        lengths = [3, 4, 2]
 | 
			
		||||
        max_len = max(lengths)
 | 
			
		||||
        C = 4
 | 
			
		||||
 | 
			
		||||
        points_data = [torch.zeros((max_len, 3)).uniform_() for i in lengths]
 | 
			
		||||
        normals_data = [torch.zeros((max_len, 3)).uniform_() for i in lengths]
 | 
			
		||||
        features_data = [torch.zeros((max_len, C)).uniform_() for i in lengths]
 | 
			
		||||
        for length, p, n, f in zip(
 | 
			
		||||
            lengths, points_data, normals_data, features_data
 | 
			
		||||
        ):
 | 
			
		||||
            p[length:] = 0.0
 | 
			
		||||
            n[length:] = 0.0
 | 
			
		||||
            f[length:] = 0.0
 | 
			
		||||
        points_list = [d[:length] for length, d in zip(lengths, points_data)]
 | 
			
		||||
        normals_list = [d[:length] for length, d in zip(lengths, normals_data)]
 | 
			
		||||
        features_list = [
 | 
			
		||||
            d[:length] for length, d in zip(lengths, features_data)
 | 
			
		||||
        ]
 | 
			
		||||
        points_packed = torch.cat(points_data)
 | 
			
		||||
        normals_packed = torch.cat(normals_data)
 | 
			
		||||
        features_packed = torch.cat(features_data)
 | 
			
		||||
        test_cases_inputs = [
 | 
			
		||||
            ("list_0_0", points_list, None, None),
 | 
			
		||||
            ("list_1_0", points_list, normals_list, None),
 | 
			
		||||
            ("list_0_1", points_list, None, features_list),
 | 
			
		||||
            ("list_1_1", points_list, normals_list, features_list),
 | 
			
		||||
            ("padded_0_0", points_data, None, None),
 | 
			
		||||
            ("padded_1_0", points_data, normals_data, None),
 | 
			
		||||
            ("padded_0_1", points_data, None, features_data),
 | 
			
		||||
            ("padded_1_1", points_data, normals_data, features_data),
 | 
			
		||||
            ("emptylist_emptylist_emptylist", [], [], []),
 | 
			
		||||
        ]
 | 
			
		||||
        false_cases_inputs = [
 | 
			
		||||
            (
 | 
			
		||||
                "list_packed",
 | 
			
		||||
                points_list,
 | 
			
		||||
                normals_packed,
 | 
			
		||||
                features_packed,
 | 
			
		||||
                ValueError,
 | 
			
		||||
            ),
 | 
			
		||||
            ("packed_0", points_packed, None, None, ValueError),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        for name, points, normals, features in test_cases_inputs:
 | 
			
		||||
            with self.subTest(name=name):
 | 
			
		||||
                p = Pointclouds(points, normals, features)
 | 
			
		||||
                for method in public_getters:
 | 
			
		||||
                    self.assertIsNotNone(getattr(p, method)())
 | 
			
		||||
                for method in public_normals_getters:
 | 
			
		||||
                    if normals is None or p.isempty():
 | 
			
		||||
                        self.assertIsNone(getattr(p, method)())
 | 
			
		||||
                for method in public_features_getters:
 | 
			
		||||
                    if features is None or p.isempty():
 | 
			
		||||
                        self.assertIsNone(getattr(p, method)())
 | 
			
		||||
 | 
			
		||||
        for name, points, normals, features, error in false_cases_inputs:
 | 
			
		||||
            with self.subTest(name=name):
 | 
			
		||||
                with self.assertRaises(error):
 | 
			
		||||
                    Pointclouds(points, normals, features)
 | 
			
		||||
 | 
			
		||||
    def test_simple_random_clouds(self):
 | 
			
		||||
        # Define the test object either from lists or tensors.
 | 
			
		||||
        for with_normals in (False, True):
 | 
			
		||||
            for with_features in (False, True):
 | 
			
		||||
                for lists_to_tensors in (False, True):
 | 
			
		||||
                    N = 10
 | 
			
		||||
                    cloud = self.init_cloud(
 | 
			
		||||
                        N,
 | 
			
		||||
                        lists_to_tensors=lists_to_tensors,
 | 
			
		||||
                        with_normals=with_normals,
 | 
			
		||||
                        with_features=with_features,
 | 
			
		||||
                    )
 | 
			
		||||
                    points_list = cloud.points_list()
 | 
			
		||||
                    normals_list = cloud.normals_list()
 | 
			
		||||
                    features_list = cloud.features_list()
 | 
			
		||||
 | 
			
		||||
                    # Check batch calculations.
 | 
			
		||||
                    points_padded = cloud.points_padded()
 | 
			
		||||
                    normals_padded = cloud.normals_padded()
 | 
			
		||||
                    features_padded = cloud.features_padded()
 | 
			
		||||
                    points_per_cloud = cloud.num_points_per_cloud()
 | 
			
		||||
 | 
			
		||||
                    if not with_normals:
 | 
			
		||||
                        self.assertIsNone(normals_list)
 | 
			
		||||
                        self.assertIsNone(normals_padded)
 | 
			
		||||
                    if not with_features:
 | 
			
		||||
                        self.assertIsNone(features_list)
 | 
			
		||||
                        self.assertIsNone(features_padded)
 | 
			
		||||
                    for n in range(N):
 | 
			
		||||
                        p = points_list[n].shape[0]
 | 
			
		||||
                        self.assertClose(
 | 
			
		||||
                            points_padded[n, :p, :], points_list[n]
 | 
			
		||||
                        )
 | 
			
		||||
                        if with_normals:
 | 
			
		||||
                            norms = normals_list[n].shape[0]
 | 
			
		||||
                            self.assertEqual(p, norms)
 | 
			
		||||
                            self.assertClose(
 | 
			
		||||
                                normals_padded[n, :p, :], normals_list[n]
 | 
			
		||||
                            )
 | 
			
		||||
                        if with_features:
 | 
			
		||||
                            f = features_list[n].shape[0]
 | 
			
		||||
                            self.assertEqual(p, f)
 | 
			
		||||
                            self.assertClose(
 | 
			
		||||
                                features_padded[n, :p, :], features_list[n]
 | 
			
		||||
                            )
 | 
			
		||||
                        if points_padded.shape[1] > p:
 | 
			
		||||
                            self.assertTrue(points_padded[n, p:, :].eq(0).all())
 | 
			
		||||
                            if with_features:
 | 
			
		||||
                                self.assertTrue(
 | 
			
		||||
                                    features_padded[n, p:, :].eq(0).all()
 | 
			
		||||
                                )
 | 
			
		||||
                        self.assertEqual(points_per_cloud[n], p)
 | 
			
		||||
 | 
			
		||||
                    # Check compute packed.
 | 
			
		||||
                    points_packed = cloud.points_packed()
 | 
			
		||||
                    packed_to_cloud = cloud.packed_to_cloud_idx()
 | 
			
		||||
                    cloud_to_packed = cloud.cloud_to_packed_first_idx()
 | 
			
		||||
                    normals_packed = cloud.normals_packed()
 | 
			
		||||
                    features_packed = cloud.features_packed()
 | 
			
		||||
                    if not with_normals:
 | 
			
		||||
                        self.assertIsNone(normals_packed)
 | 
			
		||||
                    if not with_features:
 | 
			
		||||
                        self.assertIsNone(features_packed)
 | 
			
		||||
 | 
			
		||||
                    cur = 0
 | 
			
		||||
                    for n in range(N):
 | 
			
		||||
                        p = points_list[n].shape[0]
 | 
			
		||||
                        self.assertClose(
 | 
			
		||||
                            points_packed[cur : cur + p, :], points_list[n]
 | 
			
		||||
                        )
 | 
			
		||||
                        if with_normals:
 | 
			
		||||
                            self.assertClose(
 | 
			
		||||
                                normals_packed[cur : cur + p, :],
 | 
			
		||||
                                normals_list[n],
 | 
			
		||||
                            )
 | 
			
		||||
                        if with_features:
 | 
			
		||||
                            self.assertClose(
 | 
			
		||||
                                features_packed[cur : cur + p, :],
 | 
			
		||||
                                features_list[n],
 | 
			
		||||
                            )
 | 
			
		||||
                        self.assertTrue(
 | 
			
		||||
                            packed_to_cloud[cur : cur + p].eq(n).all()
 | 
			
		||||
                        )
 | 
			
		||||
                        self.assertTrue(cloud_to_packed[n] == cur)
 | 
			
		||||
                        cur += p
 | 
			
		||||
 | 
			
		||||
    def test_allempty(self):
 | 
			
		||||
        clouds = Pointclouds([], [])
 | 
			
		||||
        self.assertEqual(len(clouds), 0)
 | 
			
		||||
        self.assertIsNone(clouds.normals_list())
 | 
			
		||||
        self.assertIsNone(clouds.features_list())
 | 
			
		||||
        self.assertEqual(clouds.points_padded().shape[0], 0)
 | 
			
		||||
        self.assertIsNone(clouds.normals_padded())
 | 
			
		||||
        self.assertIsNone(clouds.features_padded())
 | 
			
		||||
        self.assertEqual(clouds.points_packed().shape[0], 0)
 | 
			
		||||
        self.assertIsNone(clouds.normals_packed())
 | 
			
		||||
        self.assertIsNone(clouds.features_packed())
 | 
			
		||||
 | 
			
		||||
    def test_empty(self):
 | 
			
		||||
        N, P, C = 10, 100, 2
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        points_list = []
 | 
			
		||||
        normals_list = []
 | 
			
		||||
        features_list = []
 | 
			
		||||
        valid = torch.randint(2, size=(N,), dtype=torch.uint8, device=device)
 | 
			
		||||
        for n in range(N):
 | 
			
		||||
            if valid[n]:
 | 
			
		||||
                p = torch.randint(
 | 
			
		||||
                    3, high=P, size=(1,), dtype=torch.int32, device=device
 | 
			
		||||
                )[0]
 | 
			
		||||
                points = torch.rand((p, 3), dtype=torch.float32, device=device)
 | 
			
		||||
                normals = torch.rand((p, 3), dtype=torch.float32, device=device)
 | 
			
		||||
                features = torch.rand(
 | 
			
		||||
                    (p, C), dtype=torch.float32, device=device
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                points = torch.tensor([], dtype=torch.float32, device=device)
 | 
			
		||||
                normals = torch.tensor([], dtype=torch.float32, device=device)
 | 
			
		||||
                features = torch.tensor([], dtype=torch.int64, device=device)
 | 
			
		||||
            points_list.append(points)
 | 
			
		||||
            normals_list.append(normals)
 | 
			
		||||
            features_list.append(features)
 | 
			
		||||
 | 
			
		||||
        for with_normals in (False, True):
 | 
			
		||||
            for with_features in (False, True):
 | 
			
		||||
                this_features, this_normals = None, None
 | 
			
		||||
                if with_normals:
 | 
			
		||||
                    this_normals = normals_list
 | 
			
		||||
                if with_features:
 | 
			
		||||
                    this_features = features_list
 | 
			
		||||
                clouds = Pointclouds(
 | 
			
		||||
                    points=points_list,
 | 
			
		||||
                    normals=this_normals,
 | 
			
		||||
                    features=this_features,
 | 
			
		||||
                )
 | 
			
		||||
                points_padded = clouds.points_padded()
 | 
			
		||||
                normals_padded = clouds.normals_padded()
 | 
			
		||||
                features_padded = clouds.features_padded()
 | 
			
		||||
                if not with_normals:
 | 
			
		||||
                    self.assertIsNone(normals_padded)
 | 
			
		||||
                if not with_features:
 | 
			
		||||
                    self.assertIsNone(features_padded)
 | 
			
		||||
                points_per_cloud = clouds.num_points_per_cloud()
 | 
			
		||||
                for n in range(N):
 | 
			
		||||
                    p = len(points_list[n])
 | 
			
		||||
                    if p > 0:
 | 
			
		||||
                        self.assertClose(
 | 
			
		||||
                            points_padded[n, :p, :], points_list[n]
 | 
			
		||||
                        )
 | 
			
		||||
                        if with_normals:
 | 
			
		||||
                            self.assertClose(
 | 
			
		||||
                                normals_padded[n, :p, :], normals_list[n]
 | 
			
		||||
                            )
 | 
			
		||||
                        if with_features:
 | 
			
		||||
                            self.assertClose(
 | 
			
		||||
                                features_padded[n, :p, :], features_list[n]
 | 
			
		||||
                            )
 | 
			
		||||
                        if points_padded.shape[1] > p:
 | 
			
		||||
                            self.assertTrue(points_padded[n, p:, :].eq(0).all())
 | 
			
		||||
                            if with_normals:
 | 
			
		||||
                                self.assertTrue(
 | 
			
		||||
                                    normals_padded[n, p:, :].eq(0).all()
 | 
			
		||||
                                )
 | 
			
		||||
                            if with_features:
 | 
			
		||||
                                self.assertTrue(
 | 
			
		||||
                                    features_padded[n, p:, :].eq(0).all()
 | 
			
		||||
                                )
 | 
			
		||||
                    self.assertTrue(points_per_cloud[n] == p)
 | 
			
		||||
 | 
			
		||||
    def test_clone_list(self):
 | 
			
		||||
        N = 5
 | 
			
		||||
        clouds = self.init_cloud(N, 100, 5)
 | 
			
		||||
        for force in (False, True):
 | 
			
		||||
            if force:
 | 
			
		||||
                clouds.points_packed()
 | 
			
		||||
 | 
			
		||||
            new_clouds = clouds.clone()
 | 
			
		||||
 | 
			
		||||
            # Check cloned and original objects do not share tensors.
 | 
			
		||||
            self.assertSeparate(
 | 
			
		||||
                new_clouds.points_list()[0], clouds.points_list()[0]
 | 
			
		||||
            )
 | 
			
		||||
            self.assertSeparate(
 | 
			
		||||
                new_clouds.normals_list()[0], clouds.normals_list()[0]
 | 
			
		||||
            )
 | 
			
		||||
            self.assertSeparate(
 | 
			
		||||
                new_clouds.features_list()[0], clouds.features_list()[0]
 | 
			
		||||
            )
 | 
			
		||||
            for attrib in [
 | 
			
		||||
                "points_packed",
 | 
			
		||||
                "normals_packed",
 | 
			
		||||
                "features_packed",
 | 
			
		||||
                "points_padded",
 | 
			
		||||
                "normals_padded",
 | 
			
		||||
                "features_padded",
 | 
			
		||||
            ]:
 | 
			
		||||
                self.assertSeparate(
 | 
			
		||||
                    getattr(new_clouds, attrib)(), getattr(clouds, attrib)()
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            self.assertCloudsEqual(clouds, new_clouds)
 | 
			
		||||
 | 
			
		||||
    def test_clone_tensor(self):
 | 
			
		||||
        N = 5
 | 
			
		||||
        clouds = self.init_cloud(N, 100, 5, lists_to_tensors=True)
 | 
			
		||||
        for force in (False, True):
 | 
			
		||||
            if force:
 | 
			
		||||
                clouds.points_packed()
 | 
			
		||||
 | 
			
		||||
            new_clouds = clouds.clone()
 | 
			
		||||
 | 
			
		||||
            # Check cloned and original objects do not share tensors.
 | 
			
		||||
            self.assertSeparate(
 | 
			
		||||
                new_clouds.points_list()[0], clouds.points_list()[0]
 | 
			
		||||
            )
 | 
			
		||||
            self.assertSeparate(
 | 
			
		||||
                new_clouds.normals_list()[0], clouds.normals_list()[0]
 | 
			
		||||
            )
 | 
			
		||||
            self.assertSeparate(
 | 
			
		||||
                new_clouds.features_list()[0], clouds.features_list()[0]
 | 
			
		||||
            )
 | 
			
		||||
            for attrib in [
 | 
			
		||||
                "points_packed",
 | 
			
		||||
                "normals_packed",
 | 
			
		||||
                "features_packed",
 | 
			
		||||
                "points_padded",
 | 
			
		||||
                "normals_padded",
 | 
			
		||||
                "features_padded",
 | 
			
		||||
            ]:
 | 
			
		||||
                self.assertSeparate(
 | 
			
		||||
                    getattr(new_clouds, attrib)(), getattr(clouds, attrib)()
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            self.assertCloudsEqual(clouds, new_clouds)
 | 
			
		||||
 | 
			
		||||
    def assertCloudsEqual(self, cloud1, cloud2):
 | 
			
		||||
        N = len(cloud1)
 | 
			
		||||
        self.assertEqual(N, len(cloud2))
 | 
			
		||||
 | 
			
		||||
        for i in range(N):
 | 
			
		||||
            self.assertClose(cloud1.points_list()[i], cloud2.points_list()[i])
 | 
			
		||||
            self.assertClose(cloud1.normals_list()[i], cloud2.normals_list()[i])
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                cloud1.features_list()[i], cloud2.features_list()[i]
 | 
			
		||||
            )
 | 
			
		||||
        has_normals = cloud1.normals_list() is not None
 | 
			
		||||
        self.assertTrue(has_normals == (cloud2.normals_list() is not None))
 | 
			
		||||
        has_features = cloud1.features_list() is not None
 | 
			
		||||
        self.assertTrue(has_features == (cloud2.features_list() is not None))
 | 
			
		||||
 | 
			
		||||
        # check padded & packed
 | 
			
		||||
        self.assertClose(cloud1.points_padded(), cloud2.points_padded())
 | 
			
		||||
        self.assertClose(cloud1.points_packed(), cloud2.points_packed())
 | 
			
		||||
        if has_normals:
 | 
			
		||||
            self.assertClose(cloud1.normals_padded(), cloud2.normals_padded())
 | 
			
		||||
            self.assertClose(cloud1.normals_packed(), cloud2.normals_packed())
 | 
			
		||||
        if has_features:
 | 
			
		||||
            self.assertClose(cloud1.features_padded(), cloud2.features_padded())
 | 
			
		||||
            self.assertClose(cloud1.features_packed(), cloud2.features_packed())
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            cloud1.packed_to_cloud_idx(), cloud2.packed_to_cloud_idx()
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            cloud1.cloud_to_packed_first_idx(),
 | 
			
		||||
            cloud2.cloud_to_packed_first_idx(),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            cloud1.num_points_per_cloud(), cloud2.num_points_per_cloud()
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            cloud1.packed_to_cloud_idx(), cloud2.packed_to_cloud_idx()
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            cloud1.padded_to_packed_idx(), cloud2.padded_to_packed_idx()
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(all(cloud1.valid == cloud2.valid))
 | 
			
		||||
        self.assertTrue(cloud1.equisized == cloud2.equisized)
 | 
			
		||||
 | 
			
		||||
    def test_offset(self):
 | 
			
		||||
        def naive_offset(clouds, offsets_packed):
 | 
			
		||||
            new_points_packed = clouds.points_packed() + offsets_packed
 | 
			
		||||
            new_points_list = list(
 | 
			
		||||
                new_points_packed.split(
 | 
			
		||||
                    clouds.num_points_per_cloud().tolist(), 0
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            return Pointclouds(
 | 
			
		||||
                points=new_points_list,
 | 
			
		||||
                normals=clouds.normals_list(),
 | 
			
		||||
                features=clouds.features_list(),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        N = 5
 | 
			
		||||
        clouds = self.init_cloud(N, 100, 10)
 | 
			
		||||
        all_p = clouds.points_packed().size(0)
 | 
			
		||||
        points_per_cloud = clouds.num_points_per_cloud()
 | 
			
		||||
        for force in (False, True):
 | 
			
		||||
            if force:
 | 
			
		||||
                clouds._compute_packed(refresh=True)
 | 
			
		||||
                clouds._compute_padded()
 | 
			
		||||
                clouds.padded_to_packed_idx()
 | 
			
		||||
 | 
			
		||||
            deform = torch.rand(
 | 
			
		||||
                (all_p, 3), dtype=torch.float32, device=clouds.device
 | 
			
		||||
            )
 | 
			
		||||
            new_clouds_naive = naive_offset(clouds, deform)
 | 
			
		||||
 | 
			
		||||
            new_clouds = clouds.offset(deform)
 | 
			
		||||
 | 
			
		||||
            points_cumsum = torch.cumsum(points_per_cloud, 0).tolist()
 | 
			
		||||
            points_cumsum.insert(0, 0)
 | 
			
		||||
            for i in range(N):
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    new_clouds.points_list()[i],
 | 
			
		||||
                    clouds.points_list()[i]
 | 
			
		||||
                    + deform[points_cumsum[i] : points_cumsum[i + 1]],
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    clouds.normals_list()[i], new_clouds_naive.normals_list()[i]
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    clouds.features_list()[i],
 | 
			
		||||
                    new_clouds_naive.features_list()[i],
 | 
			
		||||
                )
 | 
			
		||||
            self.assertCloudsEqual(new_clouds, new_clouds_naive)
 | 
			
		||||
 | 
			
		||||
    def test_scale(self):
 | 
			
		||||
        def naive_scale(cloud, scale):
 | 
			
		||||
            if not torch.is_tensor(scale):
 | 
			
		||||
                scale = torch.full(len(cloud), scale)
 | 
			
		||||
            new_points_list = [
 | 
			
		||||
                scale[i] * points.clone()
 | 
			
		||||
                for (i, points) in enumerate(cloud.points_list())
 | 
			
		||||
            ]
 | 
			
		||||
            return Pointclouds(
 | 
			
		||||
                new_points_list, cloud.normals_list(), cloud.features_list()
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        N = 5
 | 
			
		||||
        clouds = self.init_cloud(N, 100, 10)
 | 
			
		||||
        for force in (False, True):
 | 
			
		||||
            if force:
 | 
			
		||||
                clouds._compute_packed(refresh=True)
 | 
			
		||||
                clouds._compute_padded()
 | 
			
		||||
                clouds.padded_to_packed_idx()
 | 
			
		||||
            scales = torch.rand(N)
 | 
			
		||||
            new_clouds_naive = naive_scale(clouds, scales)
 | 
			
		||||
            new_clouds = clouds.scale(scales)
 | 
			
		||||
            for i in range(N):
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    scales[i] * clouds.points_list()[i],
 | 
			
		||||
                    new_clouds.points_list()[i],
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    clouds.normals_list()[i], new_clouds_naive.normals_list()[i]
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    clouds.features_list()[i],
 | 
			
		||||
                    new_clouds_naive.features_list()[i],
 | 
			
		||||
                )
 | 
			
		||||
            self.assertCloudsEqual(new_clouds, new_clouds_naive)
 | 
			
		||||
 | 
			
		||||
    def test_extend_list(self):
 | 
			
		||||
        N = 10
 | 
			
		||||
        clouds = self.init_cloud(N, 100, 10)
 | 
			
		||||
        for force in (False, True):
 | 
			
		||||
            if force:
 | 
			
		||||
                # force some computes to happen
 | 
			
		||||
                clouds._compute_packed(refresh=True)
 | 
			
		||||
                clouds._compute_padded()
 | 
			
		||||
                clouds.padded_to_packed_idx()
 | 
			
		||||
            new_clouds = clouds.extend(N)
 | 
			
		||||
            self.assertEqual(len(clouds) * 10, len(new_clouds))
 | 
			
		||||
            for i in range(len(clouds)):
 | 
			
		||||
                for n in range(N):
 | 
			
		||||
                    self.assertClose(
 | 
			
		||||
                        clouds.points_list()[i],
 | 
			
		||||
                        new_clouds.points_list()[i * N + n],
 | 
			
		||||
                    )
 | 
			
		||||
                    self.assertClose(
 | 
			
		||||
                        clouds.normals_list()[i],
 | 
			
		||||
                        new_clouds.normals_list()[i * N + n],
 | 
			
		||||
                    )
 | 
			
		||||
                    self.assertClose(
 | 
			
		||||
                        clouds.features_list()[i],
 | 
			
		||||
                        new_clouds.features_list()[i * N + n],
 | 
			
		||||
                    )
 | 
			
		||||
                    self.assertTrue(
 | 
			
		||||
                        clouds.valid[i] == new_clouds.valid[i * N + n]
 | 
			
		||||
                    )
 | 
			
		||||
            self.assertAllSeparate(
 | 
			
		||||
                clouds.points_list()
 | 
			
		||||
                + new_clouds.points_list()
 | 
			
		||||
                + clouds.normals_list()
 | 
			
		||||
                + new_clouds.normals_list()
 | 
			
		||||
                + clouds.features_list()
 | 
			
		||||
                + new_clouds.features_list()
 | 
			
		||||
            )
 | 
			
		||||
            self.assertIsNone(new_clouds._points_packed)
 | 
			
		||||
            self.assertIsNone(new_clouds._normals_packed)
 | 
			
		||||
            self.assertIsNone(new_clouds._features_packed)
 | 
			
		||||
            self.assertIsNone(new_clouds._points_padded)
 | 
			
		||||
            self.assertIsNone(new_clouds._normals_padded)
 | 
			
		||||
            self.assertIsNone(new_clouds._features_padded)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            clouds.extend(N=-1)
 | 
			
		||||
 | 
			
		||||
    def test_to_list(self):
 | 
			
		||||
        cloud = self.init_cloud(5, 100, 10)
 | 
			
		||||
        device = torch.device("cuda:1")
 | 
			
		||||
 | 
			
		||||
        new_cloud = cloud.to(device)
 | 
			
		||||
        self.assertTrue(new_cloud.device == device)
 | 
			
		||||
        self.assertTrue(cloud.device == torch.device("cuda:0"))
 | 
			
		||||
        for attrib in [
 | 
			
		||||
            "points_padded",
 | 
			
		||||
            "points_packed",
 | 
			
		||||
            "normals_padded",
 | 
			
		||||
            "normals_packed",
 | 
			
		||||
            "features_padded",
 | 
			
		||||
            "features_packed",
 | 
			
		||||
            "num_points_per_cloud",
 | 
			
		||||
            "cloud_to_packed_first_idx",
 | 
			
		||||
            "padded_to_packed_idx",
 | 
			
		||||
        ]:
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                getattr(new_cloud, attrib)().cpu(),
 | 
			
		||||
                getattr(cloud, attrib)().cpu(),
 | 
			
		||||
            )
 | 
			
		||||
        for i in range(len(cloud)):
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
 | 
			
		||||
            )
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
 | 
			
		||||
            )
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                cloud.features_list()[i].cpu(),
 | 
			
		||||
                new_cloud.features_list()[i].cpu(),
 | 
			
		||||
            )
 | 
			
		||||
        self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
 | 
			
		||||
        self.assertTrue(cloud.equisized == new_cloud.equisized)
 | 
			
		||||
        self.assertTrue(cloud._N == new_cloud._N)
 | 
			
		||||
        self.assertTrue(cloud._P == new_cloud._P)
 | 
			
		||||
        self.assertTrue(cloud._C == new_cloud._C)
 | 
			
		||||
 | 
			
		||||
    def test_to_tensor(self):
 | 
			
		||||
        cloud = self.init_cloud(5, 100, 10, lists_to_tensors=True)
 | 
			
		||||
        device = torch.device("cuda:1")
 | 
			
		||||
 | 
			
		||||
        new_cloud = cloud.to(device)
 | 
			
		||||
        self.assertTrue(new_cloud.device == device)
 | 
			
		||||
        self.assertTrue(cloud.device == torch.device("cuda:0"))
 | 
			
		||||
        for attrib in [
 | 
			
		||||
            "points_padded",
 | 
			
		||||
            "points_packed",
 | 
			
		||||
            "normals_padded",
 | 
			
		||||
            "normals_packed",
 | 
			
		||||
            "features_padded",
 | 
			
		||||
            "features_packed",
 | 
			
		||||
            "num_points_per_cloud",
 | 
			
		||||
            "cloud_to_packed_first_idx",
 | 
			
		||||
            "padded_to_packed_idx",
 | 
			
		||||
        ]:
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                getattr(new_cloud, attrib)().cpu(),
 | 
			
		||||
                getattr(cloud, attrib)().cpu(),
 | 
			
		||||
            )
 | 
			
		||||
        for i in range(len(cloud)):
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
 | 
			
		||||
            )
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
 | 
			
		||||
            )
 | 
			
		||||
            self.assertClose(
 | 
			
		||||
                cloud.features_list()[i].cpu(),
 | 
			
		||||
                new_cloud.features_list()[i].cpu(),
 | 
			
		||||
            )
 | 
			
		||||
        self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
 | 
			
		||||
        self.assertTrue(cloud.equisized == new_cloud.equisized)
 | 
			
		||||
        self.assertTrue(cloud._N == new_cloud._N)
 | 
			
		||||
        self.assertTrue(cloud._P == new_cloud._P)
 | 
			
		||||
        self.assertTrue(cloud._C == new_cloud._C)
 | 
			
		||||
 | 
			
		||||
    def test_split(self):
 | 
			
		||||
        clouds = self.init_cloud(5, 100, 10)
 | 
			
		||||
        split_sizes = [2, 3]
 | 
			
		||||
        split_clouds = clouds.split(split_sizes)
 | 
			
		||||
        self.assertEqual(len(split_clouds[0]), 2)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            split_clouds[0].points_list()
 | 
			
		||||
            == [clouds.get_cloud(0)[0], clouds.get_cloud(1)[0]]
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(len(split_clouds[1]), 3)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            split_clouds[1].points_list()
 | 
			
		||||
            == [
 | 
			
		||||
                clouds.get_cloud(2)[0],
 | 
			
		||||
                clouds.get_cloud(3)[0],
 | 
			
		||||
                clouds.get_cloud(4)[0],
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        split_sizes = [2, 0.3]
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            clouds.split(split_sizes)
 | 
			
		||||
 | 
			
		||||
    def test_get_cloud(self):
 | 
			
		||||
        clouds = self.init_cloud(2, 100, 10)
 | 
			
		||||
        for i in range(len(clouds)):
 | 
			
		||||
            points, normals, features = clouds.get_cloud(i)
 | 
			
		||||
            self.assertClose(points, clouds.points_list()[i])
 | 
			
		||||
            self.assertClose(normals, clouds.normals_list()[i])
 | 
			
		||||
            self.assertClose(features, clouds.features_list()[i])
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            clouds.get_cloud(5)
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            clouds.get_cloud(0.2)
 | 
			
		||||
 | 
			
		||||
    def test_get_bounding_boxes(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        points_list = []
 | 
			
		||||
        for size in [10]:
 | 
			
		||||
            points = torch.rand((size, 3), dtype=torch.float32, device=device)
 | 
			
		||||
            points_list.append(points)
 | 
			
		||||
 | 
			
		||||
        mins = torch.min(points, dim=0)[0]
 | 
			
		||||
        maxs = torch.max(points, dim=0)[0]
 | 
			
		||||
        bboxes_gt = torch.stack([mins, maxs], dim=1).unsqueeze(0)
 | 
			
		||||
        clouds = Pointclouds(points_list)
 | 
			
		||||
        bboxes = clouds.get_bounding_boxes()
 | 
			
		||||
        self.assertClose(bboxes_gt, bboxes)
 | 
			
		||||
 | 
			
		||||
    def test_padded_to_packed_idx(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        points_list = []
 | 
			
		||||
        npoints = [10, 20, 30]
 | 
			
		||||
        for p in npoints:
 | 
			
		||||
            points = torch.rand((p, 3), dtype=torch.float32, device=device)
 | 
			
		||||
            points_list.append(points)
 | 
			
		||||
 | 
			
		||||
        clouds = Pointclouds(points_list)
 | 
			
		||||
 | 
			
		||||
        padded_to_packed_idx = clouds.padded_to_packed_idx()
 | 
			
		||||
        points_packed = clouds.points_packed()
 | 
			
		||||
        points_padded = clouds.points_padded()
 | 
			
		||||
        points_padded_flat = points_padded.view(-1, 3)
 | 
			
		||||
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            points_padded_flat[padded_to_packed_idx], points_packed
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        idx = padded_to_packed_idx.view(-1, 1).expand(-1, 3)
 | 
			
		||||
        self.assertClose(points_padded_flat.gather(0, idx), points_packed)
 | 
			
		||||
 | 
			
		||||
    def test_getitem(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        clouds = self.init_cloud(3, 10, 100)
 | 
			
		||||
 | 
			
		||||
        def check_equal(selected, indices):
 | 
			
		||||
            for selectedIdx, index in indices:
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    selected.points_list()[selectedIdx],
 | 
			
		||||
                    clouds.points_list()[index],
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    selected.normals_list()[selectedIdx],
 | 
			
		||||
                    clouds.normals_list()[index],
 | 
			
		||||
                )
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    selected.features_list()[selectedIdx],
 | 
			
		||||
                    clouds.features_list()[index],
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # int index
 | 
			
		||||
        index = 1
 | 
			
		||||
        clouds_selected = clouds[index]
 | 
			
		||||
        self.assertEqual(len(clouds_selected), 1)
 | 
			
		||||
        check_equal(clouds_selected, [(0, 1)])
 | 
			
		||||
 | 
			
		||||
        # list index
 | 
			
		||||
        index = [1, 2]
 | 
			
		||||
        clouds_selected = clouds[index]
 | 
			
		||||
        self.assertEqual(len(clouds_selected), len(index))
 | 
			
		||||
        check_equal(clouds_selected, enumerate(index))
 | 
			
		||||
 | 
			
		||||
        # slice index
 | 
			
		||||
        index = slice(0, 2, 1)
 | 
			
		||||
        clouds_selected = clouds[index]
 | 
			
		||||
        self.assertEqual(len(clouds_selected), 2)
 | 
			
		||||
        check_equal(clouds_selected, [(0, 0), (1, 1)])
 | 
			
		||||
 | 
			
		||||
        # bool tensor
 | 
			
		||||
        index = torch.tensor([1, 0, 1], dtype=torch.bool, device=device)
 | 
			
		||||
        clouds_selected = clouds[index]
 | 
			
		||||
        self.assertEqual(len(clouds_selected), index.sum())
 | 
			
		||||
        check_equal(clouds_selected, [(0, 0), (1, 2)])
 | 
			
		||||
 | 
			
		||||
        # int tensor
 | 
			
		||||
        index = torch.tensor([1, 2], dtype=torch.int64, device=device)
 | 
			
		||||
        clouds_selected = clouds[index]
 | 
			
		||||
        self.assertEqual(len(clouds_selected), index.numel())
 | 
			
		||||
        check_equal(clouds_selected, enumerate(index.tolist()))
 | 
			
		||||
 | 
			
		||||
        # invalid index
 | 
			
		||||
        index = torch.tensor([1, 0, 1], dtype=torch.float32, device=device)
 | 
			
		||||
        with self.assertRaises(IndexError):
 | 
			
		||||
            clouds_selected = clouds[index]
 | 
			
		||||
        index = 1.2
 | 
			
		||||
        with self.assertRaises(IndexError):
 | 
			
		||||
            clouds_selected = clouds[index]
 | 
			
		||||
 | 
			
		||||
    def test_update_padded(self):
 | 
			
		||||
        N, P, C = 5, 100, 4
 | 
			
		||||
        for with_normfeat in (True, False):
 | 
			
		||||
            for with_new_normfeat in (True, False):
 | 
			
		||||
                clouds = self.init_cloud(
 | 
			
		||||
                    N,
 | 
			
		||||
                    P,
 | 
			
		||||
                    C,
 | 
			
		||||
                    with_normals=with_normfeat,
 | 
			
		||||
                    with_features=with_normfeat,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                num_points_per_cloud = clouds.num_points_per_cloud()
 | 
			
		||||
 | 
			
		||||
                # initialize new points, normals, features
 | 
			
		||||
                new_points = torch.rand(
 | 
			
		||||
                    clouds.points_padded().shape, device=clouds.device
 | 
			
		||||
                )
 | 
			
		||||
                new_points_list = [
 | 
			
		||||
                    new_points[i, : num_points_per_cloud[i]] for i in range(N)
 | 
			
		||||
                ]
 | 
			
		||||
                new_normals, new_normals_list = None, None
 | 
			
		||||
                new_features, new_features_list = None, None
 | 
			
		||||
                if with_new_normfeat:
 | 
			
		||||
                    new_normals = torch.rand(
 | 
			
		||||
                        clouds.points_padded().shape, device=clouds.device
 | 
			
		||||
                    )
 | 
			
		||||
                    new_normals_list = [
 | 
			
		||||
                        new_normals[i, : num_points_per_cloud[i]]
 | 
			
		||||
                        for i in range(N)
 | 
			
		||||
                    ]
 | 
			
		||||
                    feat_shape = [
 | 
			
		||||
                        clouds.points_padded().shape[0],
 | 
			
		||||
                        clouds.points_padded().shape[1],
 | 
			
		||||
                        C,
 | 
			
		||||
                    ]
 | 
			
		||||
                    new_features = torch.rand(feat_shape, device=clouds.device)
 | 
			
		||||
                    new_features_list = [
 | 
			
		||||
                        new_features[i, : num_points_per_cloud[i]]
 | 
			
		||||
                        for i in range(N)
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
                # update
 | 
			
		||||
                new_clouds = clouds.update_padded(
 | 
			
		||||
                    new_points, new_normals, new_features
 | 
			
		||||
                )
 | 
			
		||||
                self.assertIsNone(new_clouds._points_list)
 | 
			
		||||
                self.assertIsNone(new_clouds._points_packed)
 | 
			
		||||
 | 
			
		||||
                self.assertEqual(new_clouds.equisized, clouds.equisized)
 | 
			
		||||
                self.assertTrue(all(new_clouds.valid == clouds.valid))
 | 
			
		||||
 | 
			
		||||
                self.assertClose(new_clouds.points_padded(), new_points)
 | 
			
		||||
                self.assertClose(
 | 
			
		||||
                    new_clouds.points_packed(), torch.cat(new_points_list)
 | 
			
		||||
                )
 | 
			
		||||
                for i in range(N):
 | 
			
		||||
                    self.assertClose(
 | 
			
		||||
                        new_clouds.points_list()[i], new_points_list[i]
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                if with_new_normfeat:
 | 
			
		||||
                    for i in range(N):
 | 
			
		||||
                        self.assertClose(
 | 
			
		||||
                            new_clouds.normals_list()[i], new_normals_list[i]
 | 
			
		||||
                        )
 | 
			
		||||
                        self.assertClose(
 | 
			
		||||
                            new_clouds.features_list()[i], new_features_list[i]
 | 
			
		||||
                        )
 | 
			
		||||
                    self.assertClose(new_clouds.normals_padded(), new_normals)
 | 
			
		||||
                    self.assertClose(
 | 
			
		||||
                        new_clouds.normals_packed(), torch.cat(new_normals_list)
 | 
			
		||||
                    )
 | 
			
		||||
                    self.assertClose(new_clouds.features_padded(), new_features)
 | 
			
		||||
                    self.assertClose(
 | 
			
		||||
                        new_clouds.features_packed(),
 | 
			
		||||
                        torch.cat(new_features_list),
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    if with_normfeat:
 | 
			
		||||
                        for i in range(N):
 | 
			
		||||
                            self.assertClose(
 | 
			
		||||
                                new_clouds.normals_list()[i],
 | 
			
		||||
                                clouds.normals_list()[i],
 | 
			
		||||
                            )
 | 
			
		||||
                            self.assertClose(
 | 
			
		||||
                                new_clouds.features_list()[i],
 | 
			
		||||
                                clouds.features_list()[i],
 | 
			
		||||
                            )
 | 
			
		||||
                            self.assertNotSeparate(
 | 
			
		||||
                                new_clouds.normals_list()[i],
 | 
			
		||||
                                clouds.normals_list()[i],
 | 
			
		||||
                            )
 | 
			
		||||
                            self.assertNotSeparate(
 | 
			
		||||
                                new_clouds.features_list()[i],
 | 
			
		||||
                                clouds.features_list()[i],
 | 
			
		||||
                            )
 | 
			
		||||
 | 
			
		||||
                        self.assertClose(
 | 
			
		||||
                            new_clouds.normals_padded(), clouds.normals_padded()
 | 
			
		||||
                        )
 | 
			
		||||
                        self.assertClose(
 | 
			
		||||
                            new_clouds.normals_packed(), clouds.normals_packed()
 | 
			
		||||
                        )
 | 
			
		||||
                        self.assertClose(
 | 
			
		||||
                            new_clouds.features_padded(),
 | 
			
		||||
                            clouds.features_padded(),
 | 
			
		||||
                        )
 | 
			
		||||
                        self.assertClose(
 | 
			
		||||
                            new_clouds.features_packed(),
 | 
			
		||||
                            clouds.features_packed(),
 | 
			
		||||
                        )
 | 
			
		||||
                        self.assertNotSeparate(
 | 
			
		||||
                            new_clouds.normals_padded(), clouds.normals_padded()
 | 
			
		||||
                        )
 | 
			
		||||
                        self.assertNotSeparate(
 | 
			
		||||
                            new_clouds.features_padded(),
 | 
			
		||||
                            clouds.features_padded(),
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        self.assertIsNone(new_clouds.normals_list())
 | 
			
		||||
                        self.assertIsNone(new_clouds.features_list())
 | 
			
		||||
                        self.assertIsNone(new_clouds.normals_padded())
 | 
			
		||||
                        self.assertIsNone(new_clouds.features_padded())
 | 
			
		||||
                        self.assertIsNone(new_clouds.normals_packed())
 | 
			
		||||
                        self.assertIsNone(new_clouds.features_packed())
 | 
			
		||||
 | 
			
		||||
                for attrib in [
 | 
			
		||||
                    "num_points_per_cloud",
 | 
			
		||||
                    "cloud_to_packed_first_idx",
 | 
			
		||||
                    "padded_to_packed_idx",
 | 
			
		||||
                ]:
 | 
			
		||||
                    self.assertClose(
 | 
			
		||||
                        getattr(new_clouds, attrib)(), getattr(clouds, attrib)()
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def compute_packed_with_init(
 | 
			
		||||
        num_clouds: int = 10, max_p: int = 100, features: int = 300
 | 
			
		||||
    ):
 | 
			
		||||
        clouds = TestPointclouds.init_cloud(num_clouds, max_p, features)
 | 
			
		||||
        torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        def compute_packed():
 | 
			
		||||
            clouds._compute_packed(refresh=True)
 | 
			
		||||
            torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        return compute_packed
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def compute_padded_with_init(
 | 
			
		||||
        num_clouds: int = 10, max_p: int = 100, features: int = 300
 | 
			
		||||
    ):
 | 
			
		||||
        clouds = TestPointclouds.init_cloud(num_clouds, max_p, features)
 | 
			
		||||
        torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        def compute_padded():
 | 
			
		||||
            clouds._compute_padded(refresh=True)
 | 
			
		||||
            torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        return compute_padded
 | 
			
		||||
							
								
								
									
										525
									
								
								tests/test_rasterize_points.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										525
									
								
								tests/test_rasterize_points.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,525 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import unittest
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d import _C
 | 
			
		||||
from pytorch3d.renderer.points.rasterize_points import (
 | 
			
		||||
    rasterize_points,
 | 
			
		||||
    rasterize_points_python,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
 | 
			
		||||
from common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_python_simple_cpu(self):
 | 
			
		||||
        self._simple_test_case(
 | 
			
		||||
            rasterize_points_python, torch.device("cpu"), bin_size=-1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_naive_simple_cpu(self):
 | 
			
		||||
        device = torch.device("cpu")
 | 
			
		||||
        self._simple_test_case(rasterize_points, device)
 | 
			
		||||
 | 
			
		||||
    def test_naive_simple_cuda(self):
 | 
			
		||||
        device = torch.device("cuda")
 | 
			
		||||
        self._simple_test_case(rasterize_points, device, bin_size=0)
 | 
			
		||||
 | 
			
		||||
    def test_python_behind_camera(self):
 | 
			
		||||
        self._test_behind_camera(
 | 
			
		||||
            rasterize_points_python, torch.device("cpu"), bin_size=-1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_cpu_behind_camera(self):
 | 
			
		||||
        self._test_behind_camera(rasterize_points, torch.device("cpu"))
 | 
			
		||||
 | 
			
		||||
    def test_cuda_behind_camera(self):
 | 
			
		||||
        self._test_behind_camera(
 | 
			
		||||
            rasterize_points, torch.device("cuda"), bin_size=0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_cpp_vs_naive_vs_binned(self):
 | 
			
		||||
        # Make sure that the backward pass runs for all pathways
 | 
			
		||||
        N = 2
 | 
			
		||||
        P = 1000
 | 
			
		||||
        image_size = 32
 | 
			
		||||
        radius = 0.1
 | 
			
		||||
        points_per_pixel = 3
 | 
			
		||||
        points1 = torch.randn(P, 3, requires_grad=True)
 | 
			
		||||
        points2 = torch.randn(int(P / 2), 3, requires_grad=True)
 | 
			
		||||
        pointclouds = Pointclouds(points=[points1, points2])
 | 
			
		||||
        grad_zbuf = torch.randn(N, image_size, image_size, points_per_pixel)
 | 
			
		||||
        grad_dists = torch.randn(N, image_size, image_size, points_per_pixel)
 | 
			
		||||
 | 
			
		||||
        # Option I: CPU, naive
 | 
			
		||||
        idx1, zbuf1, dists1 = rasterize_points(
 | 
			
		||||
            pointclouds, image_size, radius, points_per_pixel, bin_size=0
 | 
			
		||||
        )
 | 
			
		||||
        loss = (zbuf1 * grad_zbuf).sum() + (dists1 * grad_dists).sum()
 | 
			
		||||
        loss.backward()
 | 
			
		||||
        grad1 = points1.grad.data.clone()
 | 
			
		||||
 | 
			
		||||
        # Option II: CUDA, naive
 | 
			
		||||
        points1_cuda = points1.cuda().detach().clone().requires_grad_(True)
 | 
			
		||||
        points2_cuda = points2.cuda().detach().clone().requires_grad_(True)
 | 
			
		||||
        pointclouds = Pointclouds(points=[points1_cuda, points2_cuda])
 | 
			
		||||
        grad_zbuf = grad_zbuf.cuda()
 | 
			
		||||
        grad_dists = grad_dists.cuda()
 | 
			
		||||
        idx2, zbuf2, dists2 = rasterize_points(
 | 
			
		||||
            pointclouds, image_size, radius, points_per_pixel, bin_size=0
 | 
			
		||||
        )
 | 
			
		||||
        loss = (zbuf2 * grad_zbuf).sum() + (dists2 * grad_dists).sum()
 | 
			
		||||
        loss.backward()
 | 
			
		||||
        idx2 = idx2.data.cpu().clone()
 | 
			
		||||
        zbuf2 = zbuf2.data.cpu().clone()
 | 
			
		||||
        dists2 = dists2.data.cpu().clone()
 | 
			
		||||
        grad2 = points1_cuda.grad.data.cpu().clone()
 | 
			
		||||
 | 
			
		||||
        # Option III: CUDA, binned
 | 
			
		||||
        points1_cuda = points1.cuda().detach().clone().requires_grad_(True)
 | 
			
		||||
        points2_cuda = points2.cuda().detach().clone().requires_grad_(True)
 | 
			
		||||
        pointclouds = Pointclouds(points=[points1_cuda, points2_cuda])
 | 
			
		||||
        idx3, zbuf3, dists3 = rasterize_points(
 | 
			
		||||
            pointclouds, image_size, radius, points_per_pixel, bin_size=32
 | 
			
		||||
        )
 | 
			
		||||
        loss = (zbuf3 * grad_zbuf).sum() + (dists3 * grad_dists).sum()
 | 
			
		||||
        points1.grad.data.zero_()
 | 
			
		||||
        loss.backward()
 | 
			
		||||
        idx3 = idx3.data.cpu().clone()
 | 
			
		||||
        zbuf3 = zbuf3.data.cpu().clone()
 | 
			
		||||
        dists3 = dists3.data.cpu().clone()
 | 
			
		||||
        grad3 = points1_cuda.grad.data.cpu().clone()
 | 
			
		||||
 | 
			
		||||
        # Make sure everything was the same
 | 
			
		||||
        idx12_same = (idx1 == idx2).all().item()
 | 
			
		||||
        idx13_same = (idx1 == idx3).all().item()
 | 
			
		||||
        zbuf12_same = (zbuf1 == zbuf2).all().item()
 | 
			
		||||
        zbuf13_same = (zbuf1 == zbuf3).all().item()
 | 
			
		||||
        dists12_diff = (dists1 - dists2).abs().max().item()
 | 
			
		||||
        dists13_diff = (dists1 - dists3).abs().max().item()
 | 
			
		||||
        self.assertTrue(idx12_same)
 | 
			
		||||
        self.assertTrue(idx13_same)
 | 
			
		||||
        self.assertTrue(zbuf12_same)
 | 
			
		||||
        self.assertTrue(zbuf13_same)
 | 
			
		||||
        self.assertTrue(dists12_diff < 1e-6)
 | 
			
		||||
        self.assertTrue(dists13_diff < 1e-6)
 | 
			
		||||
 | 
			
		||||
        diff12 = (grad1 - grad2).abs().max().item()
 | 
			
		||||
        diff13 = (grad1 - grad3).abs().max().item()
 | 
			
		||||
        diff23 = (grad2 - grad3).abs().max().item()
 | 
			
		||||
        self.assertTrue(diff12 < 5e-6)
 | 
			
		||||
        self.assertTrue(diff13 < 5e-6)
 | 
			
		||||
        self.assertTrue(diff23 < 5e-6)
 | 
			
		||||
 | 
			
		||||
    def test_python_vs_cpu_naive(self):
 | 
			
		||||
        torch.manual_seed(231)
 | 
			
		||||
        image_size = 32
 | 
			
		||||
        radius = 0.1
 | 
			
		||||
        points_per_pixel = 3
 | 
			
		||||
 | 
			
		||||
        # Test a batch of homogeneous point clouds.
 | 
			
		||||
        N = 2
 | 
			
		||||
        P = 17
 | 
			
		||||
        points = torch.randn(N, P, 3, requires_grad=True)
 | 
			
		||||
        pointclouds = Pointclouds(points=points)
 | 
			
		||||
        args = (pointclouds, image_size, radius, points_per_pixel)
 | 
			
		||||
        self._compare_impls(
 | 
			
		||||
            rasterize_points_python,
 | 
			
		||||
            rasterize_points,
 | 
			
		||||
            args,
 | 
			
		||||
            args,
 | 
			
		||||
            points,
 | 
			
		||||
            points,
 | 
			
		||||
            compare_grads=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Test a batch of heterogeneous point clouds.
 | 
			
		||||
        P2 = 10
 | 
			
		||||
        points1 = torch.randn(P, 3, requires_grad=True)
 | 
			
		||||
        points2 = torch.randn(P2, 3)
 | 
			
		||||
        pointclouds = Pointclouds(points=[points1, points2])
 | 
			
		||||
        args = (pointclouds, image_size, radius, points_per_pixel)
 | 
			
		||||
        self._compare_impls(
 | 
			
		||||
            rasterize_points_python,
 | 
			
		||||
            rasterize_points,
 | 
			
		||||
            args,
 | 
			
		||||
            args,
 | 
			
		||||
            points1,  # check gradients for first element in batch
 | 
			
		||||
            points1,
 | 
			
		||||
            compare_grads=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_cpu_vs_cuda_naive(self):
 | 
			
		||||
        torch.manual_seed(231)
 | 
			
		||||
        image_size = 64
 | 
			
		||||
        radius = 0.1
 | 
			
		||||
        points_per_pixel = 5
 | 
			
		||||
 | 
			
		||||
        # Test homogeneous point cloud batch.
 | 
			
		||||
        N = 2
 | 
			
		||||
        P = 1000
 | 
			
		||||
        bin_size = 0
 | 
			
		||||
        points_cpu = torch.rand(N, P, 3, requires_grad=True)
 | 
			
		||||
        points_cuda = points_cpu.cuda().detach().requires_grad_(True)
 | 
			
		||||
        pointclouds_cpu = Pointclouds(points=points_cpu)
 | 
			
		||||
        pointclouds_cuda = Pointclouds(points=points_cuda)
 | 
			
		||||
        args_cpu = (
 | 
			
		||||
            pointclouds_cpu,
 | 
			
		||||
            image_size,
 | 
			
		||||
            radius,
 | 
			
		||||
            points_per_pixel,
 | 
			
		||||
            bin_size,
 | 
			
		||||
        )
 | 
			
		||||
        args_cuda = (
 | 
			
		||||
            pointclouds_cuda,
 | 
			
		||||
            image_size,
 | 
			
		||||
            radius,
 | 
			
		||||
            points_per_pixel,
 | 
			
		||||
            bin_size,
 | 
			
		||||
        )
 | 
			
		||||
        self._compare_impls(
 | 
			
		||||
            rasterize_points,
 | 
			
		||||
            rasterize_points,
 | 
			
		||||
            args_cpu,
 | 
			
		||||
            args_cuda,
 | 
			
		||||
            points_cpu,
 | 
			
		||||
            points_cuda,
 | 
			
		||||
            compare_grads=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _compare_impls(
 | 
			
		||||
        self,
 | 
			
		||||
        fn1,
 | 
			
		||||
        fn2,
 | 
			
		||||
        args1,
 | 
			
		||||
        args2,
 | 
			
		||||
        grad_var1=None,
 | 
			
		||||
        grad_var2=None,
 | 
			
		||||
        compare_grads=False,
 | 
			
		||||
    ):
 | 
			
		||||
        idx1, zbuf1, dist1 = fn1(*args1)
 | 
			
		||||
        torch.manual_seed(231)
 | 
			
		||||
        grad_zbuf = torch.randn_like(zbuf1)
 | 
			
		||||
        grad_dist = torch.randn_like(dist1)
 | 
			
		||||
        loss = (zbuf1 * grad_zbuf).sum() + (dist1 * grad_dist).sum()
 | 
			
		||||
        if compare_grads:
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            grad_points1 = grad_var1.grad.data.clone().cpu()
 | 
			
		||||
 | 
			
		||||
        idx2, zbuf2, dist2 = fn2(*args2)
 | 
			
		||||
        grad_zbuf = grad_zbuf.to(zbuf2)
 | 
			
		||||
        grad_dist = grad_dist.to(dist2)
 | 
			
		||||
        loss = (zbuf2 * grad_zbuf).sum() + (dist2 * grad_dist).sum()
 | 
			
		||||
        if compare_grads:
 | 
			
		||||
            # clear points1.grad in case args1 and args2 reused the same tensor
 | 
			
		||||
            grad_var1.grad.data.zero_()
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            grad_points2 = grad_var2.grad.data.clone().cpu()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual((idx1.cpu() == idx2.cpu()).all().item(), 1)
 | 
			
		||||
        self.assertEqual((zbuf1.cpu() == zbuf2.cpu()).all().item(), 1)
 | 
			
		||||
        self.assertClose(dist1.cpu(), dist2.cpu())
 | 
			
		||||
        if compare_grads:
 | 
			
		||||
            self.assertTrue(
 | 
			
		||||
                torch.allclose(grad_points1, grad_points2, atol=2e-6)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def _test_behind_camera(self, rasterize_points_fn, device, bin_size=None):
 | 
			
		||||
        # Test case where all points are behind the camera -- nothing should
 | 
			
		||||
        # get rasterized
 | 
			
		||||
        N = 2
 | 
			
		||||
        P = 32
 | 
			
		||||
        xy = torch.randn(N, P, 2)
 | 
			
		||||
        z = torch.randn(N, P, 1).abs().mul(-1)  # Make them all negative
 | 
			
		||||
        points = torch.cat([xy, z], dim=2).to(device)
 | 
			
		||||
        image_size = 16
 | 
			
		||||
        points_per_pixel = 3
 | 
			
		||||
        radius = 0.2
 | 
			
		||||
        idx_expected = torch.full(
 | 
			
		||||
            (N, 16, 16, 3), fill_value=-1, dtype=torch.int32, device=device
 | 
			
		||||
        )
 | 
			
		||||
        zbuf_expected = torch.full(
 | 
			
		||||
            (N, 16, 16, 3), fill_value=-1, dtype=torch.float32, device=device
 | 
			
		||||
        )
 | 
			
		||||
        dists_expected = zbuf_expected.clone()
 | 
			
		||||
        pointclouds = Pointclouds(points=points)
 | 
			
		||||
        if bin_size == -1:
 | 
			
		||||
            # simple python case with no binning
 | 
			
		||||
            idx, zbuf, dists = rasterize_points_fn(
 | 
			
		||||
                pointclouds, image_size, radius, points_per_pixel
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            idx, zbuf, dists = rasterize_points_fn(
 | 
			
		||||
                pointclouds, image_size, radius, points_per_pixel, bin_size
 | 
			
		||||
            )
 | 
			
		||||
        idx_same = (idx == idx_expected).all().item() == 1
 | 
			
		||||
        zbuf_same = (zbuf == zbuf_expected).all().item() == 1
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(idx_same)
 | 
			
		||||
        self.assertTrue(zbuf_same)
 | 
			
		||||
        self.assertTrue(torch.allclose(dists, dists_expected))
 | 
			
		||||
 | 
			
		||||
    def _simple_test_case(self, rasterize_points_fn, device, bin_size=0):
 | 
			
		||||
        # Create two pointclouds with different numbers of points.
 | 
			
		||||
        # fmt: off
 | 
			
		||||
        points1 = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.0, 0.0,  0.0],  # noqa: E241
 | 
			
		||||
                [0.4, 0.0,  0.1],  # noqa: E241
 | 
			
		||||
                [0.0, 0.4,  0.2],  # noqa: E241
 | 
			
		||||
                [0.0, 0.0, -0.1],  # noqa: E241 Points with negative z should be skippped
 | 
			
		||||
            ],
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        points2 = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.0, 0.0,  0.0],  # noqa: E241
 | 
			
		||||
                [0.4, 0.0,  0.1],  # noqa: E241
 | 
			
		||||
                [0.0, 0.4,  0.2],  # noqa: E241
 | 
			
		||||
                [0.0, 0.0, -0.1],  # noqa: E241 Points with negative z should be skippped
 | 
			
		||||
                [0.0, 0.0, -0.7],  # noqa: E241 Points with negative z should be skippped
 | 
			
		||||
            ],
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        pointclouds = Pointclouds(points=[points1, points2])
 | 
			
		||||
 | 
			
		||||
        image_size = 5
 | 
			
		||||
        points_per_pixel = 2
 | 
			
		||||
        radius = 0.5
 | 
			
		||||
 | 
			
		||||
        # The expected output values. Note that in the outputs, the world space
 | 
			
		||||
        # +Y is up, and the world space +X is left.
 | 
			
		||||
        idx1_expected = torch.full(
 | 
			
		||||
            (1, 5, 5, 2), fill_value=-1, dtype=torch.int32, device=device
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: off
 | 
			
		||||
        idx1_expected[0, :, :, 0] = torch.tensor([
 | 
			
		||||
            [-1, -1,  2, -1, -1],  # noqa: E241
 | 
			
		||||
            [-1,  1,  0,  2, -1],  # noqa: E241
 | 
			
		||||
            [ 1,  0,  0,  0, -1],  # noqa: E241 E201
 | 
			
		||||
            [-1,  1,  0, -1, -1],  # noqa: E241
 | 
			
		||||
            [-1, -1, -1, -1, -1],  # noqa: E241
 | 
			
		||||
        ], device=device)
 | 
			
		||||
        idx1_expected[0, :, :, 1] = torch.tensor([
 | 
			
		||||
            [-1, -1, -1, -1, -1],  # noqa: E241
 | 
			
		||||
            [-1,  2,  2, -1, -1],  # noqa: E241
 | 
			
		||||
            [-1,  1,  1, -1, -1],  # noqa: E241
 | 
			
		||||
            [-1, -1, -1, -1, -1],  # noqa: E241
 | 
			
		||||
            [-1, -1, -1, -1, -1],  # noqa: E241
 | 
			
		||||
        ], device=device)
 | 
			
		||||
        # fmt: on
 | 
			
		||||
 | 
			
		||||
        zbuf1_expected = torch.full(
 | 
			
		||||
            (1, 5, 5, 2), fill_value=100, dtype=torch.float32, device=device
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: off
 | 
			
		||||
        zbuf1_expected[0, :, :, 0] = torch.tensor([
 | 
			
		||||
            [-1.0, -1.0,  0.2, -1.0, -1.0],  # noqa: E241
 | 
			
		||||
            [-1.0,  0.1,  0.0,  0.2, -1.0],  # noqa: E241
 | 
			
		||||
            [ 0.1,  0.0,  0.0,  0.0, -1.0],  # noqa: E241 E201
 | 
			
		||||
            [-1.0,  0.1,  0.0, -1.0, -1.0],  # noqa: E241
 | 
			
		||||
            [-1.0, -1.0, -1.0, -1.0, -1.0]   # noqa: E241
 | 
			
		||||
        ], device=device)
 | 
			
		||||
        zbuf1_expected[0, :, :, 1] = torch.tensor([
 | 
			
		||||
            [-1.0, -1.0, -1.0, -1.0, -1.0],  # noqa: E241
 | 
			
		||||
            [-1.0,  0.2,  0.2, -1.0, -1.0],  # noqa: E241
 | 
			
		||||
            [-1.0,  0.1,  0.1, -1.0, -1.0],  # noqa: E241
 | 
			
		||||
            [-1.0, -1.0, -1.0, -1.0, -1.0],  # noqa: E241
 | 
			
		||||
            [-1.0, -1.0, -1.0, -1.0, -1.0],  # noqa: E241
 | 
			
		||||
        ], device=device)
 | 
			
		||||
        # fmt: on
 | 
			
		||||
 | 
			
		||||
        dists1_expected = torch.full(
 | 
			
		||||
            (1, 5, 5, 2), fill_value=0.0, dtype=torch.float32, device=device
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: off
 | 
			
		||||
        dists1_expected[0, :, :, 0] = torch.tensor([
 | 
			
		||||
            [-1.00, -1.00,  0.16, -1.00, -1.00],  # noqa: E241
 | 
			
		||||
            [-1.00,  0.16,  0.16,  0.16, -1.00],  # noqa: E241
 | 
			
		||||
            [ 0.16,  0.16,  0.00,  0.16, -1.00],  # noqa: E241 E201
 | 
			
		||||
            [-1.00,  0.16,  0.16, -1.00, -1.00],  # noqa: E241
 | 
			
		||||
            [-1.00, -1.00, -1.00, -1.00, -1.00],  # noqa: E241
 | 
			
		||||
        ], device=device)
 | 
			
		||||
        dists1_expected[0, :, :, 1] = torch.tensor([
 | 
			
		||||
            [-1.00, -1.00, -1.00, -1.00, -1.00],  # noqa: E241
 | 
			
		||||
            [-1.00,  0.16,  0.00, -1.00, -1.00],  # noqa: E241
 | 
			
		||||
            [-1.00,  0.00,  0.16, -1.00, -1.00],  # noqa: E241
 | 
			
		||||
            [-1.00, -1.00, -1.00, -1.00, -1.00],  # noqa: E241
 | 
			
		||||
            [-1.00, -1.00, -1.00, -1.00, -1.00],  # noqa: E241
 | 
			
		||||
        ], device=device)
 | 
			
		||||
        # fmt: on
 | 
			
		||||
 | 
			
		||||
        if bin_size == -1:
 | 
			
		||||
            # simple python case with no binning
 | 
			
		||||
            idx, zbuf, dists = rasterize_points_fn(
 | 
			
		||||
                pointclouds, image_size, radius, points_per_pixel
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            idx, zbuf, dists = rasterize_points_fn(
 | 
			
		||||
                pointclouds, image_size, radius, points_per_pixel, bin_size
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # check first point cloud
 | 
			
		||||
        idx_same = (idx[0, ...] == idx1_expected).all().item() == 1
 | 
			
		||||
        if idx_same == 0:
 | 
			
		||||
            print(idx[0, :, :, 0])
 | 
			
		||||
            print(idx[0, :, :, 1])
 | 
			
		||||
        zbuf_same = (zbuf[0, ...] == zbuf1_expected).all().item() == 1
 | 
			
		||||
        dist_same = torch.allclose(dists[0, ...], dists1_expected)
 | 
			
		||||
        self.assertTrue(idx_same)
 | 
			
		||||
        self.assertTrue(zbuf_same)
 | 
			
		||||
        self.assertTrue(dist_same)
 | 
			
		||||
 | 
			
		||||
        # Check second point cloud - the indices in idx refer to points in the
 | 
			
		||||
        # pointclouds.points_packed() tensor. In the second point cloud,
 | 
			
		||||
        # two points are behind the screen - the expected indices are the same
 | 
			
		||||
        # the first pointcloud but offset by the number of points in the
 | 
			
		||||
        # first pointcloud.
 | 
			
		||||
        num_points_per_cloud = pointclouds.num_points_per_cloud()
 | 
			
		||||
        idx1_expected[idx1_expected >= 0] += num_points_per_cloud[0]
 | 
			
		||||
 | 
			
		||||
        idx_same = (idx[1, ...] == idx1_expected).all().item() == 1
 | 
			
		||||
        zbuf_same = (zbuf[1, ...] == zbuf1_expected).all().item() == 1
 | 
			
		||||
        self.assertTrue(idx_same)
 | 
			
		||||
        self.assertTrue(zbuf_same)
 | 
			
		||||
        self.assertTrue(torch.allclose(dists[1, ...], dists1_expected))
 | 
			
		||||
 | 
			
		||||
    def test_coarse_cpu(self):
 | 
			
		||||
        return self._test_coarse_rasterize(torch.device("cpu"))
 | 
			
		||||
 | 
			
		||||
    def test_coarse_cuda(self):
 | 
			
		||||
        return self._test_coarse_rasterize(torch.device("cuda"))
 | 
			
		||||
 | 
			
		||||
    def test_compare_coarse_cpu_vs_cuda(self):
 | 
			
		||||
        torch.manual_seed(231)
 | 
			
		||||
        N = 3
 | 
			
		||||
        max_P = 1000
 | 
			
		||||
        image_size = 64
 | 
			
		||||
        radius = 0.1
 | 
			
		||||
        bin_size = 16
 | 
			
		||||
        max_points_per_bin = 500
 | 
			
		||||
 | 
			
		||||
        # create heterogeneous point clouds
 | 
			
		||||
        points = []
 | 
			
		||||
        for _ in range(N):
 | 
			
		||||
            p = np.random.choice(max_P)
 | 
			
		||||
            points.append(torch.randn(p, 3))
 | 
			
		||||
 | 
			
		||||
        pointclouds = Pointclouds(points=points)
 | 
			
		||||
        points_packed = pointclouds.points_packed()
 | 
			
		||||
        cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
 | 
			
		||||
        num_points_per_cloud = pointclouds.num_points_per_cloud()
 | 
			
		||||
        args = (
 | 
			
		||||
            points_packed,
 | 
			
		||||
            cloud_to_packed_first_idx,
 | 
			
		||||
            num_points_per_cloud,
 | 
			
		||||
            image_size,
 | 
			
		||||
            radius,
 | 
			
		||||
            bin_size,
 | 
			
		||||
            max_points_per_bin,
 | 
			
		||||
        )
 | 
			
		||||
        bp_cpu = _C._rasterize_points_coarse(*args)
 | 
			
		||||
 | 
			
		||||
        pointclouds_cuda = pointclouds.to("cuda:0")
 | 
			
		||||
        points_packed = pointclouds_cuda.points_packed()
 | 
			
		||||
        cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
 | 
			
		||||
        num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()
 | 
			
		||||
        args = (
 | 
			
		||||
            points_packed,
 | 
			
		||||
            cloud_to_packed_first_idx,
 | 
			
		||||
            num_points_per_cloud,
 | 
			
		||||
            image_size,
 | 
			
		||||
            radius,
 | 
			
		||||
            bin_size,
 | 
			
		||||
            max_points_per_bin,
 | 
			
		||||
        )
 | 
			
		||||
        bp_cuda = _C._rasterize_points_coarse(*args)
 | 
			
		||||
 | 
			
		||||
        # Bin points might not be the same: CUDA version might write them in
 | 
			
		||||
        # any order. But if we sort the non-(-1) elements of the CUDA output
 | 
			
		||||
        # then they should be the same.
 | 
			
		||||
        for n in range(N):
 | 
			
		||||
            for by in range(bp_cpu.shape[1]):
 | 
			
		||||
                for bx in range(bp_cpu.shape[2]):
 | 
			
		||||
                    K = (bp_cpu[n, by, bx] != -1).sum().item()
 | 
			
		||||
                    idxs_cpu = bp_cpu[n, by, bx].tolist()
 | 
			
		||||
                    idxs_cuda = bp_cuda[n, by, bx].tolist()
 | 
			
		||||
                    idxs_cuda[:K] = sorted(idxs_cuda[:K])
 | 
			
		||||
                    self.assertEqual(idxs_cpu, idxs_cuda)
 | 
			
		||||
 | 
			
		||||
    def _test_coarse_rasterize(self, device):
 | 
			
		||||
        #
 | 
			
		||||
        #  Note that +Y is up and +X is left in the diagram below.
 | 
			
		||||
        #
 | 
			
		||||
        #  (4)              |2
 | 
			
		||||
        #                   |
 | 
			
		||||
        #                   |
 | 
			
		||||
        #                   |
 | 
			
		||||
        #                   |1
 | 
			
		||||
        #                   |
 | 
			
		||||
        #             (1)   |
 | 
			
		||||
        #                   | (2)
 | 
			
		||||
        # ____________(0)__(5)___________________
 | 
			
		||||
        # 2        1        |          -1      -2
 | 
			
		||||
        #                   |
 | 
			
		||||
        #       (3)         |
 | 
			
		||||
        #                   |
 | 
			
		||||
        #                   |-1
 | 
			
		||||
        #                   |
 | 
			
		||||
        #
 | 
			
		||||
        # Locations of the points are shown by o. The screen bounding box
 | 
			
		||||
        # is between [-1, 1] in both the x and y directions.
 | 
			
		||||
        #
 | 
			
		||||
        # These points are interesting because:
 | 
			
		||||
        # (0) Falls into two bins;
 | 
			
		||||
        # (1) and (2) fall into one bin;
 | 
			
		||||
        # (3) is out-of-bounds, but its disk is in-bounds;
 | 
			
		||||
        # (4) is out-of-bounds, and its entire disk is also out-of-bounds
 | 
			
		||||
        # (5) has a negative z-value, so it should be skipped
 | 
			
		||||
        # fmt: off
 | 
			
		||||
        points = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [ 0.5,  0.0,  0.0],  # noqa: E241, E201
 | 
			
		||||
                [ 0.5,  0.5,  0.1],  # noqa: E241, E201
 | 
			
		||||
                [-0.3,  0.4,  0.0],  # noqa: E241
 | 
			
		||||
                [ 1.1, -0.5,  0.2],  # noqa: E241, E201
 | 
			
		||||
                [ 2.0,  2.0,  0.3],  # noqa: E241, E201
 | 
			
		||||
                [ 0.0,  0.0, -0.1],  # noqa: E241, E201
 | 
			
		||||
            ],
 | 
			
		||||
            device=device
 | 
			
		||||
        )
 | 
			
		||||
        # fmt: on
 | 
			
		||||
        image_size = 16
 | 
			
		||||
        radius = 0.2
 | 
			
		||||
        bin_size = 8
 | 
			
		||||
        max_points_per_bin = 5
 | 
			
		||||
 | 
			
		||||
        bin_points_expected = -1 * torch.ones(
 | 
			
		||||
            1, 2, 2, 5, dtype=torch.int32, device=device
 | 
			
		||||
        )
 | 
			
		||||
        # Note that the order is only deterministic here for CUDA if all points
 | 
			
		||||
        # fit in one chunk. This will the the case for this small example, but
 | 
			
		||||
        # to properly exercise coordianted writes among multiple chunks we need
 | 
			
		||||
        # to use a bigger test case.
 | 
			
		||||
        bin_points_expected[0, 1, 0, :2] = torch.tensor([0, 3])
 | 
			
		||||
        bin_points_expected[0, 0, 1, 0] = torch.tensor([2])
 | 
			
		||||
        bin_points_expected[0, 0, 0, :2] = torch.tensor([0, 1])
 | 
			
		||||
 | 
			
		||||
        pointclouds = Pointclouds(points=[points])
 | 
			
		||||
        args = (
 | 
			
		||||
            pointclouds.points_packed(),
 | 
			
		||||
            pointclouds.cloud_to_packed_first_idx(),
 | 
			
		||||
            pointclouds.num_points_per_cloud(),
 | 
			
		||||
            image_size,
 | 
			
		||||
            radius,
 | 
			
		||||
            bin_size,
 | 
			
		||||
            max_points_per_bin,
 | 
			
		||||
        )
 | 
			
		||||
        bin_points = _C._rasterize_points_coarse(*args)
 | 
			
		||||
        bin_points_same = (bin_points == bin_points_expected).all()
 | 
			
		||||
        self.assertTrue(bin_points_same.item() == 1)
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user