mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-11 14:55:59 +08:00
Non Square image rasterization for pointclouds
Summary: Similar to non square image rasterization for meshes, apply the same updates to the pointcloud rasterizer. Main API Change: - PointRasterizationSettings now accepts a tuple/list of (H, W) for the image size. Reviewed By: jcjohnson Differential Revision: D25465206 fbshipit-source-id: 7370d83c431af1b972158cecae19d82364623380
This commit is contained in:
committed by
Facebook GitHub Bot
parent
569e5229a9
commit
3d769a66cb
@@ -452,7 +452,6 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
const bool inside = b_pp.x > 0.0f && b_pp.y > 0.0f && b_pp.z > 0.0f;
|
||||
const float sign = inside ? -1.0f : 1.0f;
|
||||
|
||||
// TODO(T52813608) Add support for non-square images.
|
||||
auto grad_dist_f = PointTriangleDistanceBackward(
|
||||
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
|
||||
const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f);
|
||||
@@ -606,7 +605,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
||||
const float half_pix_x = NDC_x_half_range / W;
|
||||
const float half_pix_y = NDC_y_half_range / H;
|
||||
|
||||
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
|
||||
// This is a boolean array of shape (num_bins_y, num_bins_x, chunk_size)
|
||||
// stored in shared memory that will track whether each point in the chunk
|
||||
// falls into each bin of the image.
|
||||
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
|
||||
@@ -755,7 +754,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
const int num_bins_y = 1 + (H - 1) / bin_size;
|
||||
const int num_bins_x = 1 + (W - 1) / bin_size;
|
||||
|
||||
if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) {
|
||||
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
|
||||
std::stringstream ss;
|
||||
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
|
||||
<< ", num_bins_x: " << num_bins_x << ", "
|
||||
@@ -800,7 +799,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
// ****************************************************************************
|
||||
__global__ void RasterizeMeshesFineCudaKernel(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int32_t* bin_faces, // (N, B, B, T)
|
||||
const int32_t* bin_faces, // (N, BH, BW, T)
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const bool perspective_correct,
|
||||
@@ -813,12 +812,12 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
const int H,
|
||||
const int W,
|
||||
const int K,
|
||||
int64_t* face_idxs, // (N, S, S, K)
|
||||
float* zbuf, // (N, S, S, K)
|
||||
float* pix_dists, // (N, S, S, K)
|
||||
float* bary // (N, S, S, K, 3)
|
||||
int64_t* face_idxs, // (N, H, W, K)
|
||||
float* zbuf, // (N, H, W, K)
|
||||
float* pix_dists, // (N, H, W, K)
|
||||
float* bary // (N, H, W, K, 3)
|
||||
) {
|
||||
// This can be more than S^2 if S % bin_size != 0
|
||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||
int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
Reference in New Issue
Block a user