Support variable size radius for points in rasterizer

Summary:
Support variable size pointclouds in the renderer API to allow compatibility with Pulsar rasterizer.

If radius is provided as a float, it is converted to a tensor of shape (P). Otherwise radius is expected to be an (N, P_padded) dimensional tensor where P_padded is the max number of points in the batch (following the convention from pulsar: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/frl/gemini/pulsar/pulsar/renderer.py?commit=ee0342850210e5df441e14fd97162675c70d147c&lines=50)

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21429400

fbshipit-source-id: 65de7d9cd2472b27fc29f96160c33687e88098a2
This commit is contained in:
Nikhila Ravi
2020-09-18 18:46:45 -07:00
committed by Facebook GitHub Bot
parent e40c2167ae
commit ebe2693b11
8 changed files with 291 additions and 73 deletions

View File

@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional
from typing import List, Optional, Tuple, Union
import torch
@@ -18,7 +18,7 @@ kMaxPointsPerBin = 22
def rasterize_points(
pointclouds,
image_size: int = 256,
radius: float = 0.01,
radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None,
@@ -35,8 +35,10 @@ def rasterize_points(
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
image_size: Integer giving the resolution of the rasterized image
radius (Optional): Float giving the radius (in NDC units) of the disk to
be rasterized for each point.
radius (Optional): The radius (in NDC units) of the disk to
be rasterized. This can either be a float in which case the same radius is used
for each point, or a torch.Tensor of shape (N, P) giving a radius per point
in the batch.
points_per_pixel (Optional): We will keep track of this many points per
pixel, returning the nearest points_per_pixel points along the z-axis
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
@@ -74,6 +76,8 @@ def rasterize_points(
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud()
radius = _format_radius(radius, pointclouds)
if bin_size is None:
if not points_packed.is_cuda:
# Binned CPU rasterization not fully implemented
@@ -117,6 +121,48 @@ def rasterize_points(
)
def _format_radius(
radius: Union[float, List, Tuple, torch.Tensor], pointclouds
) -> torch.Tensor:
"""
Format the radius as a torch tensor of shape (P_packed,)
where P_packed is the total number of points in the
batch (i.e. pointclouds.points_packed().shape[0]).
This will enable support for a different size radius
for each point in the batch.
Args:
radius: can be a float, List, Tuple or tensor of
shape (N, P_padded) where P_padded is the
maximum number of points for each pointcloud
in the batch.
Returns:
radius: torch.Tensor of shape (P_packed)
"""
N, P_padded = pointclouds._N, pointclouds._P
points_packed = pointclouds.points_packed()
P_packed = points_packed.shape[0]
if isinstance(radius, (list, tuple)):
radius = torch.tensor(radius).type_as(points_packed)
if isinstance(radius, torch.Tensor):
if N == 1 and radius.ndim == 1:
radius = radius[None, ...]
if radius.shape != (N, P_padded):
msg = "radius must be of shape (N, P): got %s"
raise ValueError(msg % (repr(radius.shape)))
else:
padded_to_packed_idx = pointclouds.padded_to_packed_idx()
radius = radius.view(-1)[padded_to_packed_idx]
elif isinstance(radius, float):
radius = torch.full((P_packed,), fill_value=radius).type_as(points_packed)
else:
msg = "radius must be a float, list, tuple or tensor; got %s"
raise ValueError(msg % type(radius))
return radius
class _RasterizePoints(torch.autograd.Function):
@staticmethod
def forward(
@@ -125,7 +171,7 @@ class _RasterizePoints(torch.autograd.Function):
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size: int = 256,
radius: float = 0.01,
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
bin_size: int = 0,
max_points_per_bin: int = 0,
@@ -175,7 +221,10 @@ class _RasterizePoints(torch.autograd.Function):
def rasterize_points_python(
pointclouds, image_size: int = 256, radius: float = 0.01, points_per_pixel: int = 8
pointclouds,
image_size: int = 256,
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
):
"""
Naive pure PyTorch implementation of pointcloud rasterization.
@@ -190,6 +239,9 @@ def rasterize_points_python(
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud()
# Support variable size radius for each point in the batch
radius = _format_radius(radius, pointclouds)
# Intialize output tensors.
point_idxs = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
@@ -225,12 +277,13 @@ def rasterize_points_python(
# Check whether each point in the batch affects this pixel.
for p in range(point_start_idx, point_stop_idx):
px, py, pz = points_packed[p, :]
r = radius2[p]
if pz < 0:
continue
dx = px - xf
dy = py - yf
dist2 = dx * dx + dy * dy
if dist2 < radius2:
if dist2 < r:
top_k_points.append((pz, p, dist2))
top_k_points.sort()
if len(top_k_points) > K:

View File

@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Optional
from typing import NamedTuple, Optional, Union
import torch
import torch.nn as nn
@@ -30,7 +30,7 @@ class PointsRasterizationSettings:
def __init__(
self,
image_size: int = 256,
radius: float = 0.01,
radius: Union[float, torch.Tensor] = 0.01,
points_per_pixel: int = 8,
bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None,