mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
Fix coordinate system conventions in renderer
Summary:
## Updates
- Defined the world and camera coordinates according to this figure. The world coordinates are defined as having +Y up, +X left and +Z in.
{F230888499}
- Removed all flipping from blending functions.
- Updated the rasterizer to return images with +Y up and +X left.
- Updated all the mesh rasterizer tests
- The expected values are now defined in terms of the default +Y up, +X left
- Added tests where the triangles in the meshes are non symmetrical so that it is clear which direction +X and +Y are
## Questions:
- Should we have **scene settings** instead of raster settings?
- To be more correct we should be [z clipping in the rasterizer based on the far/near clipping planes](https://github.com/ShichenLiu/SoftRas/blob/master/soft_renderer/cuda/soft_rasterize_cuda_kernel.cu#L400) - these values are also required in the blending functions so should we make these scene level parameters and have a scene settings tuple which is available to the rasterizer and shader?
Reviewed By: gkioxari
Differential Revision: D20208604
fbshipit-source-id: 55787301b1bffa0afa9618f0a0886cc681da51f3
This commit is contained in:
committed by
Facebook Github Bot
parent
767d68a3af
commit
15c72be444
@@ -189,12 +189,12 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
const float* face_verts,
|
||||
const int64_t* mesh_to_face_first_idx,
|
||||
const int64_t* num_faces_per_mesh,
|
||||
float blur_radius,
|
||||
bool perspective_correct,
|
||||
int N,
|
||||
int H,
|
||||
int W,
|
||||
int K,
|
||||
const float blur_radius,
|
||||
const bool perspective_correct,
|
||||
const int N,
|
||||
const int H,
|
||||
const int W,
|
||||
const int K,
|
||||
int64_t* face_idxs,
|
||||
float* zbuf,
|
||||
float* pix_dists,
|
||||
@@ -207,8 +207,10 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
// Convert linear index to 3D index
|
||||
const int n = i / (H * W); // batch index.
|
||||
const int pix_idx = i % (H * W);
|
||||
const int yi = pix_idx / H;
|
||||
const int xi = pix_idx % W;
|
||||
|
||||
// Determine ordering based on axis convention.
|
||||
const int yi = H - 1 - pix_idx / W;
|
||||
const int xi = W - 1 - pix_idx % W;
|
||||
|
||||
// screen coordinates to ndc coordiantes of pixel.
|
||||
const float xf = PixToNdc(xi, W);
|
||||
@@ -254,7 +256,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
|
||||
// TODO: make sorting an option as only top k is needed, not sorted values.
|
||||
BubbleSort(q, q_size);
|
||||
int idx = n * H * W * K + yi * H * K + xi * K;
|
||||
int idx = n * H * W * K + pix_idx * K;
|
||||
for (int k = 0; k < q_size; ++k) {
|
||||
face_idxs[idx + k] = q[k].idx;
|
||||
zbuf[idx + k] = q[k].z;
|
||||
@@ -274,7 +276,7 @@ RasterizeMeshesNaiveCuda(
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
bool perspective_correct) {
|
||||
const bool perspective_correct) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
@@ -331,12 +333,12 @@ RasterizeMeshesNaiveCuda(
|
||||
__global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int64_t* pix_to_face, // (N, H, W, K)
|
||||
bool perspective_correct,
|
||||
int N,
|
||||
int F,
|
||||
int H,
|
||||
int W,
|
||||
int K,
|
||||
const bool perspective_correct,
|
||||
const int N,
|
||||
const int F,
|
||||
const int H,
|
||||
const int W,
|
||||
const int K,
|
||||
const float* grad_zbuf, // (N, H, W, K)
|
||||
const float* grad_bary, // (N, H, W, K, 3)
|
||||
const float* grad_dists, // (N, H, W, K)
|
||||
@@ -351,8 +353,11 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
// Convert linear index to 3D index
|
||||
const int n = t_i / (H * W); // batch index.
|
||||
const int pix_idx = t_i % (H * W);
|
||||
const int yi = pix_idx / H;
|
||||
const int xi = pix_idx % W;
|
||||
|
||||
// Determine ordering based on axis convention.
|
||||
const int yi = H - 1 - pix_idx / W;
|
||||
const int xi = W - 1 - pix_idx % W;
|
||||
|
||||
const float xf = PixToNdc(xi, W);
|
||||
const float yf = PixToNdc(yi, H);
|
||||
const float2 pxy = make_float2(xf, yf);
|
||||
@@ -360,8 +365,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
// Loop over all the faces for this pixel.
|
||||
for (int k = 0; k < K; k++) {
|
||||
// Index into (N, H, W, K, :) grad tensors
|
||||
const int i =
|
||||
n * H * W * K + yi * H * K + xi * K + k; // pixel index + face index
|
||||
// pixel index + top k index
|
||||
int i = n * H * W * K + pix_idx * K + k;
|
||||
|
||||
const int f = pix_to_face[i];
|
||||
if (f < 0) {
|
||||
@@ -451,7 +456,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const torch::Tensor& grad_dists, // (N, H, W, K)
|
||||
bool perspective_correct) {
|
||||
const bool perspective_correct) {
|
||||
const int F = face_verts.size(0);
|
||||
const int N = pix_to_face.size(0);
|
||||
const int H = pix_to_face.size(1);
|
||||
@@ -509,6 +514,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
||||
// Have each block handle a chunk of faces
|
||||
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
|
||||
const int num_chunks = N * chunks_per_batch;
|
||||
|
||||
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
||||
const int batch_idx = chunk / chunks_per_batch; // batch index
|
||||
const int chunk_idx = chunk % chunks_per_batch;
|
||||
@@ -551,17 +557,21 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
||||
// Y coordinate of the top and bottom of the bin.
|
||||
// PixToNdc gives the location of the center of each pixel, so we
|
||||
// need to add/subtract a half pixel to get the true extent of the bin.
|
||||
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
|
||||
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
|
||||
// Reverse ordering of Y axis so that +Y is upwards in the image.
|
||||
const int yidx = num_bins - by;
|
||||
float bin_y_max = PixToNdc(yidx * bin_size - 1, H) + half_pix;
|
||||
float bin_y_min = PixToNdc((yidx - 1) * bin_size, H) - half_pix;
|
||||
|
||||
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
|
||||
|
||||
for (int bx = 0; bx < num_bins; ++bx) {
|
||||
// X coordinate of the left and right of the bin.
|
||||
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
|
||||
const float bin_x_max =
|
||||
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
|
||||
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
|
||||
// Reverse ordering of x axis so that +X is left.
|
||||
const int xidx = num_bins - bx;
|
||||
float bin_x_max = PixToNdc(xidx * bin_size - 1, W) + half_pix;
|
||||
float bin_x_min = PixToNdc((xidx - 1) * bin_size, W) - half_pix;
|
||||
|
||||
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
|
||||
if (y_overlap && x_overlap) {
|
||||
binmask.set(by, bx, f);
|
||||
}
|
||||
@@ -654,7 +664,6 @@ torch::Tensor RasterizeMeshesCoarseCuda(
|
||||
// ****************************************************************************
|
||||
// * FINE RASTERIZATION *
|
||||
// ****************************************************************************
|
||||
|
||||
__global__ void RasterizeMeshesFineCudaKernel(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int32_t* bin_faces, // (N, B, B, T)
|
||||
@@ -695,8 +704,14 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
|
||||
if (yi >= H || xi >= W)
|
||||
continue;
|
||||
const float xf = PixToNdc(xi, W);
|
||||
const float yf = PixToNdc(yi, H);
|
||||
|
||||
// Reverse ordering of the X and Y axis so that
|
||||
// in the image +Y is pointing up and +X is pointing left.
|
||||
const int yidx = H - 1 - yi;
|
||||
const int xidx = W - 1 - xi;
|
||||
|
||||
const float xf = PixToNdc(xidx, W);
|
||||
const float yf = PixToNdc(yidx, H);
|
||||
const float2 pxy = make_float2(xf, yf);
|
||||
|
||||
// This part looks like the naive rasterization kernel, except we use
|
||||
@@ -751,7 +766,7 @@ RasterizeMeshesFineCuda(
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const int faces_per_pixel,
|
||||
bool perspective_correct) {
|
||||
const bool perspective_correct) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
|
||||
@@ -14,10 +14,10 @@ RasterizeMeshesNaiveCpu(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int faces_per_pixel,
|
||||
bool perspective_correct);
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
@@ -25,10 +25,10 @@ RasterizeMeshesNaiveCuda(
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& mesh_to_face_first_idx,
|
||||
const at::Tensor& num_faces_per_mesh,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int num_closest,
|
||||
bool perspective_correct);
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
const bool perspective_correct);
|
||||
#endif
|
||||
// Forward pass for rasterizing a batch of meshes.
|
||||
//
|
||||
@@ -77,10 +77,10 @@ RasterizeMeshesNaive(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int faces_per_pixel,
|
||||
bool perspective_correct) {
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct) {
|
||||
// TODO: Better type checking.
|
||||
if (face_verts.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
@@ -117,7 +117,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
||||
const torch::Tensor& grad_bary,
|
||||
const torch::Tensor& grad_zbuf,
|
||||
const torch::Tensor& grad_dists,
|
||||
bool perspective_correct);
|
||||
const bool perspective_correct);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
@@ -126,7 +126,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
|
||||
const torch::Tensor& grad_bary,
|
||||
const torch::Tensor& grad_zbuf,
|
||||
const torch::Tensor& grad_dists,
|
||||
bool perspective_correct);
|
||||
const bool perspective_correct);
|
||||
#endif
|
||||
|
||||
// Args:
|
||||
@@ -159,7 +159,7 @@ torch::Tensor RasterizeMeshesBackward(
|
||||
const torch::Tensor& grad_zbuf,
|
||||
const torch::Tensor& grad_bary,
|
||||
const torch::Tensor& grad_dists,
|
||||
bool perspective_correct) {
|
||||
const bool perspective_correct) {
|
||||
if (face_verts.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return RasterizeMeshesBackwardCuda(
|
||||
@@ -191,20 +191,20 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
||||
const torch::Tensor& face_verts,
|
||||
const at::Tensor& mesh_to_face_first_idx,
|
||||
const at::Tensor& num_faces_per_mesh,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int bin_size,
|
||||
int max_faces_per_bin);
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const int max_faces_per_bin);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor RasterizeMeshesCoarseCuda(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int bin_size,
|
||||
int max_faces_per_bin);
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const int max_faces_per_bin);
|
||||
#endif
|
||||
// Args:
|
||||
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
|
||||
@@ -232,10 +232,10 @@ torch::Tensor RasterizeMeshesCoarse(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int bin_size,
|
||||
int max_faces_per_bin) {
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const int max_faces_per_bin) {
|
||||
if (face_verts.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return RasterizeMeshesCoarseCuda(
|
||||
@@ -270,11 +270,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
RasterizeMeshesFineCuda(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& bin_faces,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int bin_size,
|
||||
int faces_per_pixel,
|
||||
bool perspective_correct);
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct);
|
||||
#endif
|
||||
// Args:
|
||||
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
|
||||
@@ -317,11 +317,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
RasterizeMeshesFine(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& bin_faces,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int bin_size,
|
||||
int faces_per_pixel,
|
||||
bool perspective_correct) {
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct) {
|
||||
if (face_verts.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return RasterizeMeshesFineCuda(
|
||||
@@ -373,6 +373,7 @@ RasterizeMeshesFine(
|
||||
// this function instead returns screen-space
|
||||
// barycentric coordinates for each pixel.
|
||||
//
|
||||
//
|
||||
// Returns:
|
||||
// A 4 element tuple of:
|
||||
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
|
||||
@@ -394,12 +395,12 @@ RasterizeMeshes(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int faces_per_pixel,
|
||||
int bin_size,
|
||||
int max_faces_per_bin,
|
||||
bool perspective_correct) {
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
const int bin_size,
|
||||
const int max_faces_per_bin,
|
||||
const bool perspective_correct) {
|
||||
if (bin_size > 0 && max_faces_per_bin > 0) {
|
||||
// Use coarse-to-fine rasterization
|
||||
auto bin_faces = RasterizeMeshesCoarse(
|
||||
|
||||
@@ -105,9 +105,9 @@ RasterizeMeshesNaiveCpu(
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int faces_per_pixel,
|
||||
bool perspective_correct) {
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
const bool perspective_correct) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
@@ -153,12 +153,19 @@ RasterizeMeshesNaiveCpu(
|
||||
|
||||
// Iterate through the horizontal lines of the image from top to bottom.
|
||||
for (int yi = 0; yi < H; ++yi) {
|
||||
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
||||
const int yidx = H - 1 - yi;
|
||||
|
||||
// Y coordinate of the top of the pixel.
|
||||
const float yf = PixToNdc(yi, H);
|
||||
const float yf = PixToNdc(yidx, H);
|
||||
// Iterate through pixels on this horizontal line, left to right.
|
||||
for (int xi = 0; xi < W; ++xi) {
|
||||
// Reverse the order of xi so that +X is pointing to the left in the
|
||||
// image.
|
||||
const int xidx = W - 1 - xi;
|
||||
|
||||
// X coordinate of the left of the pixel.
|
||||
const float xf = PixToNdc(xi, W);
|
||||
const float xf = PixToNdc(xidx, W);
|
||||
// Use a priority queue to hold values:
|
||||
// (z, idx, r, bary.x, bary.y. bary.z)
|
||||
std::priority_queue<std::tuple<float, int, float, float, float, float>>
|
||||
@@ -250,7 +257,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||
const torch::Tensor& grad_dists, // (N, H, W, K)
|
||||
bool perspective_correct) {
|
||||
const bool perspective_correct) {
|
||||
const int F = face_verts.size(0);
|
||||
const int N = pix_to_face.size(0);
|
||||
const int H = pix_to_face.size(1);
|
||||
@@ -267,12 +274,19 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Iterate through the horizontal lines of the image from top to bottom.
|
||||
for (int y = 0; y < H; ++y) {
|
||||
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
||||
const int yidx = H - 1 - y;
|
||||
|
||||
// Y coordinate of the top of the pixel.
|
||||
const float yf = PixToNdc(y, H);
|
||||
const float yf = PixToNdc(yidx, H);
|
||||
// Iterate through pixels on this horizontal line, left to right.
|
||||
for (int x = 0; x < W; ++x) {
|
||||
// Reverse the order of xi so that +X is pointing to the left in the
|
||||
// image.
|
||||
const int xidx = W - 1 - x;
|
||||
|
||||
// X coordinate of the left of the pixel.
|
||||
const float xf = PixToNdc(x, W);
|
||||
const float xf = PixToNdc(xidx, W);
|
||||
const vec2<float> pxy(xf, yf);
|
||||
|
||||
// Iterate through the faces that hit this pixel.
|
||||
@@ -376,10 +390,10 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
int image_size,
|
||||
float blur_radius,
|
||||
int bin_size,
|
||||
int max_faces_per_bin) {
|
||||
const int image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const int max_faces_per_bin) {
|
||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||
face_verts.size(2) != 3) {
|
||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||
@@ -387,6 +401,7 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
||||
if (num_faces_per_mesh.ndimension() != 1) {
|
||||
AT_ERROR("num_faces_per_mesh can only have one dimension");
|
||||
}
|
||||
|
||||
const int N = num_faces_per_mesh.size(0); // batch size.
|
||||
const int M = max_faces_per_bin;
|
||||
|
||||
@@ -415,13 +430,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
||||
const int face_stop_idx =
|
||||
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
|
||||
|
||||
float bin_y_min = -1.0f;
|
||||
float bin_y_max = bin_y_min + bin_width;
|
||||
float bin_y_max = 1.0f;
|
||||
float bin_y_min = bin_y_max - bin_width;
|
||||
|
||||
// Iterate through the horizontal bins from top to bottom.
|
||||
for (int by = 0; by < BH; ++by) {
|
||||
float bin_x_min = -1.0f;
|
||||
float bin_x_max = bin_x_min + bin_width;
|
||||
float bin_x_max = 1.0f;
|
||||
float bin_x_min = bin_x_max - bin_width;
|
||||
|
||||
// Iterate through bins on this horizontal line, left to right.
|
||||
for (int bx = 0; bx < BW; ++bx) {
|
||||
@@ -458,13 +473,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
||||
}
|
||||
}
|
||||
|
||||
// Shift the bin to the right for the next loop iteration.
|
||||
bin_x_min = bin_x_max;
|
||||
bin_x_max = bin_x_min + bin_width;
|
||||
// Shift the bin down for the next loop iteration.
|
||||
bin_x_max = bin_x_min;
|
||||
bin_x_min = bin_x_min - bin_width;
|
||||
}
|
||||
// Shift the bin down for the next loop iteration.
|
||||
bin_y_min = bin_y_max;
|
||||
bin_y_max = bin_y_min + bin_width;
|
||||
// Shift the bin left for the next loop iteration.
|
||||
bin_y_max = bin_y_min;
|
||||
bin_y_min = bin_y_min - bin_width;
|
||||
}
|
||||
}
|
||||
return bin_faces;
|
||||
|
||||
@@ -38,7 +38,7 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor:
|
||||
device = fragments.pix_to_face.device
|
||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
|
||||
pixel_colors[..., :3] = colors[..., 0, :]
|
||||
return torch.flip(pixel_colors, [1])
|
||||
return pixel_colors
|
||||
|
||||
|
||||
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
@@ -80,7 +80,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
alpha = torch.prod((1.0 - prob), dim=-1)
|
||||
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
|
||||
pixel_colors[..., 3] = 1.0 - alpha
|
||||
return torch.flip(pixel_colors, [1])
|
||||
return pixel_colors
|
||||
|
||||
|
||||
def softmax_rgb_blend(
|
||||
@@ -125,7 +125,7 @@ def softmax_rgb_blend(
|
||||
|
||||
N, H, W, K = fragments.pix_to_face.shape
|
||||
device = fragments.pix_to_face.device
|
||||
pix_colors = torch.ones(
|
||||
pixel_colors = torch.ones(
|
||||
(N, H, W, 4), dtype=colors.dtype, device=colors.device
|
||||
)
|
||||
background = blend_params.background_color
|
||||
@@ -166,7 +166,7 @@ def softmax_rgb_blend(
|
||||
# Sum: weights * textures + background color
|
||||
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
|
||||
weighted_background = (delta / denom) * background
|
||||
pix_colors[..., :3] = weighted_colors + weighted_background
|
||||
pix_colors[..., 3] = 1.0 - alpha
|
||||
pixel_colors[..., :3] = weighted_colors + weighted_background
|
||||
pixel_colors[..., 3] = 1.0 - alpha
|
||||
|
||||
return torch.flip(pix_colors, [1])
|
||||
return pixel_colors
|
||||
|
||||
@@ -944,7 +944,7 @@ def camera_position_from_spherical_angles(
|
||||
azim = math.pi / 180.0 * azim
|
||||
x = dist * torch.cos(elev) * torch.sin(azim)
|
||||
y = dist * torch.sin(elev)
|
||||
z = -dist * torch.cos(elev) * torch.cos(azim)
|
||||
z = dist * torch.cos(elev) * torch.cos(azim)
|
||||
camera_position = torch.stack([x, y, z], dim=1)
|
||||
if camera_position.dim() == 0:
|
||||
camera_position = camera_position.view(1, -1) # add batch dim.
|
||||
|
||||
@@ -208,6 +208,11 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
return grads
|
||||
|
||||
|
||||
def pix_to_ndc(i, S):
|
||||
# NDC x-offset + (i * pixel_width + half_pixel_width)
|
||||
return -1 + (2 * i + 1.0) / S
|
||||
|
||||
|
||||
def rasterize_meshes_python(
|
||||
meshes,
|
||||
image_size: int = 256,
|
||||
@@ -249,10 +254,6 @@ def rasterize_meshes_python(
|
||||
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
# NDC is from [-1, 1]. Get pixel size using specified image size.
|
||||
pixel_width = 2.0 / W
|
||||
pixel_height = 2.0 / H
|
||||
|
||||
# Calculate all face bounding boxes.
|
||||
x_mins = torch.min(faces_verts[:, :, 0], dim=1, keepdim=True).values
|
||||
x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
|
||||
@@ -269,14 +270,20 @@ def rasterize_meshes_python(
|
||||
for n in range(N):
|
||||
face_start_idx = mesh_to_face_first_idx[n]
|
||||
face_stop_idx = face_start_idx + num_faces_per_mesh[n]
|
||||
# Y coordinate of the top of the image.
|
||||
yf = -1.0 + 0.5 * pixel_height
|
||||
|
||||
# Iterate through the horizontal lines of the image from top to bottom.
|
||||
for yi in range(H):
|
||||
# X coordinate of the left of the image.
|
||||
xf = -1.0 + 0.5 * pixel_width
|
||||
# Y coordinate of one end of the image. Reverse the ordering
|
||||
# of yi so that +Y is pointing up in the image.
|
||||
yfix = H - 1 - yi
|
||||
yf = pix_to_ndc(yfix, H)
|
||||
|
||||
# Iterate through pixels on this horizontal line, left to right.
|
||||
for xi in range(W):
|
||||
# X coordinate of one end of the image. Reverse the ordering
|
||||
# of xi so that +X is pointing to the left in the image.
|
||||
xfix = W - 1 - xi
|
||||
xf = pix_to_ndc(xfix, H)
|
||||
top_k_points = []
|
||||
|
||||
# Check whether each face in the mesh affects this pixel.
|
||||
@@ -347,12 +354,6 @@ def rasterize_meshes_python(
|
||||
bary_coords[n, yi, xi, k, 2] = bary[2]
|
||||
pix_dists[n, yi, xi, k] = dist
|
||||
|
||||
# Move to the next horizontal pixel
|
||||
xf += pixel_width
|
||||
|
||||
# Move to the next vertical pixel
|
||||
yf += pixel_height
|
||||
|
||||
return face_idxs, zbuf, bary_coords, pix_dists
|
||||
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ class MeshRenderer(nn.Module):
|
||||
if raster_settings.blur_radius > 0.0:
|
||||
# TODO: potentially move barycentric clipping to the rasterizer
|
||||
# if no downstream functions requires unclipped values.
|
||||
# This will avoid unnecssary re-interpolation of the z buffer.
|
||||
# This will avoid unnecssary re-interpolation of the z buffer.
|
||||
clipped_bary_coords = _clip_barycentric_coordinates(
|
||||
fragments.bary_coords
|
||||
)
|
||||
@@ -67,4 +67,5 @@ class MeshRenderer(nn.Module):
|
||||
pix_to_face=fragments.pix_to_face,
|
||||
)
|
||||
images = self.shader(fragments, meshes_world, **kwargs)
|
||||
|
||||
return images
|
||||
|
||||
Reference in New Issue
Block a user