mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
C++/CUDA implementation of sigmoid alpha blend
Summary: C++/CUDA implementation of forward and backward passes for the sigmoid alpha blending function. This is slightly faster than the vectorized implementation in Python, but more importantly uses less memory due to fewer tensors being created. Reviewed By: gkioxari Differential Revision: D19980671 fbshipit-source-id: 0779055d2c68b1f20fb0870e60046077ef4613ff
This commit is contained in:
parent
dc08c30583
commit
bce396df93
210
pytorch3d/csrc/blending/sigmoid_alpha_blend.cu
Normal file
210
pytorch3d/csrc/blending/sigmoid_alpha_blend.cu
Normal file
@ -0,0 +1,210 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void SigmoidAlphaBlendForwardKernel(
|
||||
// clang-format off
|
||||
const torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> distances, // (N, H, W, K)
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> pix_to_face, // (N, H, W, K)
|
||||
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> alphas, // (N, H, W)
|
||||
// clang-format on
|
||||
const scalar_t sigma,
|
||||
const int N,
|
||||
const int H,
|
||||
const int W,
|
||||
const int K) {
|
||||
// Parallelize over each pixel in images of
|
||||
// size H * W, for each image in the batch of size N.
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
// TODO: revisit performance of this kernel with shared memory usage
|
||||
|
||||
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
const int n = t_i / (H * W); // batch index.
|
||||
const int pix_idx = t_i % (H * W);
|
||||
|
||||
// TODO: fix index calculation for non square images.
|
||||
const int yi = pix_idx / W;
|
||||
const int xi = pix_idx % W;
|
||||
scalar_t alpha = 1.0;
|
||||
|
||||
// Loop over all the faces for this pixel.
|
||||
for (int k = 0; k < K; k++) {
|
||||
// Index into (N, H, W, K) tensors
|
||||
const int f = pix_to_face[n][yi][xi][k];
|
||||
if (f < 0) {
|
||||
// Sentinel value is -1 indicating no face overlaps the pixel.
|
||||
continue;
|
||||
}
|
||||
// The distance is negative if a pixel is inside a face and positive
|
||||
// outside the face. Therefore use -1.0 * the distance to get the
|
||||
// correct sign.
|
||||
scalar_t dist = -1.0 * distances[n][yi][xi][k];
|
||||
|
||||
// Calculate the sigmoid probability.
|
||||
scalar_t prob = 1. / (1. + exp(-dist / sigma));
|
||||
|
||||
// The cumulative product ensures that alpha will be 0.0 if at least 1
|
||||
// face fully covers the pixel as for that face, prob will be 1.0.
|
||||
// This results in a multiplication by 0.0 because of the (1.0 - prob)
|
||||
// term. Therefore the final result of (1.0 - alpha) will be 1.0.
|
||||
alpha *= (1.0 - prob);
|
||||
}
|
||||
alphas[n][yi][xi] = 1.0 - alpha;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor SigmoidAlphaBlendForwardCuda(
|
||||
const at::Tensor& distances, // (N, H, W, K)
|
||||
const at::Tensor& pix_to_face, // (N, H, W, K)
|
||||
const float sigma) {
|
||||
const int N = distances.size(0);
|
||||
const int H = distances.size(1);
|
||||
const int W = distances.size(2);
|
||||
const int K = distances.size(3);
|
||||
|
||||
at::Tensor alphas = at::zeros({N, H, W}, distances.options());
|
||||
const size_t blocks = 1024;
|
||||
const size_t threads = 128;
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg distances_t{distances, "distances", 1},
|
||||
pix_to_face_t{pix_to_face, "pix_to_face", 2};
|
||||
at::CheckedFrom c = "SigmoidAlphaBlendForwardCuda";
|
||||
at::checkAllSameGPU(c, {distances_t, pix_to_face_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of distances
|
||||
at::cuda::CUDAGuard device_guard(distances.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (distances.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return alphas;
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
distances.scalar_type(), "sigmoid_alpha_blend_kernel", ([&] {
|
||||
// clang-format off
|
||||
SigmoidAlphaBlendForwardKernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
||||
distances.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
pix_to_face.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
|
||||
sigma,
|
||||
N,
|
||||
H,
|
||||
W,
|
||||
K);
|
||||
// clang-format on
|
||||
}));
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return alphas;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void SigmoidAlphaBlendBackwardKernel(
|
||||
// clang-format off
|
||||
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> grad_alphas, // (N, H, W)
|
||||
const torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> alphas, // (N, H, W)
|
||||
const torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> distances, // (N, H, W, K)
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> pix_to_face, // (N, H, W, K)
|
||||
torch::PackedTensorAccessor<scalar_t, 4, torch::RestrictPtrTraits, size_t> grad_distances, // (N, H, W)
|
||||
// clang-format on
|
||||
const scalar_t sigma,
|
||||
const int N,
|
||||
const int H,
|
||||
const int W,
|
||||
const int K) {
|
||||
// Parallelize over each of the top K faces for each pixel in images of
|
||||
// size H * W * K, for each image in the batch of size N.
|
||||
|
||||
// Get block and thread index.
|
||||
const int n = blockIdx.x;
|
||||
const int num_pixels = H * W * K;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < num_pixels; t_i += num_threads) {
|
||||
// Convert linear index to 3D index.
|
||||
int yi = t_i / (W * K);
|
||||
int xi = (t_i % (W * K)) / K;
|
||||
int k = (t_i % (W * K)) % K;
|
||||
|
||||
const scalar_t alpha = 1.0 - alphas[n][yi][xi];
|
||||
const scalar_t grad_alpha = grad_alphas[n][yi][xi];
|
||||
const int f = pix_to_face[n][yi][xi][k];
|
||||
|
||||
// Sentinel value is -1 indicating no face overlaps the pixel.
|
||||
if (f >= 0) {
|
||||
// The distance is negative if a pixel is inside a face and positive
|
||||
// outside the face. Therefore use -1.0 * the distance to get the
|
||||
// correct sign.
|
||||
scalar_t dist = -1.0 * distances[n][yi][xi][k];
|
||||
|
||||
// Calculate the sigmoid probability.
|
||||
scalar_t prob = 1. / (1. + exp(-dist / sigma));
|
||||
|
||||
grad_distances[n][yi][xi][k] = grad_alpha * (-1.0 / sigma) * prob * alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor SigmoidAlphaBlendBackwardCuda(
|
||||
const at::Tensor& grad_alphas, // (N, H, W)
|
||||
const at::Tensor& alphas, // (N, H, W)
|
||||
const at::Tensor& distances, // (N, H, W, K)
|
||||
const at::Tensor& pix_to_face, // (N, H, W, K)
|
||||
float sigma) {
|
||||
const int N = distances.size(0);
|
||||
const int H = distances.size(1);
|
||||
const int W = distances.size(2);
|
||||
const int K = distances.size(3);
|
||||
|
||||
at::Tensor grad_distances = at::zeros({N, H, W, K}, distances.options());
|
||||
|
||||
const dim3 threads(512);
|
||||
const dim3 blocks(N, 1024 / N + 1);
|
||||
|
||||
at::TensorArg grad_alphas_t{grad_alphas, "grad_alphas", 1},
|
||||
alphas_t{alphas, "alphas", 2}, distances_t{distances, "distances", 3},
|
||||
pix_to_face_t{pix_to_face, "pix_to_face", 4};
|
||||
at::CheckedFrom c = "SigmoidAlphaBlendBackwardCuda";
|
||||
at::checkAllSameGPU(c, {grad_alphas_t, alphas_t, distances_t, pix_to_face_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of distances
|
||||
at::cuda::CUDAGuard device_guard(alphas.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (alphas.numel() == 0) {
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_alphas;
|
||||
}
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
distances.scalar_type(), "sigmoid_alpha_blend_backward_kernel", ([&] {
|
||||
SigmoidAlphaBlendBackwardKernel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
// clang-format off
|
||||
grad_alphas.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
|
||||
distances.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
pix_to_face.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
grad_distances.packed_accessor<scalar_t, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
// clang-format on
|
||||
sigma,
|
||||
N,
|
||||
H,
|
||||
W,
|
||||
K);
|
||||
}));
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return grad_distances;
|
||||
}
|
97
pytorch3d/csrc/blending/sigmoid_alpha_blend.h
Normal file
97
pytorch3d/csrc/blending/sigmoid_alpha_blend.h
Normal file
@ -0,0 +1,97 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
|
||||
// clang-format off
|
||||
// Function to blend the top K faces per pixel based on the 2d euclidean distance
|
||||
// from the center of the pixel to the face. This method is adapted from [1].
|
||||
// The output can be used to set the alpha value in an RGBA image.
|
||||
// Args:
|
||||
// pix_to_face: LongTensor of shape (N, H, W, K), indices of faces overlapping
|
||||
// with each pixel, where N is the batch size, H, W are the dimensions of the
|
||||
// image and K is the number of faces rasterized per pixel.
|
||||
// distances: FloatTensor of shape (N, H, W, K), 2d euclidean distance of each pixel
|
||||
// relative to the faces in pix_to_face
|
||||
// sigma: float, parameter which controls the width of the sigmoid for blending
|
||||
// Returns:
|
||||
// alphas: FloatTensor of shape (N, H, W), the blended values for each pixel
|
||||
// in the image.
|
||||
//
|
||||
// [1] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
|
||||
// Image-based 3D Reasoning'
|
||||
// clang-format on
|
||||
at::Tensor SigmoidAlphaBlendForwardCpu(
|
||||
const at::Tensor& distances,
|
||||
const at::Tensor& pix_to_face,
|
||||
const float sigma);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
at::Tensor SigmoidAlphaBlendForwardCuda(
|
||||
const at::Tensor& distances,
|
||||
const at::Tensor& pix_to_face,
|
||||
const float sigma);
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
// Args:
|
||||
// grad_alphas: FloatTensor of shape (N, H, W), upstream gradients for alphas
|
||||
// alphas: FloatTensor of shape (N, H, W), the alpha values from the forward pass
|
||||
// pix_to_face: LongTensor of shape (N, H, W, K), indices of faces overlapping
|
||||
// with each pixel, where N is the batch size, H, W are the dimensions of the
|
||||
// image, and K is the number of faces rasterized per pixel
|
||||
// distances: FloatTensor of shape (N, H, W, K), 2d euclidean distance of each pixel
|
||||
// to the corresponding faces in pix_to_face
|
||||
// sigma: float, parameter which controls the width of the sigmoid for blending
|
||||
// Returns:
|
||||
// grad_distances: FloatTensor of shape (N, H, W, K)
|
||||
// clang-format on
|
||||
at::Tensor SigmoidAlphaBlendBackwardCpu(
|
||||
const at::Tensor& grad_alphas,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& distances,
|
||||
const at::Tensor& pix_to_face,
|
||||
const float sigma);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
at::Tensor SigmoidAlphaBlendBackwardCuda(
|
||||
const at::Tensor& grad_alphas,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& distances,
|
||||
const at::Tensor& pix_to_face,
|
||||
const float sigma);
|
||||
#endif
|
||||
|
||||
// Implementation which is exposed.
|
||||
at::Tensor
|
||||
SigmoidAlphaBlend(at::Tensor& distances, at::Tensor& pix_to_face, float sigma) {
|
||||
if (distances.is_cuda() && pix_to_face.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return SigmoidAlphaBlendForwardCuda(distances, pix_to_face, sigma);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return SigmoidAlphaBlendForwardCpu(distances, pix_to_face, sigma);
|
||||
}
|
||||
|
||||
// Implementation which is exposed.
|
||||
at::Tensor SigmoidAlphaBlendBackward(
|
||||
const at::Tensor& grad_alphas,
|
||||
const at::Tensor& alphas,
|
||||
const at::Tensor& distances,
|
||||
const at::Tensor& pix_to_face,
|
||||
const float sigma) {
|
||||
if (distances.is_cuda() && pix_to_face.is_cuda() && alphas.is_cuda() &&
|
||||
grad_alphas.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return SigmoidAlphaBlendBackwardCuda(
|
||||
grad_alphas, alphas, distances, pix_to_face, sigma);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return SigmoidAlphaBlendBackwardCpu(
|
||||
grad_alphas, alphas, distances, pix_to_face, sigma);
|
||||
}
|
123
pytorch3d/csrc/blending/sigmoid_alpha_blend_cpu.cpp
Normal file
123
pytorch3d/csrc/blending/sigmoid_alpha_blend_cpu.cpp
Normal file
@ -0,0 +1,123 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
at::Tensor SigmoidAlphaBlendForwardCpu(
|
||||
const at::Tensor& distances, // (N, H, W, K)
|
||||
const at::Tensor& pix_to_face, // (N, H, W, K)
|
||||
const float sigma) {
|
||||
const int N = distances.size(0);
|
||||
const int H = distances.size(1);
|
||||
const int W = distances.size(2);
|
||||
const int K = distances.size(3);
|
||||
|
||||
torch::Tensor out = torch::empty({N, H, W}, distances.options());
|
||||
|
||||
auto distances_a = distances.accessor<float, 4>();
|
||||
auto pix_to_face_a = pix_to_face.accessor<int64_t, 4>();
|
||||
auto out_a = out.accessor<float, 3>();
|
||||
|
||||
// Iterate over the images in the batch.
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Iterate through the horizontal lines of the image from top to bottom.
|
||||
for (int h = 0; h < H; ++h) {
|
||||
// Iterate over the pixels on this horizontal line, left to right.
|
||||
for (int w = 0; w < W; ++w) {
|
||||
float alpha = 1.0;
|
||||
|
||||
// Loop through the top K faces for each pixel.
|
||||
for (int k = 0; k < K; ++k) {
|
||||
const int f = pix_to_face_a[n][h][w][k];
|
||||
if (f < 0) {
|
||||
// Sentinel value is -1 indicating no face overlaps the pixel.
|
||||
continue;
|
||||
}
|
||||
// The distance is negative if a pixel is inside a face and positive
|
||||
// outside the face. Therefore use -1.0 * the distance to get the
|
||||
// correct sign.
|
||||
float dist = -1.0 * distances_a[n][h][w][k];
|
||||
|
||||
// Calculate the sigmoid probability.
|
||||
float prob = 1. / (1. + exp(-dist / sigma));
|
||||
|
||||
// The product ensures that alpha will be 0.0 if at least 1
|
||||
// face fully covers the pixel as for that face, prob will be 1.0.
|
||||
// This results in a multiplication by 0.0 because of the (1.0 - prob)
|
||||
// term. Therefore 1.0 - alpha will be 1.0.
|
||||
alpha *= 1.0 - prob;
|
||||
}
|
||||
out_a[n][h][w] = 1.0 - alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor SigmoidAlphaBlendBackwardCpu(
|
||||
const at::Tensor& grad_alphas, // (N, H, W)
|
||||
const at::Tensor& alphas, // (N, H, W)
|
||||
const at::Tensor& distances, // (N, H, W, K)
|
||||
const at::Tensor& pix_to_face, // (N, H, W, K)
|
||||
const float sigma) {
|
||||
const int N = distances.size(0);
|
||||
const int H = distances.size(1);
|
||||
const int W = distances.size(2);
|
||||
const int K = distances.size(3);
|
||||
|
||||
auto distances_a = distances.accessor<float, 4>();
|
||||
auto pix_to_face_a = pix_to_face.accessor<int64_t, 4>();
|
||||
auto alphas_a = alphas.accessor<float, 3>();
|
||||
auto grad_alphas_a = grad_alphas.accessor<float, 3>();
|
||||
|
||||
torch::Tensor grad_distances =
|
||||
torch::zeros({N, H, W, K}, distances.options());
|
||||
auto grad_distances_a = grad_distances.accessor<float, 4>();
|
||||
|
||||
// Iterate over the images in the batch.
|
||||
for (int n = 0; n < N; ++n) {
|
||||
// Iterate through the horizontal lines of the image from top to bottom.
|
||||
for (int h = 0; h < H; ++h) {
|
||||
// Iterate over the pixels on this horizontal line, left to right.
|
||||
for (int w = 0; w < W; ++w) {
|
||||
// Get the alpha value from the forward pass and the
|
||||
// upstream gradient.
|
||||
const float alpha = 1.0 - alphas_a[n][h][w];
|
||||
const float grad_alpha = grad_alphas_a[n][h][w];
|
||||
|
||||
// Loop through the top K faces for each pixel.
|
||||
for (int k = 0; k < K; ++k) {
|
||||
const int f = pix_to_face_a[n][h][w][k];
|
||||
if (f < 0) {
|
||||
// Sentinel value is -1 indicating no face overlaps the pixel
|
||||
continue;
|
||||
}
|
||||
// The distance is negative if a pixel is inside a face and positive
|
||||
// outside the face. Therefore use -1.0 * distance to get the
|
||||
// correct sign.
|
||||
float dist = -1.0 * distances_a[n][h][w][k];
|
||||
|
||||
// Calculate the sigmoid probability.
|
||||
float prob = 1. / (1. + exp(-dist / sigma));
|
||||
|
||||
// clang-format off
|
||||
// We need to take the derivative of alpha w.r.t to the distance.
|
||||
// alpha = 1.0 - (1.0- sigmoid(-x)) * (1.0 - sigmoid(-x2)) * ... * (1.0 - sigmoid(-xn))
|
||||
//
|
||||
// Note that d/dx sigmoid(x) = sigmoid(x) * (1.0 - sigmoid(x))
|
||||
//
|
||||
// This gives:
|
||||
// d_alpha/d_dist = -1.0 * -1.0 * sigmoid(-x)(1. - sigmoid(-x)) * (-1.0/sigma)
|
||||
// * ((1.0 - sigmoid(-x2) * ... * (1.0 - sigmoid(-xn))
|
||||
// = (-1.0/sigma) * prob * (1.0 - prob) * alpha/(1.0 - prob)
|
||||
// = (-1.0/sigma) * prob * alpha
|
||||
// clang-format on
|
||||
grad_distances_a[n][h][w][k] =
|
||||
grad_alpha * (-1.0 / sigma) * prob * alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return grad_distances;
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "blending/sigmoid_alpha_blend.h"
|
||||
#include "compositing/alpha_composite.h"
|
||||
#include "compositing/norm_weighted_sum.h"
|
||||
#include "compositing/weighted_sum.h"
|
||||
@ -31,6 +32,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("rasterize_points_backward", &RasterizePointsBackward);
|
||||
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);
|
||||
m.def("rasterize_meshes", &RasterizeMeshes);
|
||||
m.def("sigmoid_alpha_blend", &SigmoidAlphaBlend);
|
||||
m.def("sigmoid_alpha_blend_backward", &SigmoidAlphaBlendBackward);
|
||||
|
||||
// Accumulation functions
|
||||
m.def("accum_weightedsumnorm", &weightedSumNormForward);
|
||||
|
@ -216,10 +216,9 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForwardCuda(
|
||||
const auto F = faces.size(0);
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2};
|
||||
at::TensorArg verts_t{verts, "verts", 1}, faces_t{faces, "faces", 2};
|
||||
at::CheckedFrom c = "FaceAreasNormalsForwardCuda";
|
||||
at::checkAllSameGPU(c, {verts_t, faces_t});
|
||||
at::checkAllSameType(c, {verts_t, faces_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of verts
|
||||
at::cuda::CUDAGuard device_guard(verts.device());
|
||||
@ -256,12 +255,11 @@ at::Tensor FaceAreasNormalsBackwardCuda(
|
||||
const at::Tensor verts,
|
||||
const at::Tensor faces) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg verts_t{verts, "verts", 1}, faces_t{verts, "faces", 2},
|
||||
grad_areas_t{verts, "grad_areas", 3},
|
||||
grad_normals_t{verts, "grad_normals", 4};
|
||||
at::TensorArg verts_t{verts, "verts", 1}, faces_t{faces, "faces", 2},
|
||||
grad_areas_t{grad_areas, "grad_areas", 3},
|
||||
grad_normals_t{grad_normals, "grad_normals", 4};
|
||||
at::CheckedFrom c = "FaceAreasNormalsBackwardCuda";
|
||||
at::checkAllSameGPU(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
|
||||
at::checkAllSameType(c, {verts_t, faces_t, grad_areas_t, grad_normals_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of verts
|
||||
at::cuda::CUDAGuard device_guard(verts.device());
|
||||
|
@ -5,6 +5,7 @@ from typing import NamedTuple, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
|
||||
|
||||
# Example functions for blending the top K colors per pixel using the outputs
|
||||
@ -59,6 +60,29 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4)
|
||||
|
||||
|
||||
# Wrapper for the C++/CUDA Implementation of sigmoid alpha blend.
|
||||
class _SigmoidAlphaBlend(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, dists, pix_to_face, sigma):
|
||||
alphas = _C.sigmoid_alpha_blend(dists, pix_to_face, sigma)
|
||||
ctx.save_for_backward(dists, pix_to_face, alphas)
|
||||
ctx.sigma = sigma
|
||||
return alphas
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_alphas):
|
||||
dists, pix_to_face, alphas = ctx.saved_tensors
|
||||
sigma = ctx.sigma
|
||||
grad_dists = _C.sigmoid_alpha_blend_backward(
|
||||
grad_alphas, alphas, dists, pix_to_face, sigma
|
||||
)
|
||||
return grad_dists, None, None
|
||||
|
||||
|
||||
# pyre-fixme[16]: `_SigmoidAlphaBlend` has no attribute `apply`.
|
||||
_sigmoid_alpha = _SigmoidAlphaBlend.apply
|
||||
|
||||
|
||||
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
"""
|
||||
Silhouette blending to return an RGBA image
|
||||
@ -83,19 +107,9 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
"""
|
||||
N, H, W, K = fragments.pix_to_face.shape
|
||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
|
||||
mask = fragments.pix_to_face >= 0
|
||||
|
||||
# The distance is negative if a pixel is inside a face and positive outside
|
||||
# the face. Therefore use -1.0 * fragments.dists to get the correct sign.
|
||||
prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
|
||||
|
||||
# The cumulative product ensures that alpha will be 0.0 if at least 1
|
||||
# face fully covers the pixel as for that face, prob will be 1.0.
|
||||
# This results in a multiplication by 0.0 because of the (1.0 - prob)
|
||||
# term. Therefore 1.0 - alpha will be 1.0.
|
||||
alpha = torch.prod((1.0 - prob), dim=-1)
|
||||
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
|
||||
pixel_colors[..., 3] = 1.0 - alpha
|
||||
pixel_colors[..., :3] = colors[..., 0, :]
|
||||
alpha = _sigmoid_alpha(fragments.dists, fragments.pix_to_face, blend_params.sigma)
|
||||
pixel_colors[..., 3] = alpha
|
||||
return pixel_colors
|
||||
|
||||
|
||||
|
@ -8,17 +8,24 @@ from test_blending import TestBlending
|
||||
|
||||
|
||||
def bm_blending() -> None:
|
||||
devices = ["cpu", "cuda"]
|
||||
devices = ["cuda"]
|
||||
kwargs_list = []
|
||||
num_meshes = [16]
|
||||
image_size = [128, 256]
|
||||
num_meshes = [8]
|
||||
image_size = [64, 128, 256]
|
||||
faces_per_pixel = [50, 100]
|
||||
test_cases = product(num_meshes, image_size, faces_per_pixel, devices)
|
||||
backend = ["pytorch", "custom"]
|
||||
test_cases = product(num_meshes, image_size, faces_per_pixel, devices, backend)
|
||||
|
||||
for case in test_cases:
|
||||
n, s, k, d = case
|
||||
n, s, k, d, b = case
|
||||
kwargs_list.append(
|
||||
{"num_meshes": n, "image_size": s, "faces_per_pixel": k, "device": d}
|
||||
{
|
||||
"num_meshes": n,
|
||||
"image_size": s,
|
||||
"faces_per_pixel": k,
|
||||
"device": d,
|
||||
"backend": b,
|
||||
}
|
||||
)
|
||||
|
||||
benchmark(
|
||||
@ -28,6 +35,7 @@ def bm_blending() -> None:
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
kwargs_list = [case for case in kwargs_list if case["backend"] == "pytorch"]
|
||||
benchmark(
|
||||
TestBlending.bm_softmax_blending,
|
||||
"SOFTMAX_BLENDING_PYTORCH",
|
||||
|
@ -44,6 +44,16 @@ def sigmoid_blend_naive_loop(colors, fragments, blend_params):
|
||||
return pixel_colors
|
||||
|
||||
|
||||
def sigmoid_alpha_blend_vectorized(colors, fragments, blend_params) -> torch.Tensor:
|
||||
N, H, W, K = fragments.pix_to_face.shape
|
||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=colors.device)
|
||||
mask = fragments.pix_to_face >= 0
|
||||
prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
|
||||
pixel_colors[..., :3] = colors[..., 0, :]
|
||||
pixel_colors[..., 3] = 1.0 - torch.prod((1.0 - prob), dim=-1)
|
||||
return pixel_colors
|
||||
|
||||
|
||||
def sigmoid_blend_naive_loop_backward(grad_images, images, fragments, blend_params):
|
||||
pix_to_face = fragments.pix_to_face
|
||||
dists = fragments.dists
|
||||
@ -136,10 +146,9 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
def _compare_impls(
|
||||
self, fn1, fn2, args1, args2, grad_var1=None, grad_var2=None, compare_grads=True
|
||||
):
|
||||
|
||||
out1 = fn1(*args1)
|
||||
out2 = fn2(*args2)
|
||||
self.assertTrue(torch.allclose(out1.cpu(), out2.cpu(), atol=1e-7))
|
||||
self.assertClose(out1.cpu()[..., 3], out2.cpu()[..., 3], atol=1e-7)
|
||||
|
||||
# Check gradients
|
||||
if not compare_grads:
|
||||
@ -151,9 +160,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
(out2 * grad_out).sum().backward()
|
||||
self.assertTrue(hasattr(grad_var2, "grad"))
|
||||
self.assertTrue(
|
||||
torch.allclose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5)
|
||||
)
|
||||
self.assertClose(grad_var1.grad.cpu(), grad_var2.grad.cpu(), atol=2e-5)
|
||||
|
||||
def test_hard_rgb_blend(self):
|
||||
N, H, W, K = 5, 10, 10, 20
|
||||
@ -223,18 +230,15 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
torch.manual_seed(231)
|
||||
F = 32 # number of faces in the mesh
|
||||
# The python loop version is really slow so only using small input sizes.
|
||||
N, S, K = 2, 10, 5
|
||||
N, S, K = 1, 4, 1
|
||||
device = torch.device("cuda")
|
||||
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||
pix_to_face = torch.randint(low=-1, high=F, size=(N, S, S, K), device=device)
|
||||
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||
empty = torch.tensor([], device=device)
|
||||
|
||||
# # randomly flip the sign of the distance
|
||||
# # (-) means inside triangle, (+) means outside triangle.
|
||||
random_sign_flip = torch.rand((N, S, S, K))
|
||||
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||
dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
|
||||
dists2 = dists1.detach().clone()
|
||||
dists1 = torch.randn(size=(N, S, S, K), device=device)
|
||||
dists2 = dists1.clone()
|
||||
dists1.requires_grad = True
|
||||
dists2.requires_grad = True
|
||||
|
||||
fragments1 = Fragments(
|
||||
@ -256,7 +260,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
self._compare_impls(
|
||||
sigmoid_alpha_blend,
|
||||
sigmoid_blend_naive_loop,
|
||||
sigmoid_alpha_blend_vectorized,
|
||||
args1,
|
||||
args2,
|
||||
dists1,
|
||||
@ -324,26 +328,21 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
num_meshes: int = 16,
|
||||
image_size: int = 128,
|
||||
faces_per_pixel: int = 100,
|
||||
device: str = "cpu",
|
||||
device="cuda",
|
||||
backend: str = "pytorch",
|
||||
):
|
||||
if torch.cuda.is_available() and "cuda:" in device:
|
||||
# If a device other than the default is used, set the device explicity.
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
device = torch.device(device)
|
||||
torch.manual_seed(231)
|
||||
|
||||
# Create dummy outputs of rasterization
|
||||
N, S, K = num_meshes, image_size, faces_per_pixel
|
||||
F = 32 # num faces in the mesh
|
||||
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||
pix_to_face = torch.randint(
|
||||
low=-1, high=F + 1, size=(N, S, S, K), device=device
|
||||
)
|
||||
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||
empty = torch.tensor([], device=device)
|
||||
|
||||
# # randomly flip the sign of the distance
|
||||
# # (-) means inside triangle, (+) means outside triangle.
|
||||
random_sign_flip = torch.rand((N, S, S, K), device=device)
|
||||
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||
dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
|
||||
fragments = Fragments(
|
||||
pix_to_face=pix_to_face,
|
||||
@ -352,11 +351,18 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
dists=dists1,
|
||||
)
|
||||
blend_params = BlendParams(sigma=1e-3)
|
||||
|
||||
blend_fn = (
|
||||
sigmoid_alpha_blend_vectorized
|
||||
if backend == "pytorch"
|
||||
else sigmoid_alpha_blend
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def fn():
|
||||
# test forward and backward pass
|
||||
images = sigmoid_alpha_blend(colors, fragments, blend_params)
|
||||
images = blend_fn(colors, fragments, blend_params)
|
||||
images.sum().backward()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@ -368,6 +374,7 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
image_size: int = 128,
|
||||
faces_per_pixel: int = 100,
|
||||
device: str = "cpu",
|
||||
backend: str = "pytorch",
|
||||
):
|
||||
if torch.cuda.is_available() and "cuda:" in device:
|
||||
# If a device other than the default is used, set the device explicity.
|
||||
@ -379,14 +386,12 @@ class TestBlending(TestCaseMixin, unittest.TestCase):
|
||||
# Create dummy outputs of rasterization
|
||||
N, S, K = num_meshes, image_size, faces_per_pixel
|
||||
F = 32 # num faces in the mesh
|
||||
pix_to_face = torch.randint(F + 1, size=(N, S, S, K), device=device) - 1
|
||||
pix_to_face = torch.randint(
|
||||
low=-1, high=F + 1, size=(N, S, S, K), device=device
|
||||
)
|
||||
colors = torch.randn((N, S, S, K, 3), device=device)
|
||||
empty = torch.tensor([], device=device)
|
||||
|
||||
# # randomly flip the sign of the distance
|
||||
# # (-) means inside triangle, (+) means outside triangle.
|
||||
random_sign_flip = torch.rand((N, S, S, K), device=device)
|
||||
random_sign_flip[random_sign_flip > 0.5] *= -1.0
|
||||
dists1 = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
|
||||
zbuf = torch.randn(size=(N, S, S, K), requires_grad=True, device=device)
|
||||
fragments = Fragments(
|
||||
|
Loading…
x
Reference in New Issue
Block a user