Non square image rasterization for meshes

Summary:
There are a couple of options for supporting non square images:
1) NDC stays at [-1, 1] in both directions with the distance calculations all modified by (W/H). There are a lot of distance based calculations (e.g. triangle areas for barycentric coordinates etc) so this requires changes in many places.
2) NDC is scaled by (W/H) so the smallest side has [-1, 1]. In this case none of the distance calculations need to be updated and only the pixel to NDC calculation needs to be modified.

I decided to go with option 2 after trying option 1!

API Changes:
- Image size can now be specified optionally as a tuple

TODO:
- add a benchmark test for the non square case.

Reviewed By: jcjohnson

Differential Revision: D24404975

fbshipit-source-id: 545efb67c822d748ec35999b35762bce58db2cf4
This commit is contained in:
Nikhila Ravi 2020-12-09 09:16:57 -08:00 committed by Facebook GitHub Bot
parent 0216e4689a
commit d07307a451
13 changed files with 774 additions and 115 deletions

View File

@ -55,6 +55,19 @@ While we tried to emulate several aspects of OpenGL, there are differences in th
--- ---
### Rasterizing Non Square Images
To rasterize an image where H != W, you can specify the `image_size` in the `RasterizationSettings` as a tuple of (H, W).
The aspect ratio needs special consideration. There are two aspect ratios to be aware of:
- the aspect ratio of each pixel
- the aspect ratio of the output image
In the cameras e.g. `FoVPerspectiveCameras`, the `aspect_ratio` argument can be used to set the pixel aspect ratio. In the rasterizer, we assume square pixels, but variable image aspect ratio (i.e rectangle images).
In most cases you will want to set the camera aspect ratio to 1.0 (i.e. square pixels) and only vary the `image_size` in the `RasterizationSettings`(i.e. the output image dimensions in pixels).
---
### The pulsar backend ### The pulsar backend
Since v0.3, [pulsar](https://arxiv.org/abs/2004.07484) can be used as a backend for point-rendering. It has a focus on efficiency, which comes with pros and cons: it is highly optimized and all rendering stages are integrated in the CUDA kernels. This leads to significantly higher speed and better scaling behavior. We use it at Facebook Reality Labs to render and optimize scenes with millions of spheres in resolutions up to 4K. You can find a runtime comparison plot below (settings: `bin_size=None`, `points_per_pixel=5`, `image_size=1024`, `radius=1e-2`, `composite_params.radius=1e-4`; benchmarked on an RTX 2070 GPU). Since v0.3, [pulsar](https://arxiv.org/abs/2004.07484) can be used as a backend for point-rendering. It has a focus on efficiency, which comes with pros and cons: it is highly optimized and all rendering stages are integrated in the CUDA kernels. This leads to significantly higher speed and better scaling behavior. We use it at Facebook Reality Labs to render and optimize scenes with millions of spheres in resolutions up to 4K. You can find a runtime comparison plot below (settings: `bin_size=None`, `points_per_pixel=5`, `image_size=1024`, `radius=1e-2`, `composite_params.radius=1e-4`; benchmarked on an RTX 2070 GPU).
@ -75,6 +88,8 @@ For mesh texturing we offer several options (in `pytorch3d/renderer/mesh/texturi
<img src="assets/texturing.jpg" width="1000"> <img src="assets/texturing.jpg" width="1000">
---
### A simple renderer ### A simple renderer
A renderer in PyTorch3D is composed of a **rasterizer** and a **shader**. Create a renderer in a few simple steps: A renderer in PyTorch3D is composed of a **rasterizer** and a **shader**. Create a renderer in a few simple steps:
@ -108,6 +123,8 @@ renderer = MeshRenderer(
) )
``` ```
---
### A custom shader ### A custom shader
Shaders are the most flexible part of the PyTorch3D rendering API. We have created some examples of shaders in `shaders.py` but this is a non exhaustive set. Shaders are the most flexible part of the PyTorch3D rendering API. We have created some examples of shaders in `shaders.py` but this is a non exhaustive set.

View File

@ -234,8 +234,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const int xi = W - 1 - pix_idx % W; const int xi = W - 1 - pix_idx % W;
// screen coordinates to ndc coordiantes of pixel. // screen coordinates to ndc coordiantes of pixel.
const float xf = PixToNdc(xi, W); const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNdc(yi, H); const float yf = PixToNonSquareNdc(yi, H, W);
const float2 pxy = make_float2(xf, yf); const float2 pxy = make_float2(xf, yf);
// For keeping track of the K closest points we want a data structure // For keeping track of the K closest points we want a data structure
@ -262,6 +262,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
for (int f = face_start_idx; f < face_stop_idx; ++f) { for (int f = face_start_idx; f < face_stop_idx; ++f) {
// Check if the pixel pxy is inside the face bounding box and if it is, // Check if the pixel pxy is inside the face bounding box and if it is,
// update q, q_size, q_max_z and q_max_idx in place. // update q, q_size, q_max_z and q_max_idx in place.
CheckPixelInsideFace( CheckPixelInsideFace(
face_verts, face_verts,
f, f,
@ -280,6 +281,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// TODO: make sorting an option as only top k is needed, not sorted values. // TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size); BubbleSort(q, q_size);
int idx = n * H * W * K + pix_idx * K; int idx = n * H * W * K + pix_idx * K;
for (int k = 0; k < q_size; ++k) { for (int k = 0; k < q_size; ++k) {
face_idxs[idx + k] = q[k].idx; face_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z; zbuf[idx + k] = q[k].z;
@ -296,7 +298,7 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts, const at::Tensor& face_verts,
const at::Tensor& mesh_to_faces_packed_first_idx, const at::Tensor& mesh_to_faces_packed_first_idx,
const at::Tensor& num_faces_per_mesh, const at::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int num_closest, const int num_closest,
const bool perspective_correct, const bool perspective_correct,
@ -332,8 +334,8 @@ RasterizeMeshesNaiveCuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int N = num_faces_per_mesh.size(0); // batch size. const int N = num_faces_per_mesh.size(0); // batch size.
const int H = image_size; // Assume square images. const int H = std::get<0>(image_size);
const int W = image_size; const int W = std::get<1>(image_size);
const int K = num_closest; const int K = num_closest;
auto long_opts = num_faces_per_mesh.options().dtype(at::kLong); auto long_opts = num_faces_per_mesh.options().dtype(at::kLong);
@ -405,8 +407,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
const int yi = H - 1 - pix_idx / W; const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W; const int xi = W - 1 - pix_idx % W;
const float xf = PixToNdc(xi, W); const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNdc(yi, H); const float yf = PixToNonSquareNdc(yi, H, W);
const float2 pxy = make_float2(xf, yf); const float2 pxy = make_float2(xf, yf);
// Loop over all the faces for this pixel. // Loop over all the faces for this pixel.
@ -589,12 +591,25 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
int* bin_faces) { int* bin_faces) {
extern __shared__ char sbuf[]; extern __shared__ char sbuf[];
const int M = max_faces_per_bin; const int M = max_faces_per_bin;
const int num_bins = 1 + (W - 1) / bin_size; // Integer divide round up // Integer divide round up
const float half_pix = 1.0f / W; // Size of half a pixel in NDC units const int num_bins_x = 1 + (W - 1) / bin_size;
const int num_bins_y = 1 + (H - 1) / bin_size;
// NDC range depends on the ratio of W/H
// The shorter side from (H, W) is given an NDC range of 2.0 and
// the other side is scaled by the ratio of H:W.
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;
// Size of half a pixel in NDC units is the NDC half range
// divided by the corresponding image dimension
const float half_pix_x = NDC_x_half_range / W;
const float half_pix_y = NDC_y_half_range / H;
// This is a boolean array of shape (num_bins, num_bins, chunk_size) // This is a boolean array of shape (num_bins, num_bins, chunk_size)
// stored in shared memory that will track whether each point in the chunk // stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image. // falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size); BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
// Have each block handle a chunk of faces // Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size; const int chunks_per_batch = 1 + (F - 1) / chunk_size;
@ -641,21 +656,24 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
} }
// Brute-force search over all bins; TODO(T54294966) something smarter. // Brute-force search over all bins; TODO(T54294966) something smarter.
for (int by = 0; by < num_bins; ++by) { for (int by = 0; by < num_bins_y; ++by) {
// Y coordinate of the top and bottom of the bin. // Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we // PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin. // need to add/subtract a half pixel to get the true extent of the bin.
// Reverse ordering of Y axis so that +Y is upwards in the image. // Reverse ordering of Y axis so that +Y is upwards in the image.
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix; const float bin_y_min =
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix; PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
const float bin_y_max =
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax); const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
for (int bx = 0; bx < num_bins; ++bx) { for (int bx = 0; bx < num_bins_x; ++bx) {
// X coordinate of the left and right of the bin. // X coordinate of the left and right of the bin.
// Reverse ordering of x axis so that +X is left. // Reverse ordering of x axis so that +X is left.
const float bin_x_max = const float bin_x_max =
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix; PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix; const float bin_x_min =
PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax); const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) { if (y_overlap && x_overlap) {
@ -668,12 +686,13 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Now we have processed every face in the current chunk. We need to // Now we have processed every face in the current chunk. We need to
// count the number of faces in each bin so we can write the indices // count the number of faces in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin. // out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) { for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
const int by = byx / num_bins; byx += blockDim.x) {
const int bx = byx % num_bins; const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;
const int count = binmask.count(by, bx); const int count = binmask.count(by, bx);
const int faces_per_bin_idx = const int faces_per_bin_idx =
batch_idx * num_bins * num_bins + by * num_bins + bx; batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
// This atomically increments the (global) number of faces found // This atomically increments the (global) number of faces found
// in the current bin, and gets the previous value of the counter; // in the current bin, and gets the previous value of the counter;
@ -683,8 +702,8 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Now loop over the binmask and write the active bits for this bin // Now loop over the binmask and write the active bits for this bin
// out to bin_faces. // out to bin_faces.
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M + int next_idx = batch_idx * num_bins_y * num_bins_x * M +
bx * M + start; by * num_bins_x * M + bx * M + start;
for (int f = 0; f < chunk_size; ++f) { for (int f = 0; f < chunk_size; ++f) {
if (binmask.get(by, bx, f)) { if (binmask.get(by, bx, f)) {
// TODO(T54296346) find the correct method for handling errors in // TODO(T54296346) find the correct method for handling errors in
@ -703,7 +722,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
const at::Tensor& face_verts, const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx, const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh, const at::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int max_faces_per_bin) { const int max_faces_per_bin) {
@ -725,21 +744,27 @@ at::Tensor RasterizeMeshesCoarseCuda(
at::cuda::CUDAGuard device_guard(face_verts.device()); at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int W = image_size; const int H = std::get<0>(image_size);
const int H = image_size; const int W = std::get<1>(image_size);
const int F = face_verts.size(0); const int F = face_verts.size(0);
const int N = num_faces_per_mesh.size(0); const int N = num_faces_per_mesh.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
const int M = max_faces_per_bin; const int M = max_faces_per_bin;
if (num_bins >= kMaxFacesPerBin) { // Integer divide round up.
const int num_bins_y = 1 + (H - 1) / bin_size;
const int num_bins_x = 1 + (W - 1) / bin_size;
if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) {
std::stringstream ss; std::stringstream ss;
ss << "Got " << num_bins << "; that's too many!"; ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", "
<< "; that's too many!";
AT_ERROR(ss.str()); AT_ERROR(ss.str());
} }
auto opts = num_faces_per_mesh.options().dtype(at::kInt); auto opts = num_faces_per_mesh.options().dtype(at::kInt);
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts); at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts); at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
if (bin_faces.numel() == 0) { if (bin_faces.numel() == 0) {
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
@ -747,7 +772,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
} }
const int chunk_size = 512; const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8; const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
const size_t blocks = 64; const size_t blocks = 64;
const size_t threads = 512; const size_t threads = 512;
@ -782,7 +807,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
const bool clip_barycentric_coords, const bool clip_barycentric_coords,
const bool cull_backfaces, const bool cull_backfaces,
const int N, const int N,
const int B, const int BH,
const int BW,
const int M, const int M,
const int H, const int H,
const int W, const int W,
@ -793,7 +819,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
float* bary // (N, S, S, K, 3) float* bary // (N, S, S, K, 3)
) { ) {
// This can be more than S^2 if S % bin_size != 0 // This can be more than S^2 if S % bin_size != 0
int num_pixels = N * B * B * bin_size * bin_size; int num_pixels = N * BH * BW * bin_size * bin_size;
int num_threads = gridDim.x * blockDim.x; int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -803,20 +829,26 @@ __global__ void RasterizeMeshesFineCudaKernel(
// into the same bin; this should give them coalesced memory reads when // into the same bin; this should give them coalesced memory reads when
// they read from faces and bin_faces. // they read from faces and bin_faces.
int i = pid; int i = pid;
const int n = i / (B * B * bin_size * bin_size); const int n = i / (BH * BW * bin_size * bin_size);
i %= B * B * bin_size * bin_size; i %= BH * BW * bin_size * bin_size;
const int by = i / (B * bin_size * bin_size); // bin index y
i %= B * bin_size * bin_size; const int by = i / (BW * bin_size * bin_size);
i %= BW * bin_size * bin_size;
// bin index y
const int bx = i / (bin_size * bin_size); const int bx = i / (bin_size * bin_size);
// pixel within the bin
i %= bin_size * bin_size; i %= bin_size * bin_size;
// Pixel x, y indices
const int yi = i / bin_size + by * bin_size; const int yi = i / bin_size + by * bin_size;
const int xi = i % bin_size + bx * bin_size; const int xi = i % bin_size + bx * bin_size;
if (yi >= H || xi >= W) if (yi >= H || xi >= W)
continue; continue;
const float xf = PixToNdc(xi, W); const float xf = PixToNonSquareNdc(xi, W, H);
const float yf = PixToNdc(yi, H); const float yf = PixToNonSquareNdc(yi, H, W);
const float2 pxy = make_float2(xf, yf); const float2 pxy = make_float2(xf, yf);
// This part looks like the naive rasterization kernel, except we use // This part looks like the naive rasterization kernel, except we use
@ -828,7 +860,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
float q_max_z = -1000; float q_max_z = -1000;
int q_max_idx = -1; int q_max_idx = -1;
for (int m = 0; m < M; m++) { for (int m = 0; m < M; m++) {
const int f = bin_faces[n * B * B * M + by * B * M + bx * M + m]; const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m];
if (f < 0) { if (f < 0) {
continue; // bin_faces uses -1 as a sentinal value. continue; // bin_faces uses -1 as a sentinal value.
} }
@ -858,7 +890,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
// in the image +Y is pointing up and +X is pointing left. // in the image +Y is pointing up and +X is pointing left.
const int yidx = H - 1 - yi; const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi; const int xidx = W - 1 - xi;
const int pix_idx = n * H * W * K + yidx * H * K + xidx * K;
const int pix_idx = n * H * W * K + yidx * W * K + xidx * K;
for (int k = 0; k < q_size; k++) { for (int k = 0; k < q_size; k++) {
face_idxs[pix_idx + k] = q[k].idx; face_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z; zbuf[pix_idx + k] = q[k].z;
@ -874,7 +907,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesFineCuda( RasterizeMeshesFineCuda(
const at::Tensor& face_verts, const at::Tensor& face_verts,
const at::Tensor& bin_faces, const at::Tensor& bin_faces,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int faces_per_pixel, const int faces_per_pixel,
@ -897,12 +930,15 @@ RasterizeMeshesFineCuda(
at::cuda::CUDAGuard device_guard(face_verts.device()); at::cuda::CUDAGuard device_guard(face_verts.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// bin_faces shape (N, BH, BW, M)
const int N = bin_faces.size(0); const int N = bin_faces.size(0);
const int B = bin_faces.size(1); const int BH = bin_faces.size(1);
const int BW = bin_faces.size(2);
const int M = bin_faces.size(3); const int M = bin_faces.size(3);
const int K = faces_per_pixel; const int K = faces_per_pixel;
const int H = image_size; // Assume square images only.
const int W = image_size; const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
if (K > kMaxPointsPerPixel) { if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 150"); AT_ERROR("Must have num_closest <= 150");
@ -932,7 +968,8 @@ RasterizeMeshesFineCuda(
clip_barycentric_coords, clip_barycentric_coords,
cull_backfaces, cull_backfaces,
N, N,
B, BH,
BW,
M, M,
H, H,
W, W,

View File

@ -15,7 +15,7 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int faces_per_pixel, const int faces_per_pixel,
const bool perspective_correct, const bool perspective_correct,
@ -28,7 +28,7 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts, const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx, const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh, const at::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int num_closest, const int num_closest,
const bool perspective_correct, const bool perspective_correct,
@ -48,8 +48,8 @@ RasterizeMeshesNaiveCuda(
// the batch where N is the batch size. // the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces // num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch. // for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized. // image_size: Tuple (H, W) giving the size in pixels of the output
// Assume square images only. // image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face // blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur // bounding boxes for the rasterization. Set to 0.0 if no blur
// is required. // is required.
@ -90,7 +90,7 @@ RasterizeMeshesNaive(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int faces_per_pixel, const int faces_per_pixel,
const bool perspective_correct, const bool perspective_correct,
@ -223,7 +223,7 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx, const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh, const at::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int max_faces_per_bin); const int max_faces_per_bin);
@ -233,7 +233,7 @@ torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int max_faces_per_bin); const int max_faces_per_bin);
@ -249,7 +249,8 @@ torch::Tensor RasterizeMeshesCoarseCuda(
// the batch where N is the batch size. // the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces // num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch. // for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized. // image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face // blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur // bounding boxes for the rasterization. Set to 0.0 if no blur
// is required. // is required.
@ -264,7 +265,7 @@ torch::Tensor RasterizeMeshesCoarse(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int max_faces_per_bin) { const int max_faces_per_bin) {
@ -305,7 +306,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda( RasterizeMeshesFineCuda(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& bin_faces, const torch::Tensor& bin_faces,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int faces_per_pixel, const int faces_per_pixel,
@ -321,7 +322,8 @@ RasterizeMeshesFineCuda(
// in NDC coordinates in the range [-1, 1]. // in NDC coordinates in the range [-1, 1].
// bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces // bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces
// that fall into each bin (output from coarse rasterization). // that fall into each bin (output from coarse rasterization).
// image_size: Size in pixels of the output image to be rasterized. // image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face // blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur // bounding boxes for the rasterization. Set to 0.0 if no blur
// is required. // is required.
@ -362,7 +364,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFine( RasterizeMeshesFine(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& bin_faces, const torch::Tensor& bin_faces,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int faces_per_pixel, const int faces_per_pixel,
@ -409,7 +411,8 @@ RasterizeMeshesFine(
// the batch where N is the batch size. // the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces // num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch. // for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized. // image_size: Tuple (H, W) giving the size in pixels of the output
// image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face // blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur // bounding boxes for the rasterization. Set to 0.0 if no blur
// is required. // is required.
@ -453,7 +456,7 @@ RasterizeMeshes(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int faces_per_pixel, const int faces_per_pixel,
const int bin_size, const int bin_size,

View File

@ -9,9 +9,35 @@
#include "utils/vec2.h" #include "utils/vec2.h"
#include "utils/vec3.h" #include "utils/vec3.h"
float PixToNdc(int i, int S) { // The default value of the NDC range is [-1, 1], however in the case that
// NDC x-offset + (i * pixel_width + half_pixel_width) // H != W, the NDC range is set such that the shorter side has range [-1, 1] and
return -1 + (2 * i + 1.0f) / S; // the longer side is scaled by the ratio of H:W. S1 is the dimension for which
// the NDC range is calculated and S2 is the other image dimension.
// e.g. to get the NDC x range S1 = W and S2 = H
float NonSquareNdcRange(int S1, int S2) {
float range = 2.0f;
if (S1 > S2) {
range = ((S1 / S2) * range);
}
return range;
}
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
// coordinates. We divide the NDC range into S1 evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. The dimension of i should be
// S1 and the other image dimension is S2 For example, to get the x and y NDC
// coordinates or a given pixel i:
// x = PixToNonSquareNdc(i, W, H)
// y = PixToNonSquareNdc(i, H, W)
float PixToNonSquareNdc(int i, int S1, int S2) {
float range = NonSquareNdcRange(S1, S2);
// NDC: offset + (i * pixel_width + half_pixel_width)
// The NDC range is [-range/2, range/2].
const float offset = (range / 2.0f);
return -offset + (range * i + offset) / S1;
} }
// Get (x, y, z) values for vertex from (3, 3) tensor face. // Get (x, y, z) values for vertex from (3, 3) tensor face.
@ -108,7 +134,7 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int faces_per_pixel, const int faces_per_pixel,
const bool perspective_correct, const bool perspective_correct,
@ -124,8 +150,8 @@ RasterizeMeshesNaiveCpu(
} }
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size. const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
const int H = image_size; const int H = std::get<0>(image_size);
const int W = image_size; const int W = std::get<1>(image_size);
const int K = faces_per_pixel; const int K = faces_per_pixel;
auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64); auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
@ -163,7 +189,7 @@ RasterizeMeshesNaiveCpu(
const int yidx = H - 1 - yi; const int yidx = H - 1 - yi;
// Y coordinate of the top of the pixel. // Y coordinate of the top of the pixel.
const float yf = PixToNdc(yidx, H); const float yf = PixToNonSquareNdc(yidx, H, W);
// Iterate through pixels on this horizontal line, left to right. // Iterate through pixels on this horizontal line, left to right.
for (int xi = 0; xi < W; ++xi) { for (int xi = 0; xi < W; ++xi) {
// Reverse the order of xi so that +X is pointing to the left in the // Reverse the order of xi so that +X is pointing to the left in the
@ -171,7 +197,7 @@ RasterizeMeshesNaiveCpu(
const int xidx = W - 1 - xi; const int xidx = W - 1 - xi;
// X coordinate of the left of the pixel. // X coordinate of the left of the pixel.
const float xf = PixToNdc(xidx, W); const float xf = PixToNonSquareNdc(xidx, W, H);
// Use a priority queue to hold values: // Use a priority queue to hold values:
// (z, idx, r, bary.x, bary.y. bary.z) // (z, idx, r, bary.x, bary.y. bary.z)
std::priority_queue<std::tuple<float, int, float, float, float, float>> std::priority_queue<std::tuple<float, int, float, float, float, float>>
@ -295,7 +321,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const int yidx = H - 1 - y; const int yidx = H - 1 - y;
// Y coordinate of the top of the pixel. // Y coordinate of the top of the pixel.
const float yf = PixToNdc(yidx, H); const float yf = PixToNonSquareNdc(yidx, H, W);
// Iterate through pixels on this horizontal line, left to right. // Iterate through pixels on this horizontal line, left to right.
for (int x = 0; x < W; ++x) { for (int x = 0; x < W; ++x) {
// Reverse the order of xi so that +X is pointing to the left in the // Reverse the order of xi so that +X is pointing to the left in the
@ -303,7 +329,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const int xidx = W - 1 - x; const int xidx = W - 1 - x;
// X coordinate of the left of the pixel. // X coordinate of the left of the pixel.
const float xf = PixToNdc(xidx, W); const float xf = PixToNonSquareNdc(xidx, W, H);
const vec2<float> pxy(xf, yf); const vec2<float> pxy(xf, yf);
// Iterate through the faces that hit this pixel. // Iterate through the faces that hit this pixel.
@ -353,7 +379,6 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f; const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f; const float sign = inside ? -1.0f : 1.0f;
// TODO(T52813608) Add support for non-square images.
const auto grad_dist_f = PointTriangleDistanceBackward( const auto grad_dist_f = PointTriangleDistanceBackward(
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream); pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
const auto ddist_d_v0 = std::get<1>(grad_dist_f); const auto ddist_d_v0 = std::get<1>(grad_dist_f);
@ -415,7 +440,7 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
const int image_size, const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int max_faces_per_bin) { const int max_faces_per_bin) {
@ -430,11 +455,12 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const int N = num_faces_per_mesh.size(0); // batch size. const int N = num_faces_per_mesh.size(0); // batch size.
const int M = max_faces_per_bin; const int M = max_faces_per_bin;
// Assume square images. TODO(T52813608) Support non square images. const float H = std::get<0>(image_size);
const float height = image_size; const float W = std::get<1>(image_size);
const float width = image_size;
const int BH = 1 + (height - 1) / bin_size; // Integer division round up. // Integer division round up.
const int BW = 1 + (width - 1) / bin_size; // Integer division round up. const int BH = 1 + (H - 1) / bin_size;
const int BW = 1 + (W - 1) / bin_size;
auto opts = num_faces_per_mesh.options().dtype(torch::kInt32); auto opts = num_faces_per_mesh.options().dtype(torch::kInt32);
torch::Tensor faces_per_bin = torch::zeros({N, BH, BW}, opts); torch::Tensor faces_per_bin = torch::zeros({N, BH, BW}, opts);
@ -445,8 +471,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts); auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>(); auto face_bboxes_a = face_bboxes.accessor<float, 2>();
const float pixel_width = 2.0f / image_size; const float ndc_x_range = NonSquareNdcRange(W, H);
const float bin_width = pixel_width * bin_size; const float pixel_width_x = ndc_x_range / W;
const float bin_width_x = pixel_width_x * bin_size;
const float ndc_y_range = NonSquareNdcRange(H, W);
const float pixel_width_y = ndc_y_range / H;
const float bin_width_y = pixel_width_y * bin_size;
// Iterate through the meshes in the batch. // Iterate through the meshes in the batch.
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
@ -455,12 +486,12 @@ torch::Tensor RasterizeMeshesCoarseCpu(
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>()); (face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
float bin_y_min = -1.0f; float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width; float bin_y_max = bin_y_min + bin_width_y;
// Iterate through the horizontal bins from top to bottom. // Iterate through the horizontal bins from top to bottom.
for (int by = 0; by < BH; ++by) { for (int by = 0; by < BH; ++by) {
float bin_x_min = -1.0f; float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width; float bin_x_max = bin_x_min + bin_width_x;
// Iterate through bins on this horizontal line, left to right. // Iterate through bins on this horizontal line, left to right.
for (int bx = 0; bx < BW; ++bx) { for (int bx = 0; bx < BW; ++bx) {
@ -502,11 +533,11 @@ torch::Tensor RasterizeMeshesCoarseCpu(
// Shift the bin to the right for the next loop iteration // Shift the bin to the right for the next loop iteration
bin_x_min = bin_x_max; bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width; bin_x_max = bin_x_min + bin_width_x;
} }
// Shift the bin down for the next loop iteration // Shift the bin down for the next loop iteration
bin_y_min = bin_y_max; bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width; bin_y_max = bin_y_min + bin_width_y;
} }
} }
return bin_faces; return bin_faces;

View File

@ -3,11 +3,44 @@
#pragma once #pragma once
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device // Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinate in the range [-1, 1]. We divide the NDC range into S evenly-sized // coordinates in the range [-1, 1]. We divide the NDC range into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range. // pixels, and assume that each pixel falls in the *center* of its range.
// TODO: delete this function after updating the pointcloud rasterizer to
// support non square images.
__device__ inline float PixToNdc(int i, int S) { __device__ inline float PixToNdc(int i, int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width) // NDC: x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S; return -1.0 + (2 * i + 1.0) / S;
}
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. S1 is the dimension for which
// the NDC range is calculated and S2 is the other image dimension.
// e.g. to get the NDC x range S1 = W and S2 = H
__device__ inline float NonSquareNdcRange(int S1, int S2) {
float range = 2.0f;
if (S1 > S2) {
range = ((S1 / S2) * range);
}
return range;
}
// Given a pixel coordinate 0 <= i < S1, convert it to a normalized device
// coordinates. We divide the NDC range into S1 evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
// The default value of the NDC range is [-1, 1], however in the case that
// H != W, the NDC range is set such that the shorter side has range [-1, 1] and
// the longer side is scaled by the ratio of H:W. The dimension of i should be
// S1 and the other image dimension is S2 For example, to get the x and y NDC
// coordinates or a given pixel i:
// x = PixToNonSquareNdc(i, W, H)
// y = PixToNonSquareNdc(i, H, W)
__device__ inline float PixToNonSquareNdc(int i, int S1, int S2) {
float range = NonSquareNdcRange(S1, S2);
// NDC: offset + (i * pixel_width + half_pixel_width)
// The NDC range is [-range/2, range/2].
float offset = (range / 2.0f);
return -offset + (range * i + offset) / S1;
} }
// The maximum number of points per pixel that we can return. Since we use // The maximum number of points per pixel that we can return. Since we use

View File

@ -2,6 +2,7 @@
import torch import torch
from pytorch3d.structures import Meshes, utils as struct_utils from pytorch3d.structures import Meshes, utils as struct_utils
# ------------------------ Mesh Smoothing ------------------------ # # ------------------------ Mesh Smoothing ------------------------ #
# This file contains differentiable operators to filter meshes # This file contains differentiable operators to filter meshes
# The ops include # The ops include

View File

@ -322,6 +322,22 @@ class FoVPerspectiveCameras(CamerasBase):
and then applies it to the input points. and then applies it to the input points.
The transforms can also be returned separately as Transform3d objects. The transforms can also be returned separately as Transform3d objects.
* Setting the Aspect Ratio for Non Square Images *
If the desired output image size is non square (i.e. a tuple of (H, W) where H != W)
the aspect ratio needs special consideration: There are two aspect ratios
to be aware of:
- the aspect ratio of each pixel
- the aspect ratio of the output image
The `aspect_ratio` setting in the FoVPerspectiveCameras sets the
pixel aspect ratio. When using this camera with the differentiable rasterizer
be aware that in the rasterizer we assume square pixels, but allow
variable image aspect ratio (i.e rectangle images).
In most cases you will want to set the camera `aspect_ratio=1.0`
(i.e. square pixels) and only vary the output image dimensions in pixels
for rasterization.
""" """
def __init__( def __init__(
@ -341,7 +357,8 @@ class FoVPerspectiveCameras(CamerasBase):
Args: Args:
znear: near clipping plane of the view frustrum. znear: near clipping plane of the view frustrum.
zfar: far clipping plane of the view frustrum. zfar: far clipping plane of the view frustrum.
aspect_ratio: ratio of screen_width/screen_height. aspect_ratio: aspect ratio of the image pixels.
1.0 indicates square pixels.
fov: field of view angle of the camera. fov: field of view angle of the camera.
degrees: bool, set to True if fov is specified in degrees. degrees: bool, set to True if fov is specified in degrees.
R: Rotation matrix of shape (N, 3, 3) R: Rotation matrix of shape (N, 3, 3)
@ -376,7 +393,8 @@ class FoVPerspectiveCameras(CamerasBase):
znear: near clipping plane of the view frustrum. znear: near clipping plane of the view frustrum.
zfar: far clipping plane of the view frustrum. zfar: far clipping plane of the view frustrum.
fov: field of view angle of the camera. fov: field of view angle of the camera.
aspect_ratio: ratio of screen_width/screen_height. aspect_ratio: aspect ratio of the image pixels.
1.0 indicates square pixels.
degrees: bool, set to True if fov is specified in degrees. degrees: bool, set to True if fov is specified in degrees.
Returns: Returns:

View File

@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Optional from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -20,7 +20,7 @@ kMaxFacesPerBin = 22
def rasterize_meshes( def rasterize_meshes(
meshes, meshes,
image_size: int = 256, image_size: Union[int, Tuple[int, int]] = 256,
blur_radius: float = 0.0, blur_radius: float = 0.0,
faces_per_pixel: int = 8, faces_per_pixel: int = 8,
bin_size: Optional[int] = None, bin_size: Optional[int] = None,
@ -32,12 +32,25 @@ def rasterize_meshes(
""" """
Rasterize a batch of meshes given the shape of the desired output image. Rasterize a batch of meshes given the shape of the desired output image.
Each mesh is rasterized onto a separate image of shape Each mesh is rasterized onto a separate image of shape
(image_size, image_size). (H, W) if `image_size` is a tuple or (image_size, image_size) if it
is an int.
If the desired image size is non square (i.e. a tuple of (H, W) where H != W)
the aspect ratio needs special consideration. There are two aspect ratios
to be aware of:
- the aspect ratio of each pixel
- the aspect ratio of the output image
The camera can be used to set the pixel aspect ratio. In the rasterizer,
we assume square pixels, but variable image aspect ratio (i.e rectangle images).
In most cases you will want to set the camera aspect ratio to
1.0 (i.e. square pixels) and only vary the
`image_size` (i.e. the output image dimensions in pixels).
Args: Args:
meshes: A Meshes object representing a batch of meshes, batch size N. meshes: A Meshes object representing a batch of meshes, batch size N.
image_size: Size in pixels of the output raster image for each mesh image_size: Size in pixels of the output image to be rasterized.
in the batch. Assumes square images. Can optionally be a tuple of (H, W) in the case of non square images.
blur_radius: Float distance in the range [0, 2] used to expand the face blur_radius: Float distance in the range [0, 2] used to expand the face
bounding boxes for rasterization. Setting blur radius bounding boxes for rasterization. Setting blur radius
results in blurred edges around the shape instead of a results in blurred edges around the shape instead of a
@ -98,6 +111,9 @@ def rasterize_meshes(
squared distance between the pixel (y, x) and the face given squared distance between the pixel (y, x) and the face given
by vertices ``face_verts[f]``. Pixels hit with fewer than by vertices ``face_verts[f]``. Pixels hit with fewer than
``faces_per_pixel`` are padded with -1. ``faces_per_pixel`` are padded with -1.
In the case that image_size is a tuple of (H, W) then the outputs
will be of shape `(N, H, W, ...)`.
""" """
verts_packed = meshes.verts_packed() verts_packed = meshes.verts_packed()
faces_packed = meshes.faces_packed() faces_packed = meshes.faces_packed()
@ -105,6 +121,26 @@ def rasterize_meshes(
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx() mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
num_faces_per_mesh = meshes.num_faces_per_mesh() num_faces_per_mesh = meshes.num_faces_per_mesh()
# In the case that H != W use the max image size to set the bin_size
# to accommodate the num bins constraint in the coarse rasteizer.
# If the ratio of H:W is large this might cause issues as the smaller
# dimension will have fewer bins.
# TODO: consider a better way of setting the bin size.
if isinstance(image_size, (tuple, list)):
if len(image_size) != 2:
raise ValueError("Image size can only be a tuple/list of (H, W)")
if not all(i > 0 for i in image_size):
raise ValueError(
"Image sizes must be greater than 0; got %d, %d" % image_size
)
if not all(type(i) == int for i in image_size):
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
max_image_size = max(*image_size)
im_size = image_size
else:
im_size = (image_size, image_size)
max_image_size = image_size
# TODO: Choose naive vs coarse-to-fine based on mesh size and image size. # TODO: Choose naive vs coarse-to-fine based on mesh size and image size.
if bin_size is None: if bin_size is None:
if not verts_packed.is_cuda: if not verts_packed.is_cuda:
@ -112,20 +148,20 @@ def rasterize_meshes(
bin_size = 0 bin_size = 0
else: else:
# TODO better heuristics for bin size. # TODO better heuristics for bin size.
if image_size <= 64: if max_image_size <= 64:
bin_size = 8 bin_size = 8
else: else:
# Heuristic based formula maps image_size -> bin_size as follows: # Heuristic based formula maps max_image_size -> bin_size as follows:
# image_size < 64 -> 8 # max_image_size < 64 -> 8
# 16 < image_size < 256 -> 16 # 16 < max_image_size < 256 -> 16
# 256 < image_size < 512 -> 32 # 256 < max_image_size < 512 -> 32
# 512 < image_size < 1024 -> 64 # 512 < max_image_size < 1024 -> 64
# 1024 < image_size < 2048 -> 128 # 1024 < max_image_size < 2048 -> 128
bin_size = int(2 ** max(np.ceil(np.log2(image_size)) - 4, 4)) bin_size = int(2 ** max(np.ceil(np.log2(max_image_size)) - 4, 4))
if bin_size != 0: if bin_size != 0:
# There is a limit on the number of faces per bin in the cuda kernel. # There is a limit on the number of faces per bin in the cuda kernel.
faces_per_bin = 1 + (image_size - 1) // bin_size faces_per_bin = 1 + (max_image_size - 1) // bin_size
if faces_per_bin >= kMaxFacesPerBin: if faces_per_bin >= kMaxFacesPerBin:
raise ValueError( raise ValueError(
"bin_size too small, number of faces per bin must be less than %d; got %d" "bin_size too small, number of faces per bin must be less than %d; got %d"
@ -140,7 +176,7 @@ def rasterize_meshes(
face_verts, face_verts,
mesh_to_face_first_idx, mesh_to_face_first_idx,
num_faces_per_mesh, num_faces_per_mesh,
image_size, im_size,
blur_radius, blur_radius,
faces_per_pixel, faces_per_pixel,
bin_size, bin_size,
@ -181,7 +217,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
face_verts, face_verts,
mesh_to_face_first_idx, mesh_to_face_first_idx,
num_faces_per_mesh, num_faces_per_mesh,
image_size: int = 256, image_size: Tuple[int, int] = (256, 256),
blur_radius: float = 0.01, blur_radius: float = 0.01,
faces_per_pixel: int = 0, faces_per_pixel: int = 0,
bin_size: int = 0, bin_size: int = 0,
@ -254,9 +290,53 @@ def pix_to_ndc(i, S):
return -1 + (2 * i + 1.0) / S return -1 + (2 * i + 1.0) / S
def non_square_ndc_range(S1, S2):
"""
In the case of non square images, we scale the NDC range
to maintain the aspect ratio. The smaller dimension has NDC
range of 2.0.
Args:
S1: dimension along with the NDC range is needed
S2: the other image dimension
Returns:
ndc_range: NDC range for dimension S1
"""
ndc_range = 2.0
if S1 > S2:
ndc_range = (S1 / S2) * ndc_range
return ndc_range
def pix_to_non_square_ndc(i, S1, S2):
"""
The default value of the NDC range is [-1, 1].
However in the case of non square images, we scale the NDC range
to maintain the aspect ratio. The smaller dimension has NDC
range from [-1, 1] and the other dimension is scaled by
the ratio of H:W.
e.g. for image size (H, W) = (64, 128)
Height NDC range: [-1, 1]
Width NDC range: [-2, 2]
Args:
i: pixel position on axes S1
S1: dimension along with i is given
S2: the other image dimension
Returns:
pixel: NDC coordinate of point i for dimension S1
"""
# NDC: x-offset + (i * pixel_width + half_pixel_width)
ndc_range = non_square_ndc_range(S1, S2)
offset = ndc_range / 2.0
return -offset + (ndc_range * i + offset) / S1
def rasterize_meshes_python( def rasterize_meshes_python(
meshes, meshes,
image_size: int = 256, image_size: Union[int, Tuple[int, int]] = 256,
blur_radius: float = 0.0, blur_radius: float = 0.0,
faces_per_pixel: int = 8, faces_per_pixel: int = 8,
perspective_correct: bool = False, perspective_correct: bool = False,
@ -271,9 +351,8 @@ def rasterize_meshes_python(
C++/CUDA implementations. C++/CUDA implementations.
""" """
N = len(meshes) N = len(meshes)
# Assume only square images. H, W = image_size if isinstance(image_size, tuple) else (image_size, image_size)
# TODO(T52813608) extend support for non-square images.
H, W = image_size, image_size
K = faces_per_pixel K = faces_per_pixel
device = meshes.device device = meshes.device
@ -319,14 +398,14 @@ def rasterize_meshes_python(
# Y coordinate of one end of the image. Reverse the ordering # Y coordinate of one end of the image. Reverse the ordering
# of yi so that +Y is pointing up in the image. # of yi so that +Y is pointing up in the image.
yfix = H - 1 - yi yfix = H - 1 - yi
yf = pix_to_ndc(yfix, H) yf = pix_to_non_square_ndc(yfix, H, W)
# Iterate through pixels on this horizontal line, left to right. # Iterate through pixels on this horizontal line, left to right.
for xi in range(W): for xi in range(W):
# X coordinate of one end of the image. Reverse the ordering # X coordinate of one end of the image. Reverse the ordering
# of xi so that +X is pointing to the left in the image. # of xi so that +X is pointing to the left in the image.
xfix = W - 1 - xi xfix = W - 1 - xi
xf = pix_to_ndc(xfix, W) xf = pix_to_non_square_ndc(xfix, W, H)
top_k_points = [] top_k_points = []
# Check whether each face in the mesh affects this pixel. # Check whether each face in the mesh affects this pixel.

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Optional from typing import NamedTuple, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -31,7 +31,7 @@ class RasterizationSettings:
def __init__( def __init__(
self, self,
image_size: int = 256, image_size: Union[int, Tuple[int, int]] = 256,
blur_radius: float = 0.0, blur_radius: float = 0.0,
faces_per_pixel: int = 1, faces_per_pixel: int = 1,
bin_size: Optional[int] = None, bin_size: Optional[int] = None,

View File

@ -63,7 +63,8 @@ def bm_rasterize_meshes() -> None:
kwargs_list = [] kwargs_list = []
num_meshes = [8, 16] num_meshes = [8, 16]
ico_level = [4, 5, 6] ico_level = [4, 5, 6]
image_size = [64, 128, 512] # Square and non square cases
image_size = [64, 128, 512, (512, 256), (256, 512)]
blur = [1e-6] blur = [1e-6]
faces_per_pixel = [50] faces_per_pixel = [50]
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel) test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

View File

@ -304,7 +304,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
def test_compare_coarse_cpu_vs_cuda(self): def test_compare_coarse_cpu_vs_cuda(self):
torch.manual_seed(231) torch.manual_seed(231)
N = 1 N = 1
image_size = 512 image_size = (512, 512)
blur_radius = 0.0 blur_radius = 0.0
bin_size = 32 bin_size = 32
max_faces_per_bin = 20 max_faces_per_bin = 20
@ -1077,7 +1077,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self.assertClose(dists, expected_dists) self.assertClose(dists, expected_dists)
def _test_coarse_rasterize(self, device): def _test_coarse_rasterize(self, device):
image_size = 16 image_size = (16, 16)
# No blurring. This test checks that the XY directions are # No blurring. This test checks that the XY directions are
# correctly oriented. # correctly oriented.
blur_radius = 0.0 blur_radius = 0.0

View File

@ -0,0 +1,439 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from itertools import product
from pathlib import Path
import numpy as np
import torch
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.io import load_obj
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh import TexturesUV
from pytorch3d.renderer.mesh.rasterize_meshes import (
rasterize_meshes,
rasterize_meshes_python,
)
from pytorch3d.renderer.mesh.rasterizer import (
Fragments,
MeshRasterizer,
RasterizationSettings,
)
from pytorch3d.renderer.mesh.renderer import MeshRenderer
from pytorch3d.renderer.mesh.shader import BlendParams, SoftPhongShader
from pytorch3d.structures import Meshes
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
# Verts/Faces for a simple mesh with two faces.
verts0 = torch.tensor(
[
[-0.7, -0.70, 1.0],
[0.0, -0.1, 1.0],
[0.7, -0.7, 1.0],
[-0.7, 0.1, 1.0],
[0.0, 0.7, 1.0],
[0.7, 0.1, 1.0],
],
dtype=torch.float32,
)
faces0 = torch.tensor([[1, 0, 2], [4, 3, 5]], dtype=torch.int64)
class TestRasterizeRectanglesErrors(TestCaseMixin, unittest.TestCase):
def test_image_size_arg(self):
meshes = Meshes(verts=[verts0], faces=[faces0])
with self.assertRaises(ValueError) as cm:
rasterize_meshes(
meshes,
(100, 200, 3),
0.0001,
faces_per_pixel=1,
)
self.assertTrue("tuple/list of (H, W)" in cm.msg)
with self.assertRaises(ValueError) as cm:
rasterize_meshes(
meshes,
(0, 10),
0.0001,
faces_per_pixel=1,
)
self.assertTrue("sizes must be positive" in cm.msg)
with self.assertRaises(ValueError) as cm:
rasterize_meshes(
meshes,
(100.5, 120.5),
0.0001,
faces_per_pixel=1,
)
self.assertTrue("sizes must be integers" in cm.msg)
class TestRasterizeRectangles(TestCaseMixin, unittest.TestCase):
@staticmethod
def _clone_mesh(verts0, faces0, device, batch_size):
"""
Helper function to detach and clone the verts/faces.
This is needed in order to set up the tensors for
gradient computation in different tests.
"""
verts = verts0.detach().clone()
verts.requires_grad = True
meshes = Meshes(verts=[verts], faces=[faces0])
meshes = meshes.to(device).extend(batch_size)
return verts, meshes
def _rasterize(self, meshes, image_size, bin_size, blur):
"""
Simple wrapper around the rasterize function to return
the fragment data.
"""
face_idxs, zbuf, bary_coords, pix_dists = rasterize_meshes(
meshes,
image_size,
blur,
faces_per_pixel=1,
bin_size=bin_size,
)
return Fragments(
pix_to_face=face_idxs,
zbuf=zbuf,
bary_coords=bary_coords,
dists=pix_dists,
)
@staticmethod
def _save_debug_image(fragments, image_size, bin_size, blur):
"""
Save a mask image from the rasterization output for debugging.
"""
H, W = image_size
# Save out the last image for debugging
rgb = (fragments.pix_to_face[-1, ..., :3].cpu() > -1).squeeze()
suffix = "square" if H == W else "non_square"
filename = "triangle_%s_bin_size_%s_blur_%.3f_%dx%d.png"
filename = filename % (suffix, str(bin_size), blur, H, W)
if DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
def _check_fragments(self, frag_1, frag_2):
"""
Helper function to check that the tensors in
the Fragments frag_1 and frag_2 are the same.
"""
self.assertClose(frag_1.pix_to_face, frag_2.pix_to_face)
self.assertClose(frag_1.dists, frag_2.dists)
self.assertClose(frag_1.bary_coords, frag_2.bary_coords)
self.assertClose(frag_1.zbuf, frag_2.zbuf)
def _compare_square_with_nonsq(
self,
image_size,
blur,
device,
verts0,
faces0,
nonsq_fragment_gradtensor_list,
batch_size=1,
):
"""
Calculate the output from rasterizing a square image with the minimum of (H, W).
Then compare this with the same square region in the non square image.
The input mesh faces given by faces0 and verts0 are contained within the
[-1, 1] range of the image so all the relevant pixels will be within the square region.
`nonsq_fragment_gradtensor_list` is a list of fragments and verts grad tensors
from rasterizing non square images.
"""
# Rasterize the square version of the image
H, W = image_size
S = min(H, W)
verts_square, meshes_sq = self._clone_mesh(verts0, faces0, device, batch_size)
square_fragments = self._rasterize(
meshes_sq, image_size=(S, S), bin_size=0, blur=blur
)
# Save debug image
self._save_debug_image(square_fragments, (S, S), 0, blur)
# Extract the values in the square image which are non zero.
square_mask = square_fragments.pix_to_face > -1
square_dists = square_fragments.dists[square_mask]
square_zbuf = square_fragments.zbuf[square_mask]
square_bary = square_fragments.bary_coords[square_mask]
# Retain gradients on the output of fragments to check
# intermediate values with the non square outputs.
square_fragments.dists.retain_grad()
square_fragments.bary_coords.retain_grad()
square_fragments.zbuf.retain_grad()
# Calculate gradient for the square image
torch.manual_seed(231)
grad_zbuf = torch.randn_like(square_zbuf)
grad_dist = torch.randn_like(square_dists)
grad_bary = torch.randn_like(square_bary)
loss0 = (
(grad_dist * square_dists).sum()
+ (grad_zbuf * square_zbuf).sum()
+ (grad_bary * square_bary).sum()
)
loss0.backward()
# Now compare against the non square outputs provided
# in the nonsq_fragment_gradtensor_list list
for fragments, grad_tensor, _name in nonsq_fragment_gradtensor_list:
# Check that there are the same number of non zero pixels
# in both the square and non square images.
non_square_mask = fragments.pix_to_face > -1
self.assertEqual(non_square_mask.sum().item(), square_mask.sum().item())
# Check dists, zbuf and bary match the square image
non_square_dists = fragments.dists[non_square_mask]
non_square_zbuf = fragments.zbuf[non_square_mask]
non_square_bary = fragments.bary_coords[non_square_mask]
self.assertClose(square_dists, non_square_dists)
self.assertClose(square_zbuf, non_square_zbuf)
self.assertClose(
square_bary,
non_square_bary,
atol=2e-7,
)
# Retain gradients to compare values with outputs from
# square image
fragments.dists.retain_grad()
fragments.bary_coords.retain_grad()
fragments.zbuf.retain_grad()
loss1 = (
(grad_dist * non_square_dists).sum()
+ (grad_zbuf * non_square_zbuf).sum()
+ (grad_bary * non_square_bary).sum()
)
loss1.sum().backward()
# Get the non zero values in the intermediate gradients
# and compare with the values from the square image
non_square_grad_dists = fragments.dists.grad[non_square_mask]
non_square_grad_bary = fragments.bary_coords.grad[non_square_mask]
non_square_grad_zbuf = fragments.zbuf.grad[non_square_mask]
self.assertClose(
non_square_grad_dists,
square_fragments.dists.grad[square_mask],
)
self.assertClose(
non_square_grad_bary,
square_fragments.bary_coords.grad[square_mask],
)
self.assertClose(
non_square_grad_zbuf,
square_fragments.zbuf.grad[square_mask],
)
# Finally check the gradients of the input vertices for
# the square and non square case
self.assertClose(verts_square.grad, grad_tensor.grad, rtol=2e-4)
def test_gpu(self):
"""
Test that the output of rendering non square images
gives the same result as square images. i.e. the
dists, zbuf, bary are all the same for the square
region which is present in both images.
"""
# Test both cases: (W > H), (H > W)
image_sizes = [(64, 128), (128, 64), (128, 256), (256, 128)]
devices = ["cuda:0"]
blurs = [0.0, 0.001]
batch_sizes = [1, 4]
test_cases = product(image_sizes, blurs, devices, batch_sizes)
for image_size, blur, device, batch_size in test_cases:
# Initialize the verts grad tensor and the meshes objects
verts_nonsq_naive, meshes_nonsq_naive = self._clone_mesh(
verts0, faces0, device, batch_size
)
verts_nonsq_binned, meshes_nonsq_binned = self._clone_mesh(
verts0, faces0, device, batch_size
)
# Get the outputs for both naive and coarse to fine rasterization
fragments_naive = self._rasterize(
meshes_nonsq_naive,
image_size,
blur=blur,
bin_size=0,
)
fragments_binned = self._rasterize(
meshes_nonsq_binned,
image_size,
blur=blur,
bin_size=None,
)
# Save out debug images if needed
self._save_debug_image(fragments_naive, image_size, 0, blur)
self._save_debug_image(fragments_binned, image_size, None, blur)
# Check naive and binned fragments give the same outputs
self._check_fragments(fragments_naive, fragments_binned)
# Here we want to compare the square image with the naive and the
# coarse to fine methods outputs
nonsq_fragment_gradtensor_list = [
(fragments_naive, verts_nonsq_naive, "naive"),
(fragments_binned, verts_nonsq_binned, "coarse-to-fine"),
]
self._compare_square_with_nonsq(
image_size,
blur,
device,
verts0,
faces0,
nonsq_fragment_gradtensor_list,
batch_size,
)
def test_cpu(self):
"""
Test that the output of rendering non square images
gives the same result as square images. i.e. the
dists, zbuf, bary are all the same for the square
region which is present in both images.
In this test we compare between the naive C++ implementation
and the naive python implementation as the Coarse/Fine
method is not fully implemented in C++
"""
# Test both when (W > H) and (H > W).
# Using smaller image sizes here as the Python rasterizer is really slow.
image_sizes = [(32, 64), (64, 32)]
devices = ["cpu"]
blurs = [0.0, 0.001]
batch_sizes = [1]
test_cases = product(image_sizes, blurs, devices, batch_sizes)
for image_size, blur, device, batch_size in test_cases:
# Initialize the verts grad tensor and the meshes objects
verts_nonsq_naive, meshes_nonsq_naive = self._clone_mesh(
verts0, faces0, device, batch_size
)
verts_nonsq_python, meshes_nonsq_python = self._clone_mesh(
verts0, faces0, device, batch_size
)
# Compare Naive CPU with Python as Coarse/Fine rasteriztation
# is not implemented for CPU
fragments_naive = self._rasterize(
meshes_nonsq_naive, image_size, bin_size=0, blur=blur
)
face_idxs, zbuf, bary_coords, pix_dists = rasterize_meshes_python(
meshes_nonsq_python,
image_size,
blur,
faces_per_pixel=1,
)
fragments_python = Fragments(
pix_to_face=face_idxs,
zbuf=zbuf,
bary_coords=bary_coords,
dists=pix_dists,
)
# Save debug images if DEBUG is set to true at the top of the file.
self._save_debug_image(fragments_naive, image_size, 0, blur)
self._save_debug_image(fragments_python, image_size, "python", blur)
# List of non square outputs to compare with the square output
nonsq_fragment_gradtensor_list = [
(fragments_naive, verts_nonsq_naive, "naive"),
(fragments_python, verts_nonsq_python, "python"),
]
self._compare_square_with_nonsq(
image_size,
blur,
device,
verts0,
faces0,
nonsq_fragment_gradtensor_list,
batch_size,
)
def test_render_cow(self):
"""
Test a larger textured mesh is rendered correctly in a non square image.
"""
device = torch.device("cuda:0")
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh + texture
verts, faces, aux = load_obj(
obj_filename, device=device, load_textures=True, texture_wrap=None
)
tex_map = list(aux.texture_images.values())[0]
tex_map = tex_map[None, ...].to(faces.textures_idx.device)
textures = TexturesUV(
maps=tex_map, faces_uvs=[faces.textures_idx], verts_uvs=[aux.verts_uvs]
)
mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=textures)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 180)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=(512, 1024), blur_radius=0.0, faces_per_pixel=1
)
# Init shader settings
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
blend_params = BlendParams(
sigma=1e-1,
gamma=1e-4,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
# Init renderer
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=SoftPhongShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
),
)
# Load reference image
image_ref = load_rgb_image("test_cow_image_rectangle.png", DATA_DIR)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_cow_image_rectangle.png"
)
# NOTE some pixels can be flaky
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
self.assertTrue(cond1)