mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Support variable size radius for points in rasterizer
Summary: Support variable size pointclouds in the renderer API to allow compatibility with Pulsar rasterizer. If radius is provided as a float, it is converted to a tensor of shape (P). Otherwise radius is expected to be an (N, P_padded) dimensional tensor where P_padded is the max number of points in the batch (following the convention from pulsar: https://our.intern.facebook.com/intern/diffusion/FBS/browse/master/fbcode/frl/gemini/pulsar/pulsar/renderer.py?commit=ee0342850210e5df441e14fd97162675c70d147c&lines=50) Reviewed By: jcjohnson, gkioxari Differential Revision: D21429400 fbshipit-source-id: 65de7d9cd2472b27fc29f96160c33687e88098a2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e40c2167ae
commit
ebe2693b11
@@ -38,13 +38,15 @@ __device__ void CheckPixelInsidePoint(
|
||||
float& q_max_z,
|
||||
int& q_max_idx,
|
||||
PointQ& q,
|
||||
const float radius2,
|
||||
const float* radius,
|
||||
const float xf,
|
||||
const float yf,
|
||||
const int K) {
|
||||
const float px = points[p_idx * 3 + 0];
|
||||
const float py = points[p_idx * 3 + 1];
|
||||
const float pz = points[p_idx * 3 + 2];
|
||||
const float p_radius = radius[p_idx];
|
||||
const float radius2 = p_radius * p_radius;
|
||||
if (pz < 0)
|
||||
return; // Don't render points behind the camera
|
||||
const float dx = xf - px;
|
||||
@@ -81,7 +83,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
const float* points, // (P, 3)
|
||||
const int64_t* cloud_to_packed_first_idx, // (N)
|
||||
const int64_t* num_points_per_cloud, // (N)
|
||||
const float radius,
|
||||
const float* radius,
|
||||
const int N,
|
||||
const int S,
|
||||
const int K,
|
||||
@@ -91,7 +93,6 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
// Simple version: One thread per output pixel
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const float radius2 = radius * radius;
|
||||
for (int i = tid; i < N * S * S; i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
const int n = i / (S * S); // Batch index
|
||||
@@ -128,7 +129,7 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
||||
|
||||
for (int p_idx = point_start_idx; p_idx < point_stop_idx; ++p_idx) {
|
||||
CheckPixelInsidePoint(
|
||||
points, p_idx, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
|
||||
points, p_idx, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
|
||||
}
|
||||
BubbleSort(q, q_size);
|
||||
int idx = n * S * S * K + pix_idx * K;
|
||||
@@ -145,7 +146,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const at::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const at::Tensor& radius,
|
||||
const int points_per_pixel) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg points_t{points, "points", 1},
|
||||
@@ -194,7 +195,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsNaiveCuda(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
||||
radius,
|
||||
radius.contiguous().data_ptr<float>(),
|
||||
N,
|
||||
S,
|
||||
K,
|
||||
@@ -214,7 +215,7 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
const float* points, // (P, 3)
|
||||
const int64_t* cloud_to_packed_first_idx, // (N)
|
||||
const int64_t* num_points_per_cloud, // (N)
|
||||
const float radius,
|
||||
const float* radius,
|
||||
const int N,
|
||||
const int P,
|
||||
const int S,
|
||||
@@ -266,12 +267,13 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
||||
const float px = points[p_idx * 3 + 0];
|
||||
const float py = points[p_idx * 3 + 1];
|
||||
const float pz = points[p_idx * 3 + 2];
|
||||
const float p_radius = radius[p_idx];
|
||||
if (pz < 0)
|
||||
continue; // Don't render points behind the camera.
|
||||
const float px0 = px - radius;
|
||||
const float px1 = px + radius;
|
||||
const float py0 = py - radius;
|
||||
const float py1 = py + radius;
|
||||
const float px0 = px - p_radius;
|
||||
const float px1 = px + p_radius;
|
||||
const float py0 = py - p_radius;
|
||||
const float py1 = py + p_radius;
|
||||
|
||||
// Brute-force search over all bins; TODO something smarter?
|
||||
// For example we could compute the exact bin where the point falls,
|
||||
@@ -341,7 +343,7 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
const at::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const at::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const at::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
TORCH_CHECK(
|
||||
@@ -390,7 +392,7 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_points_per_cloud.contiguous().data_ptr<int64_t>(),
|
||||
radius,
|
||||
radius.contiguous().data_ptr<float>(),
|
||||
N,
|
||||
P,
|
||||
image_size,
|
||||
@@ -411,7 +413,7 @@ at::Tensor RasterizePointsCoarseCuda(
|
||||
__global__ void RasterizePointsFineCudaKernel(
|
||||
const float* points, // (P, 3)
|
||||
const int32_t* bin_points, // (N, B, B, T)
|
||||
const float radius,
|
||||
const float* radius,
|
||||
const int bin_size,
|
||||
const int N,
|
||||
const int B, // num_bins
|
||||
@@ -425,7 +427,6 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
const int num_pixels = N * B * B * bin_size * bin_size;
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const float radius2 = radius * radius;
|
||||
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
// Convert linear index into bin and pixel indices. We make the within
|
||||
@@ -464,7 +465,7 @@ __global__ void RasterizePointsFineCudaKernel(
|
||||
continue;
|
||||
}
|
||||
CheckPixelInsidePoint(
|
||||
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, K);
|
||||
points, p, q_size, q_max_z, q_max_idx, q, radius, xf, yf, K);
|
||||
}
|
||||
// Now we've looked at all the points for this bin, so we can write
|
||||
// output for the current pixel.
|
||||
@@ -488,7 +489,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
const at::Tensor& points, // (P, 3)
|
||||
const at::Tensor& bin_points,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const at::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int points_per_pixel) {
|
||||
// Check inputs are on the same device
|
||||
@@ -525,7 +526,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RasterizePointsFineCuda(
|
||||
RasterizePointsFineCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<float>(),
|
||||
bin_points.contiguous().data_ptr<int32_t>(),
|
||||
radius,
|
||||
radius.contiguous().data_ptr<float>(),
|
||||
bin_size,
|
||||
N,
|
||||
B,
|
||||
|
||||
@@ -15,7 +15,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
@@ -25,7 +25,7 @@ RasterizePointsNaiveCuda(
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel);
|
||||
#endif
|
||||
// Naive (forward) pointcloud rasterization: For each pixel, for each point,
|
||||
@@ -41,7 +41,8 @@ RasterizePointsNaiveCuda(
|
||||
// in the batch where N is the batch size.
|
||||
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
|
||||
// for each pointcloud in the batch.
|
||||
// radius: Radius of each point (in NDC units)
|
||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||
// each point in points.
|
||||
// image_size: (S) Size of the image to return (in pixels)
|
||||
// points_per_pixel: (K) The number closest of points to return for each pixel
|
||||
//
|
||||
@@ -62,7 +63,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel) {
|
||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||
num_points_per_cloud.is_cuda()) {
|
||||
@@ -70,6 +71,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(cloud_to_packed_first_idx);
|
||||
CHECK_CUDA(num_points_per_cloud);
|
||||
CHECK_CUDA(radius);
|
||||
return RasterizePointsNaiveCuda(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -100,7 +102,7 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin);
|
||||
|
||||
@@ -110,7 +112,7 @@ torch::Tensor RasterizePointsCoarseCuda(
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin);
|
||||
#endif
|
||||
@@ -124,7 +126,8 @@ torch::Tensor RasterizePointsCoarseCuda(
|
||||
// in the batch where N is the batch size.
|
||||
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
|
||||
// for each pointcloud in the batch.
|
||||
// radius: Radius of points to rasterize (in NDC units)
|
||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||
// each point in points.
|
||||
// image_size: Size of the image to generate (in pixels)
|
||||
// bin_size: Size of each bin within the image (in pixels)
|
||||
//
|
||||
@@ -138,7 +141,7 @@ torch::Tensor RasterizePointsCoarse(
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
if (points.is_cuda() && cloud_to_packed_first_idx.is_cuda() &&
|
||||
@@ -147,6 +150,7 @@ torch::Tensor RasterizePointsCoarse(
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(cloud_to_packed_first_idx);
|
||||
CHECK_CUDA(num_points_per_cloud);
|
||||
CHECK_CUDA(radius);
|
||||
return RasterizePointsCoarseCuda(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -179,7 +183,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& bin_points,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int points_per_pixel);
|
||||
#endif
|
||||
@@ -191,7 +195,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
||||
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
|
||||
// that fall into each bin (output from coarse rasterization)
|
||||
// image_size: Size of image to generate (in pixels)
|
||||
// radius: Radius of points to rasterize (NDC units)
|
||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||
// each point in points.
|
||||
// bin_size: Size of each bin (in pixels)
|
||||
// points_per_pixel: How many points to rasterize for each pixel
|
||||
//
|
||||
@@ -210,7 +215,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
||||
const torch::Tensor& points,
|
||||
const torch::Tensor& bin_points,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int points_per_pixel) {
|
||||
if (points.is_cuda()) {
|
||||
@@ -296,7 +301,8 @@ torch::Tensor RasterizePointsBackward(
|
||||
// in the batch where N is the batch size.
|
||||
// num_points_per_cloud: LongTensor of shape (N) giving the number of points
|
||||
// for each pointcloud in the batch.
|
||||
// radius: Radius of each point (in NDC units)
|
||||
// radius: FloatTensor of shape (P) giving the radius (in NDC units) of
|
||||
// each point in points.
|
||||
// image_size: (S) Size of the image to return (in pixels)
|
||||
// points_per_pixel: (K) The number of points to return for each pixel
|
||||
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
|
||||
@@ -320,7 +326,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
|
||||
const torch::Tensor& cloud_to_packed_first_idx,
|
||||
const torch::Tensor& num_points_per_cloud,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
|
||||
@@ -17,7 +17,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int points_per_pixel) {
|
||||
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
||||
|
||||
@@ -35,8 +35,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
|
||||
auto zbuf_a = zbuf.accessor<float, 4>();
|
||||
auto pix_dists_a = pix_dists.accessor<float, 4>();
|
||||
auto radius_a = radius.accessor<float, 1>();
|
||||
|
||||
const float radius2 = radius * radius;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Loop through each pointcloud in the batch.
|
||||
// Get the start index of the points in points_packed and the num points
|
||||
@@ -63,6 +63,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
const float px = points_a[p][0];
|
||||
const float py = points_a[p][1];
|
||||
const float pz = points_a[p][2];
|
||||
const float p_radius = radius_a[p];
|
||||
const float radius2 = p_radius * p_radius;
|
||||
if (pz < 0) {
|
||||
continue;
|
||||
}
|
||||
@@ -98,7 +100,7 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
const torch::Tensor& cloud_to_packed_first_idx, // (N)
|
||||
const torch::Tensor& num_points_per_cloud, // (N)
|
||||
const int image_size,
|
||||
const float radius,
|
||||
const torch::Tensor& radius,
|
||||
const int bin_size,
|
||||
const int max_points_per_bin) {
|
||||
const int32_t N = cloud_to_packed_first_idx.size(0); // batch_size.
|
||||
@@ -112,6 +114,7 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
auto points_a = points.accessor<float, 2>();
|
||||
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
|
||||
auto bin_points_a = bin_points.accessor<int32_t, 4>();
|
||||
auto radius_a = radius.accessor<float, 1>();
|
||||
|
||||
const float pixel_width = 2.0f / image_size;
|
||||
const float bin_width = pixel_width * bin_size;
|
||||
@@ -140,13 +143,14 @@ torch::Tensor RasterizePointsCoarseCpu(
|
||||
float px = points_a[p][0];
|
||||
float py = points_a[p][1];
|
||||
float pz = points_a[p][2];
|
||||
const float p_radius = radius_a[p];
|
||||
if (pz < 0) {
|
||||
continue;
|
||||
}
|
||||
float point_x_min = px - radius;
|
||||
float point_x_max = px + radius;
|
||||
float point_y_min = py - radius;
|
||||
float point_y_max = py + radius;
|
||||
float point_x_min = px - p_radius;
|
||||
float point_x_max = px + p_radius;
|
||||
float point_y_min = py - p_radius;
|
||||
float point_y_max = py + p_radius;
|
||||
|
||||
// Use a half-open interval so that points exactly on the
|
||||
// boundary between bins will fall into exactly one bin.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -18,7 +18,7 @@ kMaxPointsPerBin = 22
|
||||
def rasterize_points(
|
||||
pointclouds,
|
||||
image_size: int = 256,
|
||||
radius: float = 0.01,
|
||||
radius: Union[float, List, Tuple, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: Optional[int] = None,
|
||||
max_points_per_bin: Optional[int] = None,
|
||||
@@ -35,8 +35,10 @@ def rasterize_points(
|
||||
(0, 0, 0); In the camera coordinate frame the x-axis goes from right-to-left,
|
||||
the y-axis goes from bottom-to-top, and the z-axis goes from back-to-front.
|
||||
image_size: Integer giving the resolution of the rasterized image
|
||||
radius (Optional): Float giving the radius (in NDC units) of the disk to
|
||||
be rasterized for each point.
|
||||
radius (Optional): The radius (in NDC units) of the disk to
|
||||
be rasterized. This can either be a float in which case the same radius is used
|
||||
for each point, or a torch.Tensor of shape (N, P) giving a radius per point
|
||||
in the batch.
|
||||
points_per_pixel (Optional): We will keep track of this many points per
|
||||
pixel, returning the nearest points_per_pixel points along the z-axis
|
||||
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
|
||||
@@ -74,6 +76,8 @@ def rasterize_points(
|
||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds.num_points_per_cloud()
|
||||
|
||||
radius = _format_radius(radius, pointclouds)
|
||||
|
||||
if bin_size is None:
|
||||
if not points_packed.is_cuda:
|
||||
# Binned CPU rasterization not fully implemented
|
||||
@@ -117,6 +121,48 @@ def rasterize_points(
|
||||
)
|
||||
|
||||
|
||||
def _format_radius(
|
||||
radius: Union[float, List, Tuple, torch.Tensor], pointclouds
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Format the radius as a torch tensor of shape (P_packed,)
|
||||
where P_packed is the total number of points in the
|
||||
batch (i.e. pointclouds.points_packed().shape[0]).
|
||||
|
||||
This will enable support for a different size radius
|
||||
for each point in the batch.
|
||||
|
||||
Args:
|
||||
radius: can be a float, List, Tuple or tensor of
|
||||
shape (N, P_padded) where P_padded is the
|
||||
maximum number of points for each pointcloud
|
||||
in the batch.
|
||||
|
||||
Returns:
|
||||
radius: torch.Tensor of shape (P_packed)
|
||||
"""
|
||||
N, P_padded = pointclouds._N, pointclouds._P
|
||||
points_packed = pointclouds.points_packed()
|
||||
P_packed = points_packed.shape[0]
|
||||
if isinstance(radius, (list, tuple)):
|
||||
radius = torch.tensor(radius).type_as(points_packed)
|
||||
if isinstance(radius, torch.Tensor):
|
||||
if N == 1 and radius.ndim == 1:
|
||||
radius = radius[None, ...]
|
||||
if radius.shape != (N, P_padded):
|
||||
msg = "radius must be of shape (N, P): got %s"
|
||||
raise ValueError(msg % (repr(radius.shape)))
|
||||
else:
|
||||
padded_to_packed_idx = pointclouds.padded_to_packed_idx()
|
||||
radius = radius.view(-1)[padded_to_packed_idx]
|
||||
elif isinstance(radius, float):
|
||||
radius = torch.full((P_packed,), fill_value=radius).type_as(points_packed)
|
||||
else:
|
||||
msg = "radius must be a float, list, tuple or tensor; got %s"
|
||||
raise ValueError(msg % type(radius))
|
||||
return radius
|
||||
|
||||
|
||||
class _RasterizePoints(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -125,7 +171,7 @@ class _RasterizePoints(torch.autograd.Function):
|
||||
cloud_to_packed_first_idx,
|
||||
num_points_per_cloud,
|
||||
image_size: int = 256,
|
||||
radius: float = 0.01,
|
||||
radius: Union[float, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: int = 0,
|
||||
max_points_per_bin: int = 0,
|
||||
@@ -175,7 +221,10 @@ class _RasterizePoints(torch.autograd.Function):
|
||||
|
||||
|
||||
def rasterize_points_python(
|
||||
pointclouds, image_size: int = 256, radius: float = 0.01, points_per_pixel: int = 8
|
||||
pointclouds,
|
||||
image_size: int = 256,
|
||||
radius: Union[float, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
):
|
||||
"""
|
||||
Naive pure PyTorch implementation of pointcloud rasterization.
|
||||
@@ -190,6 +239,9 @@ def rasterize_points_python(
|
||||
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
|
||||
num_points_per_cloud = pointclouds.num_points_per_cloud()
|
||||
|
||||
# Support variable size radius for each point in the batch
|
||||
radius = _format_radius(radius, pointclouds)
|
||||
|
||||
# Intialize output tensors.
|
||||
point_idxs = torch.full(
|
||||
(N, S, S, K), fill_value=-1, dtype=torch.int32, device=device
|
||||
@@ -225,12 +277,13 @@ def rasterize_points_python(
|
||||
# Check whether each point in the batch affects this pixel.
|
||||
for p in range(point_start_idx, point_stop_idx):
|
||||
px, py, pz = points_packed[p, :]
|
||||
r = radius2[p]
|
||||
if pz < 0:
|
||||
continue
|
||||
dx = px - xf
|
||||
dy = py - yf
|
||||
dist2 = dx * dx + dy * dy
|
||||
if dist2 < radius2:
|
||||
if dist2 < r:
|
||||
top_k_points.append((pz, p, dist2))
|
||||
top_k_points.sort()
|
||||
if len(top_k_points) > K:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -30,7 +30,7 @@ class PointsRasterizationSettings:
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 256,
|
||||
radius: float = 0.01,
|
||||
radius: Union[float, torch.Tensor] = 0.01,
|
||||
points_per_pixel: int = 8,
|
||||
bin_size: Optional[int] = None,
|
||||
max_points_per_bin: Optional[int] = None,
|
||||
|
||||
Reference in New Issue
Block a user