Fix coordinate system conventions in renderer

Summary:
## Updates

- Defined the world and camera coordinates according to this figure. The world coordinates are defined as having +Y up, +X left and +Z in.

{F230888499}

- Removed all flipping from blending functions.
- Updated the rasterizer to return images with +Y up and +X left.
- Updated all the mesh rasterizer tests
    - The expected values are now defined in terms of the default +Y up, +X left
    - Added tests where the triangles in the meshes are non symmetrical so that it is clear which direction +X and +Y are

## Questions:
- Should we have **scene settings** instead of raster settings?
    - To be more correct we should be [z clipping in the rasterizer based on the far/near clipping planes](https://github.com/ShichenLiu/SoftRas/blob/master/soft_renderer/cuda/soft_rasterize_cuda_kernel.cu#L400) - these values are also required in the blending functions so should we make these scene level parameters and have a scene settings tuple which is available to the rasterizer and shader?

Reviewed By: gkioxari

Differential Revision: D20208604

fbshipit-source-id: 55787301b1bffa0afa9618f0a0886cc681da51f3
This commit is contained in:
Nikhila Ravi
2020-03-06 06:48:31 -08:00
committed by Facebook Github Bot
parent 767d68a3af
commit 15c72be444
27 changed files with 526 additions and 486 deletions

View File

@@ -189,12 +189,12 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
float blur_radius,
bool perspective_correct,
int N,
int H,
int W,
int K,
const float blur_radius,
const bool perspective_correct,
const int N,
const int H,
const int W,
const int K,
int64_t* face_idxs,
float* zbuf,
float* pix_dists,
@@ -207,8 +207,10 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// Convert linear index to 3D index
const int n = i / (H * W); // batch index.
const int pix_idx = i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;
// Determine ordering based on axis convention.
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;
// screen coordinates to ndc coordiantes of pixel.
const float xf = PixToNdc(xi, W);
@@ -254,7 +256,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
int idx = n * H * W * K + yi * H * K + xi * K;
int idx = n * H * W * K + pix_idx * K;
for (int k = 0; k < q_size; ++k) {
face_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z;
@@ -274,7 +276,7 @@ RasterizeMeshesNaiveCuda(
const int image_size,
const float blur_radius,
const int num_closest,
bool perspective_correct) {
const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@@ -331,12 +333,12 @@ RasterizeMeshesNaiveCuda(
__global__ void RasterizeMeshesBackwardCudaKernel(
const float* face_verts, // (F, 3, 3)
const int64_t* pix_to_face, // (N, H, W, K)
bool perspective_correct,
int N,
int F,
int H,
int W,
int K,
const bool perspective_correct,
const int N,
const int F,
const int H,
const int W,
const int K,
const float* grad_zbuf, // (N, H, W, K)
const float* grad_bary, // (N, H, W, K, 3)
const float* grad_dists, // (N, H, W, K)
@@ -351,8 +353,11 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// Convert linear index to 3D index
const int n = t_i / (H * W); // batch index.
const int pix_idx = t_i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;
// Determine ordering based on axis convention.
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);
@@ -360,8 +365,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// Loop over all the faces for this pixel.
for (int k = 0; k < K; k++) {
// Index into (N, H, W, K, :) grad tensors
const int i =
n * H * W * K + yi * H * K + xi * K + k; // pixel index + face index
// pixel index + top k index
int i = n * H * W * K + pix_idx * K + k;
const int f = pix_to_face[i];
if (f < 0) {
@@ -451,7 +456,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) {
const bool perspective_correct) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
@@ -509,6 +514,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
@@ -551,17 +557,21 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin.
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
// Reverse ordering of Y axis so that +Y is upwards in the image.
const int yidx = num_bins - by;
float bin_y_max = PixToNdc(yidx * bin_size - 1, H) + half_pix;
float bin_y_min = PixToNdc((yidx - 1) * bin_size, H) - half_pix;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
for (int bx = 0; bx < num_bins; ++bx) {
// X coordinate of the left and right of the bin.
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
const float bin_x_max =
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
// Reverse ordering of x axis so that +X is left.
const int xidx = num_bins - bx;
float bin_x_max = PixToNdc(xidx * bin_size - 1, W) + half_pix;
float bin_x_min = PixToNdc((xidx - 1) * bin_size, W) - half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
binmask.set(by, bx, f);
}
@@ -654,7 +664,6 @@ torch::Tensor RasterizeMeshesCoarseCuda(
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizeMeshesFineCudaKernel(
const float* face_verts, // (F, 3, 3)
const int32_t* bin_faces, // (N, B, B, T)
@@ -695,8 +704,14 @@ __global__ void RasterizeMeshesFineCudaKernel(
if (yi >= H || xi >= W)
continue;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
// Reverse ordering of the X and Y axis so that
// in the image +Y is pointing up and +X is pointing left.
const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi;
const float xf = PixToNdc(xidx, W);
const float yf = PixToNdc(yidx, H);
const float2 pxy = make_float2(xf, yf);
// This part looks like the naive rasterization kernel, except we use
@@ -751,7 +766,7 @@ RasterizeMeshesFineCuda(
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
bool perspective_correct) {
const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");

View File

@@ -14,10 +14,10 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct);
const int image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct);
#ifdef WITH_CUDA
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -25,10 +25,10 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int num_closest,
bool perspective_correct);
const int image_size,
const float blur_radius,
const int num_closest,
const bool perspective_correct);
#endif
// Forward pass for rasterizing a batch of meshes.
//
@@ -77,10 +77,10 @@ RasterizeMeshesNaive(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct) {
const int image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct) {
// TODO: Better type checking.
if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA
@@ -117,7 +117,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
bool perspective_correct);
const bool perspective_correct);
#ifdef WITH_CUDA
torch::Tensor RasterizeMeshesBackwardCuda(
@@ -126,7 +126,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
bool perspective_correct);
const bool perspective_correct);
#endif
// Args:
@@ -159,7 +159,7 @@ torch::Tensor RasterizeMeshesBackward(
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_dists,
bool perspective_correct) {
const bool perspective_correct) {
if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA
return RasterizeMeshesBackwardCuda(
@@ -191,20 +191,20 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin);
const int image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin);
#ifdef WITH_CUDA
torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin);
const int image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin);
#endif
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
@@ -232,10 +232,10 @@ torch::Tensor RasterizeMeshesCoarse(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin) {
const int image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA
return RasterizeMeshesCoarseCuda(
@@ -270,11 +270,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
int image_size,
float blur_radius,
int bin_size,
int faces_per_pixel,
bool perspective_correct);
const int image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct);
#endif
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
@@ -317,11 +317,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFine(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
int image_size,
float blur_radius,
int bin_size,
int faces_per_pixel,
bool perspective_correct) {
const int image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct) {
if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA
return RasterizeMeshesFineCuda(
@@ -373,6 +373,7 @@ RasterizeMeshesFine(
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
//
// Returns:
// A 4 element tuple of:
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
@@ -394,12 +395,12 @@ RasterizeMeshes(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
int bin_size,
int max_faces_per_bin,
bool perspective_correct) {
const int image_size,
const float blur_radius,
const int faces_per_pixel,
const int bin_size,
const int max_faces_per_bin,
const bool perspective_correct) {
if (bin_size > 0 && max_faces_per_bin > 0) {
// Use coarse-to-fine rasterization
auto bin_faces = RasterizeMeshesCoarse(

View File

@@ -105,9 +105,9 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct) {
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@@ -153,12 +153,19 @@ RasterizeMeshesNaiveCpu(
// Iterate through the horizontal lines of the image from top to bottom.
for (int yi = 0; yi < H; ++yi) {
// Reverse the order of yi so that +Y is pointing upwards in the image.
const int yidx = H - 1 - yi;
// Y coordinate of the top of the pixel.
const float yf = PixToNdc(yi, H);
const float yf = PixToNdc(yidx, H);
// Iterate through pixels on this horizontal line, left to right.
for (int xi = 0; xi < W; ++xi) {
// Reverse the order of xi so that +X is pointing to the left in the
// image.
const int xidx = W - 1 - xi;
// X coordinate of the left of the pixel.
const float xf = PixToNdc(xi, W);
const float xf = PixToNdc(xidx, W);
// Use a priority queue to hold values:
// (z, idx, r, bary.x, bary.y. bary.z)
std::priority_queue<std::tuple<float, int, float, float, float, float>>
@@ -250,7 +257,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) {
const bool perspective_correct) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
@@ -267,12 +274,19 @@ torch::Tensor RasterizeMeshesBackwardCpu(
for (int n = 0; n < N; ++n) {
// Iterate through the horizontal lines of the image from top to bottom.
for (int y = 0; y < H; ++y) {
// Reverse the order of yi so that +Y is pointing upwards in the image.
const int yidx = H - 1 - y;
// Y coordinate of the top of the pixel.
const float yf = PixToNdc(y, H);
const float yf = PixToNdc(yidx, H);
// Iterate through pixels on this horizontal line, left to right.
for (int x = 0; x < W; ++x) {
// Reverse the order of xi so that +X is pointing to the left in the
// image.
const int xidx = W - 1 - x;
// X coordinate of the left of the pixel.
const float xf = PixToNdc(x, W);
const float xf = PixToNdc(xidx, W);
const vec2<float> pxy(xf, yf);
// Iterate through the faces that hit this pixel.
@@ -376,10 +390,10 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin) {
const int image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@@ -387,6 +401,7 @@ torch::Tensor RasterizeMeshesCoarseCpu(
if (num_faces_per_mesh.ndimension() != 1) {
AT_ERROR("num_faces_per_mesh can only have one dimension");
}
const int N = num_faces_per_mesh.size(0); // batch size.
const int M = max_faces_per_bin;
@@ -415,13 +430,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const int face_stop_idx =
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width;
float bin_y_max = 1.0f;
float bin_y_min = bin_y_max - bin_width;
// Iterate through the horizontal bins from top to bottom.
for (int by = 0; by < BH; ++by) {
float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width;
float bin_x_max = 1.0f;
float bin_x_min = bin_x_max - bin_width;
// Iterate through bins on this horizontal line, left to right.
for (int bx = 0; bx < BW; ++bx) {
@@ -458,13 +473,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
}
}
// Shift the bin to the right for the next loop iteration.
bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width;
// Shift the bin down for the next loop iteration.
bin_x_max = bin_x_min;
bin_x_min = bin_x_min - bin_width;
}
// Shift the bin down for the next loop iteration.
bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width;
// Shift the bin left for the next loop iteration.
bin_y_max = bin_y_min;
bin_y_min = bin_y_min - bin_width;
}
}
return bin_faces;

View File

@@ -38,7 +38,7 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor:
device = fragments.pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
pixel_colors[..., :3] = colors[..., 0, :]
return torch.flip(pixel_colors, [1])
return pixel_colors
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
@@ -80,7 +80,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
alpha = torch.prod((1.0 - prob), dim=-1)
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
pixel_colors[..., 3] = 1.0 - alpha
return torch.flip(pixel_colors, [1])
return pixel_colors
def softmax_rgb_blend(
@@ -125,7 +125,7 @@ def softmax_rgb_blend(
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pix_colors = torch.ones(
pixel_colors = torch.ones(
(N, H, W, 4), dtype=colors.dtype, device=colors.device
)
background = blend_params.background_color
@@ -166,7 +166,7 @@ def softmax_rgb_blend(
# Sum: weights * textures + background color
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
weighted_background = (delta / denom) * background
pix_colors[..., :3] = weighted_colors + weighted_background
pix_colors[..., 3] = 1.0 - alpha
pixel_colors[..., :3] = weighted_colors + weighted_background
pixel_colors[..., 3] = 1.0 - alpha
return torch.flip(pix_colors, [1])
return pixel_colors

View File

@@ -944,7 +944,7 @@ def camera_position_from_spherical_angles(
azim = math.pi / 180.0 * azim
x = dist * torch.cos(elev) * torch.sin(azim)
y = dist * torch.sin(elev)
z = -dist * torch.cos(elev) * torch.cos(azim)
z = dist * torch.cos(elev) * torch.cos(azim)
camera_position = torch.stack([x, y, z], dim=1)
if camera_position.dim() == 0:
camera_position = camera_position.view(1, -1) # add batch dim.

View File

@@ -208,6 +208,11 @@ class _RasterizeFaceVerts(torch.autograd.Function):
return grads
def pix_to_ndc(i, S):
# NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0) / S
def rasterize_meshes_python(
meshes,
image_size: int = 256,
@@ -249,10 +254,6 @@ def rasterize_meshes_python(
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
)
# NDC is from [-1, 1]. Get pixel size using specified image size.
pixel_width = 2.0 / W
pixel_height = 2.0 / H
# Calculate all face bounding boxes.
x_mins = torch.min(faces_verts[:, :, 0], dim=1, keepdim=True).values
x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
@@ -269,14 +270,20 @@ def rasterize_meshes_python(
for n in range(N):
face_start_idx = mesh_to_face_first_idx[n]
face_stop_idx = face_start_idx + num_faces_per_mesh[n]
# Y coordinate of the top of the image.
yf = -1.0 + 0.5 * pixel_height
# Iterate through the horizontal lines of the image from top to bottom.
for yi in range(H):
# X coordinate of the left of the image.
xf = -1.0 + 0.5 * pixel_width
# Y coordinate of one end of the image. Reverse the ordering
# of yi so that +Y is pointing up in the image.
yfix = H - 1 - yi
yf = pix_to_ndc(yfix, H)
# Iterate through pixels on this horizontal line, left to right.
for xi in range(W):
# X coordinate of one end of the image. Reverse the ordering
# of xi so that +X is pointing to the left in the image.
xfix = W - 1 - xi
xf = pix_to_ndc(xfix, H)
top_k_points = []
# Check whether each face in the mesh affects this pixel.
@@ -347,12 +354,6 @@ def rasterize_meshes_python(
bary_coords[n, yi, xi, k, 2] = bary[2]
pix_dists[n, yi, xi, k] = dist
# Move to the next horizontal pixel
xf += pixel_width
# Move to the next vertical pixel
yf += pixel_height
return face_idxs, zbuf, bary_coords, pix_dists

View File

@@ -53,7 +53,7 @@ class MeshRenderer(nn.Module):
if raster_settings.blur_radius > 0.0:
# TODO: potentially move barycentric clipping to the rasterizer
# if no downstream functions requires unclipped values.
# This will avoid unnecssary re-interpolation of the z buffer.
# This will avoid unnecssary re-interpolation of the z buffer.
clipped_bary_coords = _clip_barycentric_coordinates(
fragments.bary_coords
)
@@ -67,4 +67,5 @@ class MeshRenderer(nn.Module):
pix_to_face=fragments.pix_to_face,
)
images = self.shader(fragments, meshes_world, **kwargs)
return images