Fix coordinate system conventions in renderer

Summary:
## Updates

- Defined the world and camera coordinates according to this figure. The world coordinates are defined as having +Y up, +X left and +Z in.

{F230888499}

- Removed all flipping from blending functions.
- Updated the rasterizer to return images with +Y up and +X left.
- Updated all the mesh rasterizer tests
    - The expected values are now defined in terms of the default +Y up, +X left
    - Added tests where the triangles in the meshes are non symmetrical so that it is clear which direction +X and +Y are

## Questions:
- Should we have **scene settings** instead of raster settings?
    - To be more correct we should be [z clipping in the rasterizer based on the far/near clipping planes](https://github.com/ShichenLiu/SoftRas/blob/master/soft_renderer/cuda/soft_rasterize_cuda_kernel.cu#L400) - these values are also required in the blending functions so should we make these scene level parameters and have a scene settings tuple which is available to the rasterizer and shader?

Reviewed By: gkioxari

Differential Revision: D20208604

fbshipit-source-id: 55787301b1bffa0afa9618f0a0886cc681da51f3
This commit is contained in:
Nikhila Ravi
2020-03-06 06:48:31 -08:00
committed by Facebook Github Bot
parent 767d68a3af
commit 15c72be444
27 changed files with 526 additions and 486 deletions

View File

@@ -189,12 +189,12 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
float blur_radius,
bool perspective_correct,
int N,
int H,
int W,
int K,
const float blur_radius,
const bool perspective_correct,
const int N,
const int H,
const int W,
const int K,
int64_t* face_idxs,
float* zbuf,
float* pix_dists,
@@ -207,8 +207,10 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// Convert linear index to 3D index
const int n = i / (H * W); // batch index.
const int pix_idx = i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;
// Determine ordering based on axis convention.
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;
// screen coordinates to ndc coordiantes of pixel.
const float xf = PixToNdc(xi, W);
@@ -254,7 +256,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
int idx = n * H * W * K + yi * H * K + xi * K;
int idx = n * H * W * K + pix_idx * K;
for (int k = 0; k < q_size; ++k) {
face_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z;
@@ -274,7 +276,7 @@ RasterizeMeshesNaiveCuda(
const int image_size,
const float blur_radius,
const int num_closest,
bool perspective_correct) {
const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@@ -331,12 +333,12 @@ RasterizeMeshesNaiveCuda(
__global__ void RasterizeMeshesBackwardCudaKernel(
const float* face_verts, // (F, 3, 3)
const int64_t* pix_to_face, // (N, H, W, K)
bool perspective_correct,
int N,
int F,
int H,
int W,
int K,
const bool perspective_correct,
const int N,
const int F,
const int H,
const int W,
const int K,
const float* grad_zbuf, // (N, H, W, K)
const float* grad_bary, // (N, H, W, K, 3)
const float* grad_dists, // (N, H, W, K)
@@ -351,8 +353,11 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// Convert linear index to 3D index
const int n = t_i / (H * W); // batch index.
const int pix_idx = t_i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;
// Determine ordering based on axis convention.
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);
@@ -360,8 +365,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// Loop over all the faces for this pixel.
for (int k = 0; k < K; k++) {
// Index into (N, H, W, K, :) grad tensors
const int i =
n * H * W * K + yi * H * K + xi * K + k; // pixel index + face index
// pixel index + top k index
int i = n * H * W * K + pix_idx * K + k;
const int f = pix_to_face[i];
if (f < 0) {
@@ -451,7 +456,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) {
const bool perspective_correct) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
@@ -509,6 +514,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
@@ -551,17 +557,21 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin.
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
// Reverse ordering of Y axis so that +Y is upwards in the image.
const int yidx = num_bins - by;
float bin_y_max = PixToNdc(yidx * bin_size - 1, H) + half_pix;
float bin_y_min = PixToNdc((yidx - 1) * bin_size, H) - half_pix;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
for (int bx = 0; bx < num_bins; ++bx) {
// X coordinate of the left and right of the bin.
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
const float bin_x_max =
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
// Reverse ordering of x axis so that +X is left.
const int xidx = num_bins - bx;
float bin_x_max = PixToNdc(xidx * bin_size - 1, W) + half_pix;
float bin_x_min = PixToNdc((xidx - 1) * bin_size, W) - half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
binmask.set(by, bx, f);
}
@@ -654,7 +664,6 @@ torch::Tensor RasterizeMeshesCoarseCuda(
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizeMeshesFineCudaKernel(
const float* face_verts, // (F, 3, 3)
const int32_t* bin_faces, // (N, B, B, T)
@@ -695,8 +704,14 @@ __global__ void RasterizeMeshesFineCudaKernel(
if (yi >= H || xi >= W)
continue;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
// Reverse ordering of the X and Y axis so that
// in the image +Y is pointing up and +X is pointing left.
const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi;
const float xf = PixToNdc(xidx, W);
const float yf = PixToNdc(yidx, H);
const float2 pxy = make_float2(xf, yf);
// This part looks like the naive rasterization kernel, except we use
@@ -751,7 +766,7 @@ RasterizeMeshesFineCuda(
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
bool perspective_correct) {
const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");