mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-10 14:15:58 +08:00
Non Square image rasterization for pointclouds
Summary: Similar to non square image rasterization for meshes, apply the same updates to the pointcloud rasterizer. Main API Change: - PointRasterizationSettings now accepts a tuple/list of (H, W) for the image size. Reviewed By: jcjohnson Differential Revision: D25465206 fbshipit-source-id: 7370d83c431af1b972158cecae19d82364623380
This commit is contained in:
committed by
Facebook GitHub Bot
parent
569e5229a9
commit
3d769a66cb
@@ -30,15 +30,15 @@ __global__ void alphaCompositeCudaForwardKernel(
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Iterate over each feature in each pixel
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
int ch = pid / (H * W);
|
||||
int j = (pid % (H * W)) / W;
|
||||
int i = (pid % (H * W)) % W;
|
||||
|
||||
// alphacomposite the different values
|
||||
float cum_alpha = 1.;
|
||||
@@ -81,16 +81,16 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_pixels = C * H * W;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
int ch = pid / (H * W);
|
||||
int j = (pid % (H * W)) / W;
|
||||
int i = (pid % (H * W)) % W;
|
||||
|
||||
// alphacomposite the different values
|
||||
float cum_alpha = 1.;
|
||||
|
||||
Reference in New Issue
Block a user