Update point cloud rasterizer to support heterogeneous point clouds

Summary:
Update the point cloud rasterizer to:
- use the pointcloud datastructure (rebased on top of D19791851.)
- support rasterization of heterogeneous point clouds in the same way as with Meshes.

The main changes to the API will be as follows:
- The input to `rasterize_points` will be a `Pointclouds` object instead of a tensor. This will be easy to update e.g.
```
points = torch.randn(N, P, 3)
idx2, zbuf2, dists2 = rasterize_points(points, image_size, radius, points_per_pixel)

points = torch.randn(N, P, 3)
pointclouds = Pointclouds(points=points)
idx2, zbuf2, dists2 = rasterize_points(pointclouds, image_size, radius, points_per_pixel)
```

- The indices output from rasterization will now refer to points in `poinclouds.points_packed()`.
This may require some changes to the functions which consume the outputs of rasterization if they were previously
assuming that the indices ranged from 0 to P where P is the number of points in each pointcloud.

Making this change now so that Olivia can update her PR accordingly.

Reviewed By: gkioxari

Differential Revision: D20088651

fbshipit-source-id: 833ed659909712bcbbb6a50e2ec0189839f0413a
This commit is contained in:
Nikhila Ravi
2020-03-12 07:46:15 -07:00
committed by Facebook GitHub Bot
parent cae325718e
commit 32ad869dea
4 changed files with 241 additions and 102 deletions

View File

@@ -102,16 +102,16 @@ __device__ bool CheckPointOutsideBoundingBox(
// RasterizeMeshesFineCudaKernel.
template <typename FaceQ>
__device__ void CheckPixelInsideFace(
const float* face_verts, // (N, P, 3)
int face_idx,
const float* face_verts, // (F, 3, 3)
const int face_idx,
int& q_size,
float& q_max_z,
int& q_max_idx,
FaceQ& q,
float blur_radius,
float2 pxy, // Coordinates of the pixel
int K,
bool perspective_correct) {
const float blur_radius,
const float2 pxy, // Coordinates of the pixel
const int K,
const bool perspective_correct) {
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
@@ -335,7 +335,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const int64_t* pix_to_face, // (N, H, W, K)
const bool perspective_correct,
const int N,
const int F,
const int H,
const int W,
const int K,
@@ -472,7 +471,6 @@ torch::Tensor RasterizeMeshesBackwardCuda(
pix_to_face.contiguous().data<int64_t>(),
perspective_correct,
N,
F,
H,
W,
K,
@@ -671,7 +669,6 @@ __global__ void RasterizeMeshesFineCudaKernel(
const int bin_size,
const bool perspective_correct,
const int N,
const int F,
const int B,
const int M,
const int H,
@@ -774,7 +771,6 @@ RasterizeMeshesFineCuda(
if (bin_faces.ndimension() != 4) {
AT_ERROR("bin_faces must have 4 dimensions");
}
const int F = face_verts.size(0);
const int N = bin_faces.size(0);
const int B = bin_faces.size(1);
const int M = bin_faces.size(3);
@@ -803,7 +799,6 @@ RasterizeMeshesFineCuda(
bin_size,
perspective_correct,
N,
F,
B,
M,
H,