mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Support variable size radius for points in rasterizer
Summary: Support variable size pointclouds in the renderer API to allow compatibility with Pulsar rasterizer. If radius is provided as a float, it is converted to a tensor of shape (P). Otherwise radius is expected to be an (N, P_padded) dimensional tensor where P_padded is the max number of points in the batch (following the convention from pulsar: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/frl/gemini/pulsar/pulsar/renderer.py?commit=ee0342850210e5df441e14fd97162675c70d147c&lines=50) Reviewed By: jcjohnson, gkioxari Differential Revision: D21429400 fbshipit-source-id: 65de7d9cd2472b27fc29f96160c33687e88098a2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e40c2167ae
commit
ebe2693b11
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -18,7 +18,7 @@ kMaxPointsPerBin = 22
|
||||
def rasterize_points(
|
||||
pointclouds,
|
||||
image_size: int = 256,
|
||||
radius: float = 0.01,
|
||||
radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: Optional[int] = None,
|
||||
max_points_per_bin: Optional[int] = None,
|
||||
@@ -35,8 +35,10 @@ def rasterize_points(
|
||||
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
|
||||
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
|
||||
image_size: Integer giving the resolution of the rasterized image
|
||||
radius (Optional): Float giving the radius (in NDC units) of the disk to
|
||||
be rasterized for each point.
|
||||
radius (Optional): The radius (in NDC units) of the disk to
|
||||
be rasterized. This can either be a float in which case the same radius is used
|
||||
for each point, or a torch.Tensor of shape (N, P) giving a radius per point
|
||||
in the batch.
|
||||
points_per_pixel (Optional): We will keep track of this many points per
|
||||
pixel, returning the nearest points_per_pixel points along the z-axis
|
||||
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
|
||||
@@ -74,6 +76,8 @@ def rasterize_points(
|
||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds.num_points_per_cloud()
|
||||
|
||||
radius = _format_radius(radius, pointclouds)
|
||||
|
||||
if bin_size is None:
|
||||
if not points_packed.is_cuda:
|
||||
# Binned CPU rasterization not fully implemented
|
||||
@@ -117,6 +121,48 @@ def rasterize_points(
|
||||
)
|
||||
|
||||
|
||||
def _format_radius(
|
||||
radius: Union[float, List, Tuple, torch.Tensor], pointclouds
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Format the radius as a torch tensor of shape (P_packed,)
|
||||
where P_packed is the total number of points in the
|
||||
batch (i.e. pointclouds.points_packed().shape[0]).
|
||||
|
||||
This will enable support for a different size radius
|
||||
for each point in the batch.
|
||||
|
||||
Args:
|
||||
radius: can be a float, List, Tuple or tensor of
|
||||
shape (N, P_padded) where P_padded is the
|
||||
maximum number of points for each pointcloud
|
||||
in the batch.
|
||||
|
||||
Returns:
|
||||
radius: torch.Tensor of shape (P_packed)
|
||||
"""
|
||||
N, P_padded = pointclouds._N, pointclouds._P
|
||||
points_packed = pointclouds.points_packed()
|
||||
P_packed = points_packed.shape[0]
|
||||
if isinstance(radius, (list, tuple)):
|
||||
radius = torch.tensor(radius).type_as(points_packed)
|
||||
if isinstance(radius, torch.Tensor):
|
||||
if N == 1 and radius.ndim == 1:
|
||||
radius = radius[None, ...]
|
||||
if radius.shape != (N, P_padded):
|
||||
msg = "radius must be of shape (N, P): got %s"
|
||||
raise ValueError(msg % (repr(radius.shape)))
|
||||
else:
|
||||
padded_to_packed_idx = pointclouds.padded_to_packed_idx()
|
||||
radius = radius.view(-1)[padded_to_packed_idx]
|
||||
elif isinstance(radius, float):
|
||||
radius = torch.full((P_packed,), fill_value=radius).type_as(points_packed)
|
||||
else:
|
||||
msg = "radius must be a float, list, tuple or tensor; got %s"
|
||||
raise ValueError(msg % type(radius))
|
||||
return radius
|
||||
|
||||
|
||||
class _RasterizePoints(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -125,7 +171,7 @@ class _RasterizePoints(torch.autograd.Function):
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size: int = 256,
|
||||
radius: float = 0.01,
|
||||
radius: Union[float, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: int = 0,
|
||||
max_points_per_bin: int = 0,
|
||||
@@ -175,7 +221,10 @@ class _RasterizePoints(torch.autograd.Function):
|
||||
|
||||
|
||||
def rasterize_points_python(
|
||||
pointclouds, image_size: int = 256, radius: float = 0.01, points_per_pixel: int = 8
|
||||
pointclouds,
|
||||
image_size: int = 256,
|
||||
radius: Union[float, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
):
|
||||
"""
|
||||
Naive pure PyTorch implementation of pointcloud rasterization.
|
||||
@@ -190,6 +239,9 @@ def rasterize_points_python(
|
||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds.num_points_per_cloud()
|
||||
|
||||
# Support variable size radius for each point in the batch
|
||||
radius = _format_radius(radius, pointclouds)
|
||||
|
||||
# Intialize output tensors.
|
||||
point_idxs = torch.full(
|
||||
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
|
||||
@@ -225,12 +277,13 @@ def rasterize_points_python(
|
||||
# Check whether each point in the batch affects this pixel.
|
||||
for p in range(point_start_idx, point_stop_idx):
|
||||
px, py, pz = points_packed[p, :]
|
||||
r = radius2[p]
|
||||
if pz < 0:
|
||||
continue
|
||||
dx = px - xf
|
||||
dy = py - yf
|
||||
dist2 = dx * dx + dy * dy
|
||||
if dist2 < radius2:
|
||||
if dist2 < r:
|
||||
top_k_points.append((pz, p, dist2))
|
||||
top_k_points.sort()
|
||||
if len(top_k_points) > K:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -30,7 +30,7 @@ class PointsRasterizationSettings:
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 256,
|
||||
radius: float = 0.01,
|
||||
radius: Union[float, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: Optional[int] = None,
|
||||
max_points_per_bin: Optional[int] = None,
|
||||
|
||||
Reference in New Issue
Block a user