diff --git a/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu b/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu new file mode 100644 index 00000000..7554e874 --- /dev/null +++ b/pytorch3d/csrc/blending/sigmoid_alpha_blend.cu @@ -0,0 +1,210 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include +#include +#include +#include +#include + +template +__global__ void SigmoidAlphaBlendForwardKernel( + // clang-format off + const torch::PackedTensorAccessor distances, // (N, H, W, K) + const torch::PackedTensorAccessor pix_to_face, // (N, H, W, K) + torch::PackedTensorAccessor alphas, // (N, H, W) + // clang-format on + const scalar_t sigma, + const int N, + const int H, + const int W, + const int K) { + // Parallelize over each pixel in images of + // size H * W, for each image in the batch of size N. + const int num_threads = gridDim.x * blockDim.x; + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + + // TODO: revisit performance of this kernel with shared memory usage + + for (int t_i = tid; t_i < N * H * W; t_i += num_threads) { + // Convert linear index to 3D index + const int n = t_i / (H * W); // batch index. + const int pix_idx = t_i % (H * W); + + // TODO: fix index calculation for non square images. + const int yi = pix_idx / W; + const int xi = pix_idx % W; + scalar_t alpha = 1.0; + + // Loop over all the faces for this pixel. + for (int k = 0; k < K; k++) { + // Index into (N, H, W, K) tensors + const int f = pix_to_face[n][yi][xi][k]; + if (f < 0) { + // Sentinel value is -1 indicating no face overlaps the pixel. + continue; + } + // The distance is negative if a pixel is inside a face and positive + // outside the face. Therefore use -1.0 * the distance to get the + // correct sign. + scalar_t dist = -1.0 * distances[n][yi][xi][k]; + + // Calculate the sigmoid probability. + scalar_t prob = 1. / (1. + exp(-dist / sigma)); + + // The cumulative product ensures that alpha will be 0.0 if at least 1 + // face fully covers the pixel as for that face, prob will be 1.0. + // This results in a multiplication by 0.0 because of the (1.0 - prob) + // term. Therefore the final result of (1.0 - alpha) will be 1.0. + alpha *= (1.0 - prob); + } + alphas[n][yi][xi] = 1.0 - alpha; + } +} + +torch::Tensor SigmoidAlphaBlendForwardCuda( + const at::Tensor& distances, // (N, H, W, K) + const at::Tensor& pix_to_face, // (N, H, W, K) + const float sigma) { + const int N = distances.size(0); + const int H = distances.size(1); + const int W = distances.size(2); + const int K = distances.size(3); + + at::Tensor alphas = at::zeros({N, H, W}, distances.options()); + const size_t blocks = 1024; + const size_t threads = 128; + + // Check inputs are on the same device + at::TensorArg distances_t{distances, "distances", 1}, + pix_to_face_t{pix_to_face, "pix_to_face", 2}; + at::CheckedFrom c = "SigmoidAlphaBlendForwardCuda"; + at::checkAllSameGPU(c, {distances_t, pix_to_face_t}); + + // Set the device for the kernel launch based on the device of distances + at::cuda::CUDAGuard device_guard(distances.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (distances.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return alphas; + } + + AT_DISPATCH_FLOATING_TYPES( + distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] { + // clang-format off + SigmoidAlphaBlendForwardKernel<<>>( + distances.packed_accessor(), + pix_to_face.packed_accessor(), + alphas.packed_accessor(), + sigma, + N, + H, + W, + K); + // clang-format on + })); + + AT_CUDA_CHECK(cudaGetLastError()); + return alphas; +} + +template +__global__ void SigmoidAlphaBlendBackwardKernel( + // clang-format off + const torch::PackedTensorAccessor grad_alphas, // (N, H, W) + const torch::PackedTensorAccessor alphas, // (N, H, W) + const torch::PackedTensorAccessor distances, // (N, H, W, K) + const torch::PackedTensorAccessor pix_to_face, // (N, H, W, K) + torch::PackedTensorAccessor grad_distances, // (N, H, W) + // clang-format on + const scalar_t sigma, + const int N, + const int H, + const int W, + const int K) { + // Parallelize over each of the top K faces for each pixel in images of + // size H * W * K, for each image in the batch of size N. + + // Get block and thread index. + const int n = blockIdx.x; + const int num_pixels = H * W * K; + const int num_threads = gridDim.y * blockDim.x; + const int tid = blockIdx.y * blockDim.x + threadIdx.x; + + for (int t_i = tid; t_i < num_pixels; t_i += num_threads) { + // Convert linear index to 3D index. + int yi = t_i / (W * K); + int xi = (t_i % (W * K)) / K; + int k = (t_i % (W * K)) % K; + + const scalar_t alpha = 1.0 - alphas[n][yi][xi]; + const scalar_t grad_alpha = grad_alphas[n][yi][xi]; + const int f = pix_to_face[n][yi][xi][k]; + + // Sentinel value is -1 indicating no face overlaps the pixel. + if (f >= 0) { + // The distance is negative if a pixel is inside a face and positive + // outside the face. Therefore use -1.0 * the distance to get the + // correct sign. + scalar_t dist = -1.0 * distances[n][yi][xi][k]; + + // Calculate the sigmoid probability. + scalar_t prob = 1. / (1. + exp(-dist / sigma)); + + grad_distances[n][yi][xi][k] = grad_alpha * (-1.0 / sigma) * prob * alpha; + } + } +} + +torch::Tensor SigmoidAlphaBlendBackwardCuda( + const at::Tensor& grad_alphas, // (N, H, W) + const at::Tensor& alphas, // (N, H, W) + const at::Tensor& distances, // (N, H, W, K) + const at::Tensor& pix_to_face, // (N, H, W, K) + float sigma) { + const int N = distances.size(0); + const int H = distances.size(1); + const int W = distances.size(2); + const int K = distances.size(3); + + at::Tensor grad_distances = at::zeros({N, H, W, K}, distances.options()); + + const dim3 threads(512); + const dim3 blocks(N, 1024 / N + 1); + + at::TensorArg grad_alphas_t{grad_alphas, "grad_alphas", 1}, + alphas_t{alphas, "alphas", 2}, distances_t{distances, "distances", 3}, + pix_to_face_t{pix_to_face, "pix_to_face", 4}; + at::CheckedFrom c = "SigmoidAlphaBlendBackwardCuda"; + at::checkAllSameGPU(c, {grad_alphas_t, alphas_t, distances_t, pix_to_face_t}); + + // Set the device for the kernel launch based on the device of distances + at::cuda::CUDAGuard device_guard(alphas.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (alphas.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_alphas; + } + + AT_DISPATCH_FLOATING_TYPES( + distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] { + SigmoidAlphaBlendBackwardKernel + <<>>( + // clang-format off + grad_alphas.packed_accessor(), + alphas.packed_accessor(), + distances.packed_accessor(), + pix_to_face.packed_accessor(), + grad_distances.packed_accessor(), + // clang-format on + sigma, + N, + H, + W, + K); + })); + + AT_CUDA_CHECK(cudaGetLastError()); + return grad_distances; +} diff --git a/pytorch3d/csrc/blending/sigmoid_alpha_blend.h b/pytorch3d/csrc/blending/sigmoid_alpha_blend.h new file mode 100644 index 00000000..3565a086 --- /dev/null +++ b/pytorch3d/csrc/blending/sigmoid_alpha_blend.h @@ -0,0 +1,97 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#pragma once +#include +#include + +// clang-format off +// Function to blend the top K faces per pixel based on the 2d euclidean distance +// from the center of the pixel to the face. This method is adapted from [1]. +// The output can be used to set the alpha value in an RGBA image. +// Args: +// pix_to_face: LongTensor of shape (N, H, W, K), indices of faces overlapping +// with each pixel, where N is the batch size, H, W are the dimensions of the +// image and K is the number of faces rasterized per pixel. +// distances: FloatTensor of shape (N, H, W, K), 2d euclidean distance of each pixel +// relative to the faces in pix_to_face +// sigma: float, parameter which controls the width of the sigmoid for blending +// Returns: +// alphas: FloatTensor of shape (N, H, W), the blended values for each pixel +// in the image. +// +// [1] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for +// Image-based 3D Reasoning' +// clang-format on +at::Tensor SigmoidAlphaBlendForwardCpu( + const at::Tensor& distances, + const at::Tensor& pix_to_face, + const float sigma); + +#ifdef WITH_CUDA +at::Tensor SigmoidAlphaBlendForwardCuda( + const at::Tensor& distances, + const at::Tensor& pix_to_face, + const float sigma); +#endif + +// clang-format off +// Args: +// grad_alphas: FloatTensor of shape (N, H, W), upstream gradients for alphas +// alphas: FloatTensor of shape (N, H, W), the alpha values from the forward pass +// pix_to_face: LongTensor of shape (N, H, W, K), indices of faces overlapping +// with each pixel, where N is the batch size, H, W are the dimensions of the +// image, and K is the number of faces rasterized per pixel +// distances: FloatTensor of shape (N, H, W, K), 2d euclidean distance of each pixel +// to the corresponding faces in pix_to_face +// sigma: float, parameter which controls the width of the sigmoid for blending +// Returns: +// grad_distances: FloatTensor of shape (N, H, W, K) +// clang-format on +at::Tensor SigmoidAlphaBlendBackwardCpu( + const at::Tensor& grad_alphas, + const at::Tensor& alphas, + const at::Tensor& distances, + const at::Tensor& pix_to_face, + const float sigma); + +#ifdef WITH_CUDA +at::Tensor SigmoidAlphaBlendBackwardCuda( + const at::Tensor& grad_alphas, + const at::Tensor& alphas, + const at::Tensor& distances, + const at::Tensor& pix_to_face, + const float sigma); +#endif + +// Implementation which is exposed. +at::Tensor +SigmoidAlphaBlend(at::Tensor& distances, at::Tensor& pix_to_face, float sigma) { + if (distances.is_cuda() && pix_to_face.is_cuda()) { +#ifdef WITH_CUDA + return SigmoidAlphaBlendForwardCuda(distances, pix_to_face, sigma); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return SigmoidAlphaBlendForwardCpu(distances, pix_to_face, sigma); +} + +// Implementation which is exposed. +at::Tensor SigmoidAlphaBlendBackward( + const at::Tensor& grad_alphas, + const at::Tensor& alphas, + const at::Tensor& distances, + const at::Tensor& pix_to_face, + const float sigma) { + if (distances.is_cuda() && pix_to_face.is_cuda() && alphas.is_cuda() && + grad_alphas.is_cuda()) { +#ifdef WITH_CUDA + return SigmoidAlphaBlendBackwardCuda( + grad_alphas, alphas, distances, pix_to_face, sigma); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return SigmoidAlphaBlendBackwardCpu( + grad_alphas, alphas, distances, pix_to_face, sigma); +} diff --git a/pytorch3d/csrc/blending/sigmoid_alpha_blend_cpu.cpp b/pytorch3d/csrc/blending/sigmoid_alpha_blend_cpu.cpp new file mode 100644 index 00000000..85d19bf8 --- /dev/null +++ b/pytorch3d/csrc/blending/sigmoid_alpha_blend_cpu.cpp @@ -0,0 +1,123 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include +#include +#include + +at::Tensor SigmoidAlphaBlendForwardCpu( + const at::Tensor& distances, // (N, H, W, K) + const at::Tensor& pix_to_face, // (N, H, W, K) + const float sigma) { + const int N = distances.size(0); + const int H = distances.size(1); + const int W = distances.size(2); + const int K = distances.size(3); + + torch::Tensor out = torch::empty({N, H, W}, distances.options()); + + auto distances_a = distances.accessor(); + auto pix_to_face_a = pix_to_face.accessor(); + auto out_a = out.accessor(); + + // Iterate over the images in the batch. + for (int n = 0; n < N; ++n) { + // Iterate through the horizontal lines of the image from top to bottom. + for (int h = 0; h < H; ++h) { + // Iterate over the pixels on this horizontal line, left to right. + for (int w = 0; w < W; ++w) { + float alpha = 1.0; + + // Loop through the top K faces for each pixel. + for (int k = 0; k < K; ++k) { + const int f = pix_to_face_a[n][h][w][k]; + if (f < 0) { + // Sentinel value is -1 indicating no face overlaps the pixel. + continue; + } + // The distance is negative if a pixel is inside a face and positive + // outside the face. Therefore use -1.0 * the distance to get the + // correct sign. + float dist = -1.0 * distances_a[n][h][w][k]; + + // Calculate the sigmoid probability. + float prob = 1. / (1. + exp(-dist / sigma)); + + // The product ensures that alpha will be 0.0 if at least 1 + // face fully covers the pixel as for that face, prob will be 1.0. + // This results in a multiplication by 0.0 because of the (1.0 - prob) + // term. Therefore 1.0 - alpha will be 1.0. + alpha *= 1.0 - prob; + } + out_a[n][h][w] = 1.0 - alpha; + } + } + } + return out; +} + +at::Tensor SigmoidAlphaBlendBackwardCpu( + const at::Tensor& grad_alphas, // (N, H, W) + const at::Tensor& alphas, // (N, H, W) + const at::Tensor& distances, // (N, H, W, K) + const at::Tensor& pix_to_face, // (N, H, W, K) + const float sigma) { + const int N = distances.size(0); + const int H = distances.size(1); + const int W = distances.size(2); + const int K = distances.size(3); + + auto distances_a = distances.accessor(); + auto pix_to_face_a = pix_to_face.accessor(); + auto alphas_a = alphas.accessor(); + auto grad_alphas_a = grad_alphas.accessor(); + + torch::Tensor grad_distances = + torch::zeros({N, H, W, K}, distances.options()); + auto grad_distances_a = grad_distances.accessor(); + + // Iterate over the images in the batch. + for (int n = 0; n < N; ++n) { + // Iterate through the horizontal lines of the image from top to bottom. + for (int h = 0; h < H; ++h) { + // Iterate over the pixels on this horizontal line, left to right. + for (int w = 0; w < W; ++w) { + // Get the alpha value from the forward pass and the + // upstream gradient. + const float alpha = 1.0 - alphas_a[n][h][w]; + const float grad_alpha = grad_alphas_a[n][h][w]; + + // Loop through the top K faces for each pixel. + for (int k = 0; k < K; ++k) { + const int f = pix_to_face_a[n][h][w][k]; + if (f < 0) { + // Sentinel value is -1 indicating no face overlaps the pixel + continue; + } + // The distance is negative if a pixel is inside a face and positive + // outside the face. Therefore use -1.0 * distance to get the + // correct sign. + float dist = -1.0 * distances_a[n][h][w][k]; + + // Calculate the sigmoid probability. + float prob = 1. / (1. + exp(-dist / sigma)); + + // clang-format off + // We need to take the derivative of alpha w.r.t to the distance. + // alpha = 1.0 - (1.0- sigmoid(-x)) * (1.0 - sigmoid(-x2)) * ... * (1.0 - sigmoid(-xn)) + // + // Note that d/dx sigmoid(x) = sigmoid(x) * (1.0 - sigmoid(x)) + // + // This gives: + // d_alpha/d_dist = -1.0 * -1.0 * sigmoid(-x)(1. - sigmoid(-x)) * (-1.0/sigma) + // * ((1.0 - sigmoid(-x2) * ... * (1.0 - sigmoid(-xn)) + // = (-1.0/sigma) * prob * (1.0 - prob) * alpha/(1.0 - prob) + // = (-1.0/sigma) * prob * alpha + // clang-format on + grad_distances_a[n][h][w][k] = + grad_alpha * (-1.0 / sigma) * prob * alpha; + } + } + } + } + return grad_distances; +} diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 55da2ee7..c80800e3 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -1,6 +1,7 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include "blending/sigmoid_alpha_blend.h" #include "compositing/alpha_composite.h" #include "compositing/norm_weighted_sum.h" #include "compositing/weighted_sum.h" @@ -31,6 +32,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rasterize_points_backward", &RasterizePointsBackward); m.def("rasterize_meshes_backward", &RasterizeMeshesBackward); m.def("rasterize_meshes", &RasterizeMeshes); + m.def("sigmoid_alpha_blend", &SigmoidAlphaBlend); + m.def("sigmoid_alpha_blend_backward", &SigmoidAlphaBlendBackward); // Accumulation functions m.def("accum_weightedsumnorm", &weightedSumNormForward); diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu b/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu index 6e286add..54f15d33 100644 --- a/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals.cu @@ -216,10 +216,9 @@ std::tuple FaceAreasNormalsForwardCuda( const auto F = faces.size(0); // Check inputs are on the same device - at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2}; + at::TensorArg verts_t{verts, "verts", 1}, faces_t{faces, "faces", 2}; at::CheckedFrom c = "FaceAreasNormalsForwardCuda"; at::checkAllSameGPU(c, {verts_t, faces_t}); - at::checkAllSameType(c, {verts_t, faces_t}); // Set the device for the kernel launch based on the device of verts at::cuda::CUDAGuard device_guard(verts.device()); @@ -256,12 +255,11 @@ at::Tensor FaceAreasNormalsBackwardCuda( const at::Tensor verts, const at::Tensor faces) { // Check inputs are on the same device - at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2}, - grad_areas_t{verts, "grad_areas", 3}, - grad_normals_t{verts, "grad_normals", 4}; + at::TensorArg verts_t{verts, "verts", 1}, faces_t{faces, "faces", 2}, + grad_areas_t{grad_areas, "grad_areas", 3}, + grad_normals_t{grad_normals, "grad_normals", 4}; at::CheckedFrom c = "FaceAreasNormalsBackwardCuda"; at::checkAllSameGPU(c, {verts_t, faces_t, grad_areas_t, grad_normals_t}); - at::checkAllSameType(c, {verts_t, faces_t, grad_areas_t, grad_normals_t}); // Set the device for the kernel launch based on the device of verts at::cuda::CUDAGuard device_guard(verts.device()); diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 29cb6f12..1bfa589b 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -5,6 +5,7 @@ from typing import NamedTuple, Sequence import numpy as np import torch +from pytorch3d import _C # Example functions for blending the top K colors per pixel using the outputs @@ -59,6 +60,29 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4) +# Wrapper for the C++/CUDA Implementation of sigmoid alpha blend. +class _SigmoidAlphaBlend(torch.autograd.Function): + @staticmethod + def forward(ctx, dists, pix_to_face, sigma): + alphas = _C.sigmoid_alpha_blend(dists, pix_to_face, sigma) + ctx.save_for_backward(dists, pix_to_face, alphas) + ctx.sigma = sigma + return alphas + + @staticmethod + def backward(ctx, grad_alphas): + dists, pix_to_face, alphas = ctx.saved_tensors + sigma = ctx.sigma + grad_dists = _C.sigmoid_alpha_blend_backward( + grad_alphas, alphas, dists, pix_to_face, sigma + ) + return grad_dists, None, None + + +# pyre-fixme[16]: `_SigmoidAlphaBlend` has no attribute `apply`. +_sigmoid_alpha = _SigmoidAlphaBlend.apply + + def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: """ Silhouette blending to return an RGBA image @@ -83,19 +107,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: """ N, H, W, K = fragments.pix_to_face.shape pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device) - mask = fragments.pix_to_face >= 0 - - # The distance is negative if a pixel is inside a face and positive outside - # the face. Therefore use -1.0 * fragments.dists to get the correct sign. - prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask - - # The cumulative product ensures that alpha will be 0.0 if at least 1 - # face fully covers the pixel as for that face, prob will be 1.0. - # This results in a multiplication by 0.0 because of the (1.0 - prob) - # term. Therefore 1.0 - alpha will be 1.0. - alpha = torch.prod((1.0 - prob), dim=-1) - pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB - pixel_colors[..., 3] = 1.0 - alpha + pixel_colors[..., :3] = colors[..., 0, :] + alpha = _sigmoid_alpha(fragments.dists, fragments.pix_to_face, blend_params.sigma) + pixel_colors[..., 3] = alpha return pixel_colors diff --git a/tests/bm_blending.py b/tests/bm_blending.py index cdde1975..16aa11bc 100644 --- a/tests/bm_blending.py +++ b/tests/bm_blending.py @@ -8,17 +8,24 @@ from test_blending import TestBlending def bm_blending() -> None: - devices = ["cpu", "cuda"] + devices = ["cuda"] kwargs_list = [] - num_meshes = [16] - image_size = [128, 256] + num_meshes = [8] + image_size = [64, 128, 256] faces_per_pixel = [50, 100] - test_cases = product(num_meshes, image_size, faces_per_pixel, devices) + backend = ["pytorch", "custom"] + test_cases = product(num_meshes, image_size, faces_per_pixel, devices, backend) for case in test_cases: - n, s, k, d = case + n, s, k, d, b = case kwargs_list.append( - {"num_meshes": n, "image_size": s, "faces_per_pixel": k, "device": d} + { + "num_meshes": n, + "image_size": s, + "faces_per_pixel": k, + "device": d, + "backend": b, + } ) benchmark( @@ -28,6 +35,7 @@ def bm_blending() -> None: warmup_iters=1, ) + kwargs_list = [case for case in kwargs_list if case["backend"] == "pytorch"] benchmark( TestBlending.bm_softmax_blending, "SOFTMAX_BLENDING_PYTORCH", diff --git a/tests/test_blending.py b/tests/test_blending.py index 117ab11b..cb9ba41a 100644 --- a/tests/test_blending.py +++ b/tests/test_blending.py @@ -44,6 +44,16 @@ def sigmoid_blend_naive_loop(colors, fragments, blend_params): return pixel_colors +def sigmoid_alpha_blend_vectorized(colors, fragments, blend_params) -> torch.Tensor: + N, H, W, K = fragments.pix_to_face.shape + pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device) + mask = fragments.pix_to_face >= 0 + prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask + pixel_colors[..., :3] = colors[..., 0, :] + pixel_colors[..., 3] = 1.0 - torch.prod((1.0 - prob), dim=-1) + return pixel_colors + + def sigmoid_blend_naive_loop_backward(grad_images, images, fragments, blend_params): pix_to_face = fragments.pix_to_face dists = fragments.dists @@ -136,10 +146,9 @@ class TestBlending(TestCaseMixin, unittest.TestCase): def _compare_impls( self, fn1, fn2, args1, args2, grad_var1=None, grad_var2=None, compare_grads=True ): - out1 = fn1(*args1) out2 = fn2(*args2) - self.assertTrue(torch.allclose(out1.cpu(), out2.cpu(), atol=1e-7)) + self.assertClose(out1.cpu()[..., 3], out2.cpu()[..., 3], atol=1e-7) # Check gradients if not compare_grads: @@ -151,9 +160,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase): (out2 * grad_out).sum().backward() self.assertTrue(hasattr(grad_var2, "grad")) - self.assertTrue( - torch.allclose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5) - ) + self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5) def test_hard_rgb_blend(self): N, H, W, K = 5, 10, 10, 20 @@ -223,18 +230,15 @@ class TestBlending(TestCaseMixin, unittest.TestCase): torch.manual_seed(231) F = 32 # number of faces in the mesh # The python loop version is really slow so only using small input sizes. - N, S, K = 2, 10, 5 + N, S, K = 1, 4, 1 device = torch.device("cuda") - pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1 + pix_to_face = torch.randint(low=-1, high=F, size=(N, S, S, K), device=device) colors = torch.randn((N, S, S, K, 3), device=device) empty = torch.tensor([], device=device) - # # randomly flip the sign of the distance - # # (-) means inside triangle, (+) means outside triangle. - random_sign_flip = torch.rand((N, S, S, K)) - random_sign_flip[random_sign_flip > 0.5] *= -1.0 - dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) - dists2 = dists1.detach().clone() + dists1 = torch.randn(size=(N, S, S, K), device=device) + dists2 = dists1.clone() + dists1.requires_grad = True dists2.requires_grad = True fragments1 = Fragments( @@ -256,7 +260,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase): self._compare_impls( sigmoid_alpha_blend, - sigmoid_blend_naive_loop, + sigmoid_alpha_blend_vectorized, args1, args2, dists1, @@ -324,26 +328,21 @@ class TestBlending(TestCaseMixin, unittest.TestCase): num_meshes: int = 16, image_size: int = 128, faces_per_pixel: int = 100, - device: str = "cpu", + device="cuda", + backend: str = "pytorch", ): - if torch.cuda.is_available() and "cuda:" in device: - # If a device other than the default is used, set the device explicity. - torch.cuda.set_device(device) - device = torch.device(device) torch.manual_seed(231) # Create dummy outputs of rasterization N, S, K = num_meshes, image_size, faces_per_pixel F = 32 # num faces in the mesh - pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1 + pix_to_face = torch.randint( + low=-1, high=F + 1, size=(N, S, S, K), device=device + ) colors = torch.randn((N, S, S, K, 3), device=device) empty = torch.tensor([], device=device) - # # randomly flip the sign of the distance - # # (-) means inside triangle, (+) means outside triangle. - random_sign_flip = torch.rand((N, S, S, K), device=device) - random_sign_flip[random_sign_flip > 0.5] *= -1.0 dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) fragments = Fragments( pix_to_face=pix_to_face, @@ -352,11 +351,18 @@ class TestBlending(TestCaseMixin, unittest.TestCase): dists=dists1, ) blend_params = BlendParams(sigma=1e-3) + + blend_fn = ( + sigmoid_alpha_blend_vectorized + if backend == "pytorch" + else sigmoid_alpha_blend + ) + torch.cuda.synchronize() def fn(): # test forward and backward pass - images = sigmoid_alpha_blend(colors, fragments, blend_params) + images = blend_fn(colors, fragments, blend_params) images.sum().backward() torch.cuda.synchronize() @@ -368,6 +374,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase): image_size: int = 128, faces_per_pixel: int = 100, device: str = "cpu", + backend: str = "pytorch", ): if torch.cuda.is_available() and "cuda:" in device: # If a device other than the default is used, set the device explicity. @@ -379,14 +386,12 @@ class TestBlending(TestCaseMixin, unittest.TestCase): # Create dummy outputs of rasterization N, S, K = num_meshes, image_size, faces_per_pixel F = 32 # num faces in the mesh - pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1 + pix_to_face = torch.randint( + low=-1, high=F + 1, size=(N, S, S, K), device=device + ) colors = torch.randn((N, S, S, K, 3), device=device) empty = torch.tensor([], device=device) - # # randomly flip the sign of the distance - # # (-) means inside triangle, (+) means outside triangle. - random_sign_flip = torch.rand((N, S, S, K), device=device) - random_sign_flip[random_sign_flip > 0.5] *= -1.0 dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) zbuf = torch.randn(size=(N, S, S, K), requires_grad=True, device=device) fragments = Fragments(