Non square image rasterization for meshes

Summary:
There are a couple of options for supporting non square images:
1) NDC stays at [-1, 1] in both directions with the distance calculations all modified by (W/H). There are a lot of distance based calculations (e.g. triangle areas for barycentric coordinates etc) so this requires changes in many places.
2) NDC is scaled by (W/H) so the smallest side has [-1, 1]. In this case none of the distance calculations need to be updated and only the pixel to NDC calculation needs to be modified.

I decided to go with option 2 after trying option 1!

API Changes:
- Image size can now be specified optionally as a tuple

TODO:
- add a benchmark test for the non square case.

Reviewed By: jcjohnson

Differential Revision: D24404975

fbshipit-source-id: 545efb67c822d748ec35999b35762bce58db2cf4
This commit is contained in:
Nikhila Ravi
2020-12-09 09:16:57 -08:00
committed by Facebook GitHub Bot
parent 0216e4689a
commit d07307a451
13 changed files with 774 additions and 115 deletions

View File

@@ -234,8 +234,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const int xi = W - 1 - pix_idx % W;
// screen coordinates to ndc coordiantes of pixel.
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNonSquareNdc(yi, H, W);
const float2 pxy = make_float2(xf, yf);
// For keeping track of the K closest points we want a data structure
@@ -262,6 +262,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
for (int f = face_start_idx; f < face_stop_idx; ++f) {
// Check if the pixel pxy is inside the face bounding box and if it is,
// update q, q_size, q_max_z and q_max_idx in place.
CheckPixelInsideFace(
face_verts,
f,
@@ -280,6 +281,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
int idx = n * H * W * K + pix_idx * K;
for (int k = 0; k < q_size; ++k) {
face_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z;
@@ -296,7 +298,7 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_faces_packed_first_idx,
const at::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int num_closest,
const bool perspective_correct,
@@ -332,8 +334,8 @@ RasterizeMeshesNaiveCuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int N = num_faces_per_mesh.size(0); // batch size.
const int H = image_size; // Assume square images.
const int W = image_size;
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int K = num_closest;
auto long_opts = num_faces_per_mesh.options().dtype(at::kLong);
@@ -405,8 +407,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNonSquareNdc(yi, H, W);
const float2 pxy = make_float2(xf, yf);
// Loop over all the faces for this pixel.
@@ -589,12 +591,25 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
int* bin_faces) {
extern __shared__ char sbuf[];
const int M = max_faces_per_bin;
const int num_bins = 1 + (W - 1) / bin_size; // Integer divide round up
const float half_pix = 1.0f / W; // Size of half a pixel in NDC units
// Integer divide round up
const int num_bins_x = 1 + (W - 1) / bin_size;
const int num_bins_y = 1 + (H - 1) / bin_size;
// NDC range depends on the ratio of W/H
// The shorter side from (H, W) is given an NDC range of 2.0 and
// the other side is scaled by the ratio of H:W.
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;
// Size of half a pixel in NDC units is the NDC half range
// divided by the corresponding image dimension
const float half_pix_x = NDC_x_half_range / W;
const float half_pix_y = NDC_y_half_range / H;
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
// stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size);
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
// Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
@@ -641,21 +656,24 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
}
// Brute-force search over all bins; TODO(T54294966) something smarter.
for (int by = 0; by < num_bins; ++by) {
for (int by = 0; by < num_bins_y; ++by) {
// Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin.
// Reverse ordering of Y axis so that +Y is upwards in the image.
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
const float bin_y_min =
PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
const float bin_y_max =
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
for (int bx = 0; bx < num_bins; ++bx) {
for (int bx = 0; bx < num_bins_x; ++bx) {
// X coordinate of the left and right of the bin.
// Reverse ordering of x axis so that +X is left.
const float bin_x_max =
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
const float bin_x_min =
PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
@@ -668,12 +686,13 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Now we have processed every face in the current chunk. We need to
// count the number of faces in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) {
const int by = byx / num_bins;
const int bx = byx % num_bins;
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
byx += blockDim.x) {
const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;
const int count = binmask.count(by, bx);
const int faces_per_bin_idx =
batch_idx * num_bins * num_bins + by * num_bins + bx;
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
// This atomically increments the (global) number of faces found
// in the current bin, and gets the previous value of the counter;
@@ -683,8 +702,8 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Now loop over the binmask and write the active bits for this bin
// out to bin_faces.
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M +
bx * M + start;
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
by * num_bins_x * M + bx * M + start;
for (int f = 0; f < chunk_size; ++f) {
if (binmask.get(by, bx, f)) {
// TODO(T54296346) find the correct method for handling errors in
@@ -703,7 +722,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
@@ -725,21 +744,27 @@ at::Tensor RasterizeMeshesCoarseCuda(
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int W = image_size;
const int H = image_size;
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int F = face_verts.size(0);
const int N = num_faces_per_mesh.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
const int M = max_faces_per_bin;
if (num_bins >= kMaxFacesPerBin) {
// Integer divide round up.
const int num_bins_y = 1 + (H - 1) / bin_size;
const int num_bins_x = 1 + (W - 1) / bin_size;
if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) {
std::stringstream ss;
ss << "Got " << num_bins << "; that's too many!";
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", "
<< "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = num_faces_per_mesh.options().dtype(at::kInt);
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
if (bin_faces.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError());
@@ -747,7 +772,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
}
const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;
@@ -782,7 +807,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
const bool clip_barycentric_coords,
const bool cull_backfaces,
const int N,
const int B,
const int BH,
const int BW,
const int M,
const int H,
const int W,
@@ -793,7 +819,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
float* bary // (N, S, S, K, 3)
) {
// This can be more than S^2 if S % bin_size != 0
int num_pixels = N * B * B * bin_size * bin_size;
int num_pixels = N * BH * BW * bin_size * bin_size;
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -803,20 +829,26 @@ __global__ void RasterizeMeshesFineCudaKernel(
// into the same bin; this should give them coalesced memory reads when
// they read from faces and bin_faces.
int i = pid;
const int n = i / (B * B * bin_size * bin_size);
i %= B * B * bin_size * bin_size;
const int by = i / (B * bin_size * bin_size);
i %= B * bin_size * bin_size;
const int n = i / (BH * BW * bin_size * bin_size);
i %= BH * BW * bin_size * bin_size;
// bin index y
const int by = i / (BW * bin_size * bin_size);
i %= BW * bin_size * bin_size;
// bin index y
const int bx = i / (bin_size * bin_size);
// pixel within the bin
i %= bin_size * bin_size;
// Pixel x, y indices
const int yi = i / bin_size + by * bin_size;
const int xi = i % bin_size + bx * bin_size;
if (yi >= H || xi >= W)
continue;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNonSquareNdc(yi, H, W);
const float2 pxy = make_float2(xf, yf);
// This part looks like the naive rasterization kernel, except we use
@@ -828,7 +860,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
float q_max_z = -1000;
int q_max_idx = -1;
for (int m = 0; m < M; m++) {
const int f = bin_faces[n * B * B * M + by * B * M + bx * M + m];
const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m];
if (f < 0) {
continue; // bin_faces uses -1 as a sentinal value.
}
@@ -858,7 +890,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
// in the image +Y is pointing up and +X is pointing left.
const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi;
const int pix_idx = n * H * W * K + yidx * H * K + xidx * K;
const int pix_idx = n * H * W * K + yidx * W * K + xidx * K;
for (int k = 0; k < q_size; k++) {
face_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z;
@@ -874,7 +907,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesFineCuda(
const at::Tensor& face_verts,
const at::Tensor& bin_faces,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
@@ -897,12 +930,15 @@ RasterizeMeshesFineCuda(
at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// bin_faces shape (N, BH, BW, M)
const int N = bin_faces.size(0);
const int B = bin_faces.size(1);
const int BH = bin_faces.size(1);
const int BW = bin_faces.size(2);
const int M = bin_faces.size(3);
const int K = faces_per_pixel;
const int H = image_size; // Assume square images only.
const int W = image_size;
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 150");
@@ -932,7 +968,8 @@ RasterizeMeshesFineCuda(
clip_barycentric_coords,
cull_backfaces,
N,
B,
BH,
BW,
M,
H,
W,

View File

@@ -15,7 +15,7 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
@@ -28,7 +28,7 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int num_closest,
const bool perspective_correct,
@@ -48,8 +48,8 @@ RasterizeMeshesNaiveCuda(
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized.
// Assume square images only.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
@@ -90,7 +90,7 @@ RasterizeMeshesNaive(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
@@ -223,7 +223,7 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin);
@@ -233,7 +233,7 @@ torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin);
@@ -249,7 +249,8 @@ torch::Tensor RasterizeMeshesCoarseCuda(
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
@@ -264,7 +265,7 @@ torch::Tensor RasterizeMeshesCoarse(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
@@ -305,7 +306,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
@@ -321,7 +322,8 @@ RasterizeMeshesFineCuda(
// in NDC coordinates in the range [-1, 1].
// bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces
// that fall into each bin (output from coarse rasterization).
// image_size: Size in pixels of the output image to be rasterized.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
@@ -362,7 +364,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFine(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
@@ -409,7 +411,8 @@ RasterizeMeshesFine(
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized.
// image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
@@ -453,7 +456,7 @@ RasterizeMeshes(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
const int bin_size,

View File

@@ -9,9 +9,35 @@
#include "utils/vec2.h"
#include "utils/vec3.h"
float PixToNdc(int i, int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
// the NDC range is calculated and S2 is the other image dimension.
// e.g. to get the NDC x range S1 = W and S2 = H
float NonSquareNdcRange(int S1, int S2) {
float range = 2.0f;
if (S1 > S2) {
range = ((S1 / S2) * range);
}
return range;
}
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
// coordinates. We divide the NDC range into S1 evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. The dimension of i should be
// S1 and the other image dimension is S2 For example, to get the x and y NDC
// coordinates or a given pixel i:
// x = PixToNonSquareNdc(i, W, H)
// y = PixToNonSquareNdc(i, H, W)
float PixToNonSquareNdc(int i, int S1, int S2) {
float range = NonSquareNdcRange(S1, S2);
// NDC: offset + (i * pixel_width + half_pixel_width)
// The NDC range is [-range/2, range/2].
const float offset = (range / 2.0f);
return -offset + (range * i + offset) / S1;
}
// Get (x, y, z) values for vertex from (3, 3) tensor face.
@@ -108,7 +134,7 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
@@ -124,8 +150,8 @@ RasterizeMeshesNaiveCpu(
}
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
const int H = image_size;
const int W = image_size;
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int K = faces_per_pixel;
auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
@@ -163,7 +189,7 @@ RasterizeMeshesNaiveCpu(
const int yidx = H - 1 - yi;
// Y coordinate of the top of the pixel.
const float yf = PixToNdc(yidx, H);
const float yf = PixToNonSquareNdc(yidx, H, W);
// Iterate through pixels on this horizontal line, left to right.
for (int xi = 0; xi < W; ++xi) {
// Reverse the order of xi so that +X is pointing to the left in the
@@ -171,7 +197,7 @@ RasterizeMeshesNaiveCpu(
const int xidx = W - 1 - xi;
// X coordinate of the left of the pixel.
const float xf = PixToNdc(xidx, W);
const float xf = PixToNonSquareNdc(xidx, W, H);
// Use a priority queue to hold values:
// (z, idx, r, bary.x, bary.y. bary.z)
std::priority_queue<std::tuple<float, int, float, float, float, float>>
@@ -295,7 +321,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const int yidx = H - 1 - y;
// Y coordinate of the top of the pixel.
const float yf = PixToNdc(yidx, H);
const float yf = PixToNonSquareNdc(yidx, H, W);
// Iterate through pixels on this horizontal line, left to right.
for (int x = 0; x < W; ++x) {
// Reverse the order of xi so that +X is pointing to the left in the
@@ -303,7 +329,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const int xidx = W - 1 - x;
// X coordinate of the left of the pixel.
const float xf = PixToNdc(xidx, W);
const float xf = PixToNonSquareNdc(xidx, W, H);
const vec2<float> pxy(xf, yf);
// Iterate through the faces that hit this pixel.
@@ -353,7 +379,6 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f;
// TODO(T52813608) Add support for non-square images.
const auto grad_dist_f = PointTriangleDistanceBackward(
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
const auto ddist_d_v0 = std::get<1>(grad_dist_f);
@@ -415,7 +440,7 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const std::tuple<int, int> image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
@@ -430,11 +455,12 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const int N = num_faces_per_mesh.size(0); // batch size.
const int M = max_faces_per_bin;
// Assume square images. TODO(T52813608) Support non square images.
const float height = image_size;
const float width = image_size;
const int BH = 1 + (height - 1) / bin_size; // Integer division round up.
const int BW = 1 + (width - 1) / bin_size; // Integer division round up.
const float H = std::get<0>(image_size);
const float W = std::get<1>(image_size);
// Integer division round up.
const int BH = 1 + (H - 1) / bin_size;
const int BW = 1 + (W - 1) / bin_size;
auto opts = num_faces_per_mesh.options().dtype(torch::kInt32);
torch::Tensor faces_per_bin = torch::zeros({N, BH, BW}, opts);
@@ -445,8 +471,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
const float pixel_width = 2.0f / image_size;
const float bin_width = pixel_width * bin_size;
const float ndc_x_range = NonSquareNdcRange(W, H);
const float pixel_width_x = ndc_x_range / W;
const float bin_width_x = pixel_width_x * bin_size;
const float ndc_y_range = NonSquareNdcRange(H, W);
const float pixel_width_y = ndc_y_range / H;
const float bin_width_y = pixel_width_y * bin_size;
// Iterate through the meshes in the batch.
for (int n = 0; n < N; ++n) {
@@ -455,12 +486,12 @@ torch::Tensor RasterizeMeshesCoarseCpu(
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width;
float bin_y_max = bin_y_min + bin_width_y;
// Iterate through the horizontal bins from top to bottom.
for (int by = 0; by < BH; ++by) {
float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width;
float bin_x_max = bin_x_min + bin_width_x;
// Iterate through bins on this horizontal line, left to right.
for (int bx = 0; bx < BW; ++bx) {
@@ -502,11 +533,11 @@ torch::Tensor RasterizeMeshesCoarseCpu(
// Shift the bin to the right for the next loop iteration
bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width;
bin_x_max = bin_x_min + bin_width_x;
}
// Shift the bin down for the next loop iteration
bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width;
bin_y_max = bin_y_min + bin_width_y;
}
}
return bin_faces;

View File

@@ -3,11 +3,44 @@
#pragma once
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinate in the range [-1, 1]. We divide the NDC range into S evenly-sized
// coordinates in the range [-1, 1]. We divide the NDC range into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
// TODO: delete this function after updating the pointcloud rasterizer to
// support non square images.
__device__ inline float PixToNdc(int i, int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
// NDC: x-offset + (i * pixel_width + half_pixel_width)
return -1.0 + (2 * i + 1.0) / S;
}
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
// the NDC range is calculated and S2 is the other image dimension.
// e.g. to get the NDC x range S1 = W and S2 = H
__device__ inline float NonSquareNdcRange(int S1, int S2) {
float range = 2.0f;
if (S1 > S2) {
range = ((S1 / S2) * range);
}
return range;
}
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
// coordinates. We divide the NDC range into S1 evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. The dimension of i should be
// S1 and the other image dimension is S2 For example, to get the x and y NDC
// coordinates or a given pixel i:
// x = PixToNonSquareNdc(i, W, H)
// y = PixToNonSquareNdc(i, H, W)
__device__ inline float PixToNonSquareNdc(int i, int S1, int S2) {
float range = NonSquareNdcRange(S1, S2);
// NDC: offset + (i * pixel_width + half_pixel_width)
// The NDC range is [-range/2, range/2].
float offset = (range / 2.0f);
return -offset + (range * i + offset) / S1;
}
// The maximum number of points per pixel that we can return. Since we use