Accumulate points (#4)

Summary:
Code for accumulating points in the z-buffer in three ways:
1. weighted sum
2. normalised weighted sum
3. alpha compositing

Pull Request resolved: https://github.com/fairinternal/pytorch3d/pull/4

Reviewed By: nikhilaravi

Differential Revision: D20522422

Pulled By: gkioxari

fbshipit-source-id: 5023baa05f15e338f3821ef08f5552c2dcbfc06c
This commit is contained in:
Olivia 2020-03-19 11:19:39 -07:00 committed by Facebook GitHub Bot
parent 5218f45c2c
commit 53599770dd
21 changed files with 2466 additions and 4 deletions

View File

@ -673,9 +673,9 @@
"provenance": []
},
"kernelspec": {
"display_name": "pytorch3d (local)",
"display_name": "Python 3",
"language": "python",
"name": "pytorch3d_local"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -687,7 +687,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5+"
"version": "3.7.6"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,187 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <vector>
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void alphaCompositeCudaForwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const int num_pixels = C * W * H;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Iterate over each feature in each pixel
for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H);
int j = (pid % (W * H)) / H;
int i = (pid % (W * H)) % H;
// alphacomposite the different values
float cum_alpha = 1.;
// Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas[batch][k][j][i];
// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(
&result[batch][ch][j][i], features[ch][n_idx] * cum_alpha * alpha);
cum_alpha = cum_alpha * (1 - alpha);
}
}
}
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void alphaCompositeCudaBackwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const int num_pixels = C * W * H;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H);
int j = (pid % (W * H)) / H;
int i = (pid % (W * H)) % H;
// alphacomposite the different values
float cum_alpha = 1.;
// Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas[batch][k][j][i];
// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(
&grad_alphas[batch][k][j][i],
cum_alpha * features[ch][n_idx] * grad_outputs[batch][ch][j][i]);
atomicAdd(
&grad_features[ch][n_idx],
cum_alpha * alpha * grad_outputs[batch][ch][j][i]);
// Iterate over all (K-1) nearest points to update gradient
for (int t = 0; t < k; ++t) {
int t_idx = points_idx[batch][t][j][i];
// Sentinel value is -1, indicating no point overlaps this pixel
if (t_idx < 0) {
continue;
}
float alpha_tvalue = alphas[batch][t][j][i];
// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(
&grad_alphas[batch][t][j][i],
-grad_outputs[batch][ch][j][i] * features[ch][n_idx] * cum_alpha *
alpha / (1 - alpha_tvalue));
}
cum_alpha = cum_alpha * (1 - alphas[batch][k][j][i]);
}
}
}
torch::Tensor alphaCompositeCudaForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
auto result = torch::zeros({batch_size, C, H, W}, features.options());
const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
// clang-format on
return result;
}
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
auto grad_features = torch::zeros_like(features);
auto grad_alphas = torch::zeros_like(alphas);
const int64_t bs = alphas.size(0);
const dim3 threadsPerBlock(64);
const dim3 numBlocks(bs, 1024 / bs + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
// clang-format on
return std::make_tuple(grad_features, grad_alphas);
}

View File

@ -0,0 +1,110 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include "pytorch3d_cutils.h"
#include <vector>
// Perform alpha compositing of points in a z-buffer.
//
// Inputs:
// features: FloatTensor of shape (C, P) which gives the features
// of each point where C is the size of the feature and
// P the number of points.
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
// points_per_pixel is the number of points in the z-buffer
// sorted in z-order, and W is the image size.
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
// indices of the nearest points at each pixel, sorted in z-order.
// Returns:
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
// feature for each point. Concretely, it gives:
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k *
// features[c,points_idx[b,k,i,j]]
// where cum_alpha_k =
// alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
// CUDA declarations
#ifdef WITH_CUDA
torch::Tensor alphaCompositeCudaForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
#endif
// C++ declarations
torch::Tensor alphaCompositeCpuForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCpuBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
torch::Tensor alphaCompositeForward(
torch::Tensor& features,
torch::Tensor& alphas,
torch::Tensor& points_idx) {
features = features.contiguous();
alphas = alphas.contiguous();
points_idx = points_idx.contiguous();
if (features.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(features);
CHECK_CONTIGUOUS_CUDA(alphas);
CHECK_CONTIGUOUS_CUDA(points_idx);
#else
AT_ERROR("Not compiled with GPU support");
#endif
return alphaCompositeCudaForward(features, alphas, points_idx);
} else {
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(alphas);
CHECK_CONTIGUOUS(points_idx);
return alphaCompositeCpuForward(features, alphas, points_idx);
}
}
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
torch::Tensor& grad_outputs,
torch::Tensor& features,
torch::Tensor& alphas,
torch::Tensor& points_idx) {
grad_outputs = grad_outputs.contiguous();
features = features.contiguous();
alphas = alphas.contiguous();
points_idx = points_idx.contiguous();
if (grad_outputs.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(grad_outputs);
CHECK_CONTIGUOUS_CUDA(features);
CHECK_CONTIGUOUS_CUDA(alphas);
CHECK_CONTIGUOUS_CUDA(points_idx);
#else
AT_ERROR("Not compiled with GPU support");
#endif
return alphaCompositeCudaBackward(
grad_outputs, features, alphas, points_idx);
} else {
CHECK_CONTIGUOUS(grad_outputs);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(alphas);
CHECK_CONTIGUOUS(points_idx);
return alphaCompositeCpuBackward(
grad_outputs, features, alphas, points_idx);
}
}

View File

@ -0,0 +1,114 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <cmath>
#include <vector>
torch::Tensor alphaCompositeCpuForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t B = points_idx.size(0);
const int64_t K = points_idx.size(1);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
const int64_t C = features.size(0);
torch::Tensor result = torch::zeros({B, C, H, W}, features.options());
auto features_a = features.accessor<float, 2>();
auto alphas_a = alphas.accessor<float, 4>();
auto points_idx_a = points_idx.accessor<int64_t, 4>();
auto result_a = result.accessor<float, 4>();
// Iterate over the batch
for (int b = 0; b < B; ++b) {
// Iterate over the features
for (int c = 0; c < C; ++c) {
// Iterate through the horizontal lines of the image from top to bottom
for (int j = 0; j < H; ++j) {
// Iterate over pixels in a horizontal line, left to right
for (int i = 0; i < W; ++i) {
float cum_alpha = 1.;
// Iterate through the closest K points for this pixel
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas_a[b][k][j][i];
result_a[b][c][j][i] += cum_alpha * alpha * features_a[c][n_idx];
cum_alpha = cum_alpha * (1 - alpha);
}
}
}
}
}
return result;
}
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCpuBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
torch::Tensor grad_features = torch::zeros_like(features);
torch::Tensor grad_alphas = torch::zeros_like(alphas);
const int64_t B = points_idx.size(0);
const int64_t K = points_idx.size(1);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
const int64_t C = features.size(0);
auto grad_outputs_a = grad_outputs.accessor<float, 4>();
auto features_a = features.accessor<float, 2>();
auto alphas_a = alphas.accessor<float, 4>();
auto points_idx_a = points_idx.accessor<int64_t, 4>();
auto grad_features_a = grad_features.accessor<float, 2>();
auto grad_alphas_a = grad_alphas.accessor<float, 4>();
// Iterate over the batch
for (int b = 0; b < B; ++b) {
// Iterate over the features
for (int c = 0; c < C; ++c) {
// Iterate through the horizontal lines of the image from top to bottom
for (int j = 0; j < H; ++j) {
// Iterate over pixels in a horizontal line, left to right
for (int i = 0; i < W; ++i) {
float cum_alpha = 1.;
// Iterate through the closest K points for this pixel
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
// Sentinal value is -1, indicating no point overlaps this pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas_a[b][k][j][i];
grad_alphas_a[b][k][j][i] +=
grad_outputs_a[b][c][j][i] * features_a[c][n_idx] * cum_alpha;
grad_features_a[c][n_idx] +=
grad_outputs_a[b][c][j][i] * cum_alpha * alpha;
// Iterate over all (K-1) nearer points to update gradient
for (int t = 0; t < k; t++) {
int64_t t_idx = points_idx_a[b][t][j][i];
// Sentinal value is -1, indicating no point overlaps this pixel
if (t_idx < 0) {
continue;
}
float alpha_tvalue = alphas_a[b][t][j][i];
grad_alphas_a[b][t][j][i] -= grad_outputs_a[b][c][j][i] *
features_a[c][n_idx] * cum_alpha * alpha / (1 - alpha_tvalue);
}
cum_alpha = cum_alpha * (1 - alpha);
}
}
}
}
}
return std::make_tuple(grad_features, grad_alphas);
}

View File

@ -0,0 +1,202 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <vector>
__constant__ const float kEpsilon = 1e-4;
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void weightedSumNormCudaForwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const int num_pixels = C * W * H;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H);
int j = (pid % (W * H)) / H;
int i = (pid % (W * H)) % H;
// Store the accumulated alpha value
float cum_alpha = 0.;
// Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
cum_alpha += alphas[batch][k][j][i];
}
if (cum_alpha < kEpsilon) {
cum_alpha = kEpsilon;
}
// Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas[batch][k][j][i];
// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(
&result[batch][ch][j][i], features[ch][n_idx] * alpha / cum_alpha);
}
}
}
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void weightedSumNormCudaBackwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const int num_pixels = C * W * H;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H);
int j = (pid % (W * H)) / H;
int i = (pid % (W * H)) % H;
float sum_alpha = 0.;
float sum_alphafs = 0.;
// Iterate through the closest K points for this pixel to calculate the
// cumulative sum of the alphas for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
sum_alpha += alphas[batch][k][j][i];
sum_alphafs += alphas[batch][k][j][i] * features[ch][n_idx];
}
if (sum_alpha < kEpsilon) {
sum_alpha = kEpsilon;
}
// Iterate again through the closest K points for this pixel to calculate
// the gradient.
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas[batch][k][j][i];
// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(
&grad_alphas[batch][k][j][i],
(features[ch][n_idx] * sum_alpha - sum_alphafs) /
(sum_alpha * sum_alpha) * grad_outputs[batch][ch][j][i]);
atomicAdd(
&grad_features[ch][n_idx],
alpha * grad_outputs[batch][ch][j][i] / sum_alpha);
}
}
}
torch::Tensor weightedSumNormCudaForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
auto result = torch::zeros({batch_size, C, H, W}, features.options());
const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
// clang-format off
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
// clang-format on
return result;
}
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
auto grad_features = torch::zeros_like(features);
auto grad_alphas = torch::zeros_like(alphas);
const int64_t bs = points_idx.size(0);
const dim3 threadsPerBlock(64);
const dim3 numBlocks(bs, 1024 / bs + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
// clang-format on
return std::make_tuple(grad_features, grad_alphas);
}

View File

@ -0,0 +1,109 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include "pytorch3d_cutils.h"
#include <vector>
// Perform normalized weighted sum compositing of points in a z-buffer.
//
// Inputs:
// features: FloatTensor of shape (C, P) which gives the features
// of each point where C is the size of the feature and
// P the number of points.
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
// points_per_pixel is the number of points in the z-buffer
// sorted in z-order, and W is the image size.
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
// indices of the nearest points at each pixel, sorted in z-order.
// Returns:
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
// feature in each point. Concretely, it gives:
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
// features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j]
// CUDA declarations
#ifdef WITH_CUDA
torch::Tensor weightedSumNormCudaForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
#endif
// C++ declarations
torch::Tensor weightedSumNormCpuForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCpuBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
torch::Tensor weightedSumNormForward(
torch::Tensor& features,
torch::Tensor& alphas,
torch::Tensor& points_idx) {
features = features.contiguous();
alphas = alphas.contiguous();
points_idx = points_idx.contiguous();
if (features.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(features);
CHECK_CONTIGUOUS_CUDA(alphas);
CHECK_CONTIGUOUS_CUDA(points_idx);
#else
AT_ERROR("Not compiled with GPU support");
#endif
return weightedSumNormCudaForward(features, alphas, points_idx);
} else {
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(alphas);
CHECK_CONTIGUOUS(points_idx);
return weightedSumNormCpuForward(features, alphas, points_idx);
}
}
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormBackward(
torch::Tensor& grad_outputs,
torch::Tensor& features,
torch::Tensor& alphas,
torch::Tensor& points_idx) {
grad_outputs = grad_outputs.contiguous();
features = features.contiguous();
alphas = alphas.contiguous();
points_idx = points_idx.contiguous();
if (grad_outputs.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(grad_outputs);
CHECK_CONTIGUOUS_CUDA(features);
CHECK_CONTIGUOUS_CUDA(alphas);
CHECK_CONTIGUOUS_CUDA(points_idx);
#else
AT_ERROR("Not compiled with GPU support");
#endif
return weightedSumNormCudaBackward(
grad_outputs, features, alphas, points_idx);
} else {
CHECK_CONTIGUOUS(grad_outputs);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(alphas);
CHECK_CONTIGUOUS(points_idx);
return weightedSumNormCpuBackward(
grad_outputs, features, alphas, points_idx);
}
}

View File

@ -0,0 +1,134 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <cmath>
#include <vector>
// Epsilon float
const float kEps = 1e-4;
torch::Tensor weightedSumNormCpuForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t B = points_idx.size(0);
const int64_t K = points_idx.size(1);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
const int64_t C = features.size(0);
torch::Tensor result = torch::zeros({B, C, H, W}, features.options());
auto features_a = features.accessor<float, 2>();
auto alphas_a = alphas.accessor<float, 4>();
auto points_idx_a = points_idx.accessor<int64_t, 4>();
auto result_a = result.accessor<float, 4>();
// Iterate over the batch
for (int b = 0; b < B; ++b) {
// Iterate oer the features
for (int c = 0; c < C; ++c) {
// Iterate through the horizontal lines of the image from top to bottom
for (int j = 0; j < H; ++j) {
// Iterate over pixels in a horizontal line, left to right
for (int i = 0; i < W; ++i) {
float t_alpha = 0.;
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
if (n_idx < 0) {
continue;
}
t_alpha += alphas_a[b][k][j][i];
}
if (t_alpha < kEps) {
t_alpha = kEps;
}
// Iterate over the different zs to combine
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas_a[b][k][j][i];
result_a[b][c][j][i] += alpha * features_a[c][n_idx] / t_alpha;
}
}
}
}
}
return result;
}
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCpuBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
torch::Tensor grad_features = torch::zeros_like(features);
torch::Tensor grad_alphas = torch::zeros_like(alphas);
const int64_t B = points_idx.size(0);
const int64_t K = points_idx.size(1);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
const int64_t C = features.size(0);
auto grad_outputs_a = grad_outputs.accessor<float, 4>();
auto features_a = features.accessor<float, 2>();
auto alphas_a = alphas.accessor<float, 4>();
auto points_idx_a = points_idx.accessor<int64_t, 4>();
auto grad_features_a = grad_features.accessor<float, 2>();
auto grad_alphas_a = grad_alphas.accessor<float, 4>();
// Iterate over the batch
for (int b = 0; b < B; ++b) {
// Iterate oer the features
for (int c = 0; c < C; ++c) {
// Iterate through the horizontal lines of the image from top to bottom
for (int j = 0; j < H; ++j) {
// Iterate over pixels in a horizontal line, left to right
for (int i = 0; i < W; ++i) {
float t_alpha = 0.;
float t_alphafs = 0.;
// Iterate through the closest K points for this pixel
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
// Sentinel value is -1, indicating no point overlaps this pixel
if (n_idx < 0) {
continue;
}
t_alpha += alphas_a[b][k][j][i];
t_alphafs += alphas_a[b][k][j][i] * features_a[c][n_idx];
}
if (t_alpha < kEps) {
t_alpha = kEps;
}
// Iterate through the closest K points for this pixel ordered by z
// distance.
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas_a[b][k][j][i];
grad_alphas_a[b][k][j][i] += grad_outputs_a[b][c][j][i] *
(features_a[c][n_idx] * t_alpha - t_alphafs) /
(t_alpha * t_alpha);
grad_features_a[c][n_idx] +=
grad_outputs_a[b][c][j][i] * alpha / t_alpha;
}
}
}
}
}
return std::make_tuple(grad_features, grad_alphas);
}

View File

@ -0,0 +1,161 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <vector>
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void weightedSumCudaForwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
// clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const int num_pixels = C * W * H;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H);
int j = (pid % (W * H)) / H;
int i = (pid % (W * H)) % H;
// Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
// Accumulate the values
float alpha = alphas[batch][k][j][i];
// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(&result[batch][ch][j][i], features[ch][n_idx] * alpha);
}
}
}
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void weightedSumCudaBackwardKernel(
// clang-format off
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
// clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const int num_pixels = C * W * H;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Iterate over each pixel to compute the contribution to the
// gradient for the features and weights
for (int pid = tid; pid < num_pixels; pid += num_threads) {
int ch = pid / (W * H);
int j = (pid % (W * H)) / H;
int i = (pid % (W * H)) % H;
// Iterate through the closest K points for this pixel
for (int k = 0; k < points_idx.size(1); ++k) {
int n_idx = points_idx[batch][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas[batch][k][j][i];
// TODO(gkioxari) It might be more efficient to have threads write in a
// local variable, and move atomicAdd outside of the loop such that
// atomicAdd is executed once per thread.
atomicAdd(
&grad_alphas[batch][k][j][i],
features[ch][n_idx] * grad_outputs[batch][ch][j][i]);
atomicAdd(
&grad_features[ch][n_idx], alpha * grad_outputs[batch][ch][j][i]);
}
}
}
torch::Tensor weightedSumCudaForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
auto result = torch::zeros({batch_size, C, H, W}, features.options());
const dim3 threadsPerBlock(64);
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
// clang-format on
return result;
}
std::tuple<torch::Tensor, torch::Tensor> weightedSumCudaBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
auto grad_features = torch::zeros_like(features);
auto grad_alphas = torch::zeros_like(alphas);
const int64_t bs = points_idx.size(0);
const dim3 threadsPerBlock(64);
const dim3 numBlocks(bs, 1024 / bs + 1);
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
// doubles. Currently, support is for floats only.
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
// clang-format off
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
// clang-format on
return std::make_tuple(grad_features, grad_alphas);
}

View File

@ -0,0 +1,107 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include "pytorch3d_cutils.h"
#include <vector>
// Perform weighted sum compositing of points in a z-buffer.
//
// Inputs:
// features: FloatTensor of shape (C, P) which gives the features
// of each point where C is the size of the feature and
// P the number of points.
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
// points_per_pixel is the number of points in the z-buffer
// sorted in z-order, and W is the image size.
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
// indices of the nearest points at each pixel, sorted in z-order.
// Returns:
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
// feature in each point. Concretely, it gives:
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
// features[c,points_idx[b,k,i,j]]
// CUDA declarations
#ifdef WITH_CUDA
torch::Tensor weightedSumCudaForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
std::tuple<torch::Tensor, torch::Tensor> weightedSumCudaBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
#endif
// C++ declarations
torch::Tensor weightedSumCpuForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
std::tuple<torch::Tensor, torch::Tensor> weightedSumCpuBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx);
torch::Tensor weightedSumForward(
torch::Tensor& features,
torch::Tensor& alphas,
torch::Tensor& points_idx) {
features = features.contiguous();
alphas = alphas.contiguous();
points_idx = points_idx.contiguous();
if (features.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(features);
CHECK_CONTIGUOUS_CUDA(alphas);
CHECK_CONTIGUOUS_CUDA(points_idx);
#else
AT_ERROR("Not compiled with GPU support");
#endif
return weightedSumCudaForward(features, alphas, points_idx);
} else {
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(alphas);
CHECK_CONTIGUOUS(points_idx);
return weightedSumCpuForward(features, alphas, points_idx);
}
}
std::tuple<torch::Tensor, torch::Tensor> weightedSumBackward(
torch::Tensor& grad_outputs,
torch::Tensor& features,
torch::Tensor& alphas,
torch::Tensor& points_idx) {
grad_outputs = grad_outputs.contiguous();
features = features.contiguous();
alphas = alphas.contiguous();
points_idx = points_idx.contiguous();
if (grad_outputs.type().is_cuda()) {
#ifdef WITH_CUDA
CHECK_CONTIGUOUS_CUDA(grad_outputs);
CHECK_CONTIGUOUS_CUDA(features);
CHECK_CONTIGUOUS_CUDA(alphas);
CHECK_CONTIGUOUS_CUDA(points_idx);
#else
AT_ERROR("Not compiled with GPU support");
#endif
return weightedSumCudaBackward(grad_outputs, features, alphas, points_idx);
} else {
CHECK_CONTIGUOUS(grad_outputs);
CHECK_CONTIGUOUS(features);
CHECK_CONTIGUOUS(alphas);
CHECK_CONTIGUOUS(points_idx);
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
}
}

View File

@ -0,0 +1,98 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <cmath>
#include <vector>
torch::Tensor weightedSumCpuForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t B = points_idx.size(0);
const int64_t K = points_idx.size(1);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
const int64_t C = features.size(0);
torch::Tensor result = torch::zeros({B, C, H, W}, features.options());
auto features_a = features.accessor<float, 2>();
auto alphas_a = alphas.accessor<float, 4>();
auto points_idx_a = points_idx.accessor<int64_t, 4>();
auto result_a = result.accessor<float, 4>();
// Iterate over the batch
for (int b = 0; b < B; ++b) {
// Iterate over the features
for (int c = 0; c < C; ++c) {
// Iterate through the horizontal lines of the image from top to bottom
for (int j = 0; j < H; ++j) {
// Iterate over pixels in a horizontal line, left to right
for (int i = 0; i < W; ++i) {
// Iterate through the closest K points for this pixel
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas_a[b][k][j][i];
result_a[b][c][j][i] += alpha * features_a[c][n_idx];
}
}
}
}
}
return result;
}
std::tuple<torch::Tensor, torch::Tensor> weightedSumCpuBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t B = points_idx.size(0);
const int64_t K = points_idx.size(1);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
const int64_t C = features.size(0);
torch::Tensor grad_features = torch::zeros_like(features);
torch::Tensor grad_alphas = torch::zeros_like(alphas);
auto grad_outputs_a = grad_outputs.accessor<float, 4>();
auto features_a = features.accessor<float, 2>();
auto alphas_a = alphas.accessor<float, 4>();
auto points_idx_a = points_idx.accessor<int64_t, 4>();
auto grad_features_a = grad_features.accessor<float, 2>();
auto grad_alphas_a = grad_alphas.accessor<float, 4>();
// Iterate over the batch
for (int b = 0; b < B; ++b) {
// Iterate oer the features
for (int c = 0; c < C; ++c) {
// Iterate through the horizontal lines of the image from top to bottom
for (int j = 0; j < H; ++j) {
// Iterate over pixels in a horizontal line, left to right
for (int i = 0; i < W; ++i) {
// Iterate through the closest K points for this pixel
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
// Sentinal value is -1, indicating no point overlaps this pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas_a[b][k][j][i];
grad_alphas_a[b][k][j][i] +=
grad_outputs_a[b][c][j][i] * features_a[c][n_idx];
grad_features_a[c][n_idx] += grad_outputs_a[b][c][j][i] * alpha;
}
}
}
}
}
return std::make_tuple(grad_features, grad_alphas);
}

View File

@ -1,6 +1,9 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include "compositing/alpha_composite.h"
#include "compositing/norm_weighted_sum.h"
#include "compositing/weighted_sum.h"
#include "face_areas_normals/face_areas_normals.h"
#include "gather_scatter/gather_scatter.h"
#include "nearest_neighbor_points/nearest_neighbor_points.h"
@ -20,6 +23,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);
m.def("rasterize_meshes", &RasterizeMeshes);
// Accumulation functions
m.def("accum_weightedsumnorm", &weightedSumNormForward);
m.def("accum_weightedsum", &weightedSumForward);
m.def("accum_alphacomposite", &alphaCompositeForward);
m.def("accum_weightedsumnorm_backward", &weightedSumNormBackward);
m.def("accum_weightedsum_backward", &weightedSumBackward);
m.def("accum_alphacomposite_backward", &alphaCompositeBackward);
// These are only visible for testing; users should not call them directly
m.def("_rasterize_points_coarse", &RasterizePointsCoarse);
m.def("_rasterize_points_naive", &RasterizePointsNaive);

View File

@ -34,6 +34,14 @@ from .mesh import (
phong_shading,
rasterize_meshes,
)
from .points import (
AlphaCompositor,
NormWeightedCompositor,
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
rasterize_points,
)
from .utils import TensorProperties, convert_to_tensors_and_broadcast
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,255 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple
import torch
from pytorch3d import _C
# Example functions for blending the top K features per pixel using the outputs
# from rasterization.
# NOTE: All blending function should return a (N, H, W, C) tensor per batch element.
# This can be an image (C=3) or a set of features.
# Data class to store blending params with defaults
class CompositeParams(NamedTuple):
radius: float = 4.0 / 256.0
class _CompositeAlphaPoints(torch.autograd.Function):
"""
Composite features within a z-buffer using alpha compositing. Given a zbuffer
with corresponding features and weights, these values are accumulated according
to their weights such that features nearer in depth contribute more to the final
feature than ones further away.
Concretely this means:
weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]]
cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
Args:
features: Packed Tensor of shape (C, P) giving the features of each point.
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
image_size) giving the weight of each point in the z-buffer.
Values should be in the interval [0, 1].
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
giving the indices of the nearest points at each pixel, sorted in z-order.
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
This is weighted by alphas[n, k, y, x].
Returns:
weighted_fs: Tensor of shape (N, C, image_size, image_size)
giving the accumulated features at each point.
"""
@staticmethod
def forward(ctx, features, alphas, points_idx):
pt_cld = _C.accum_alphacomposite(features, alphas, points_idx)
ctx.save_for_backward(
features.clone(), alphas.clone(), points_idx.clone()
)
return pt_cld
@staticmethod
def backward(ctx, grad_output):
grad_features = None
grad_alphas = None
grad_points_idx = None
features, alphas, points_idx = ctx.saved_tensors
grad_features, grad_alphas = _C.accum_alphacomposite_backward(
grad_output, features, alphas, points_idx
)
return grad_features, grad_alphas, grad_points_idx, None
def alpha_composite(
pointsidx, alphas, pt_clds, blend_params=None
) -> torch.Tensor:
"""
Composite features within a z-buffer using alpha compositing. Given a zbuffer
with corresponding features and weights, these values are accumulated according
to their weights such that features nearer in depth contribute more to the final
feature than ones further away.
Concretely this means:
weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]]
cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
Args:
pt_clds: Tensor of shape (N, C, P) giving the features of each point (can use RGB for example).
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
image_size) giving the weight of each point in the z-buffer.
Values should be in the interval [0, 1].
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
giving the indices of the nearest points at each pixel, sorted in z-order.
Concretely pointsidx[n, k, y, x] = p means that features[n, :, p] is the feature of
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
This is weighted by alphas[n, k, y, x].
Returns:
Combined features: Tensor of shape (N, C, image_size, image_size)
giving the accumulated features at each point.
"""
return _CompositeAlphaPoints.apply(pt_clds, alphas, pointsidx)
class _CompositeNormWeightedSumPoints(torch.autograd.Function):
"""
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
with corresponding features and weights, these values are accumulated
according to their weights such that depth is ignored; the weights are used to perform
a weighted sum.
Concretely this means:
weighted_fs[b,c,i,j] =
sum_k alphas[b,k,i,j] * features[c,pointsidx[b,k,i,j]] / sum_k alphas[b,k,i,j]
Args:
features: Packed Tensor of shape (C, P) giving the features of each point.
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
image_size) giving the weight of each point in the z-buffer.
Values should be in the interval [0, 1].
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
giving the indices of the nearest points at each pixel, sorted in z-order.
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
This is weighted by alphas[n, k, y, x].
Returns:
weighted_fs: Tensor of shape (N, C, image_size, image_size)
giving the accumulated features at each point.
"""
@staticmethod
def forward(ctx, features, alphas, points_idx):
pt_cld = _C.accum_weightedsumnorm(features, alphas, points_idx)
ctx.save_for_backward(
features.clone(), alphas.clone(), points_idx.clone()
)
return pt_cld
@staticmethod
def backward(ctx, grad_output):
grad_features = None
grad_alphas = None
grad_points_idx = None
features, alphas, points_idx = ctx.saved_tensors
grad_features, grad_alphas = _C.accum_weightedsumnorm_backward(
grad_output, features, alphas, points_idx
)
return grad_features, grad_alphas, grad_points_idx, None
def norm_weighted_sum(
pointsidx, alphas, pt_clds, blend_params=None
) -> torch.Tensor:
"""
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
with corresponding features and weights, these values are accumulated
according to their weights such that depth is ignored; the weights are used to perform
a weighted sum.
Concretely this means:
weighted_fs[b,c,i,j] =
sum_k alphas[b,k,i,j] * features[c,pointsidx[b,k,i,j]] / sum_k alphas[b,k,i,j]
Args:
pt_clds: Packed feature tensor of shape (C, P) giving the features of each point
(can use RGB for example).
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
image_size) giving the weight of each point in the z-buffer.
Values should be in the interval [0, 1].
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
giving the indices of the nearest points at each pixel, sorted in z-order.
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
This is weighted by alphas[n, k, y, x].
Returns:
Combined features: Tensor of shape (N, C, image_size, image_size)
giving the accumulated features at each point.
"""
return _CompositeNormWeightedSumPoints.apply(pt_clds, alphas, pointsidx)
class _CompositeWeightedSumPoints(torch.autograd.Function):
"""
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
with corresponding features and weights, these values are accumulated
according to their weights such that depth is ignored; the weights are used to
perform a weighted sum. As opposed to norm weighted sum, the weights are not
normalized to sum to 1.
Concretely this means:
weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * features[c,pointsidx[b,k,i,j]]
Args:
features: Packed Tensor of shape (C, P) giving the features of each point.
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
image_size) giving the weight of each point in the z-buffer.
Values should be in the interval [0, 1].
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
giving the indices of the nearest points at each pixel, sorted in z-order.
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
This is weighted by alphas[n, k, y, x].
Returns:
weighted_fs: Tensor of shape (N, C, image_size, image_size)
giving the accumulated features at each point.
"""
@staticmethod
def forward(ctx, features, alphas, points_idx):
pt_cld = _C.accum_weightedsum(features, alphas, points_idx)
ctx.save_for_backward(
features.clone(), alphas.clone(), points_idx.clone()
)
return pt_cld
@staticmethod
def backward(ctx, grad_output):
grad_features = None
grad_alphas = None
grad_points_idx = None
features, alphas, points_idx = ctx.saved_tensors
grad_features, grad_alphas = _C.accum_weightedsum_backward(
grad_output, features, alphas, points_idx
)
return grad_features, grad_alphas, grad_points_idx, None
def weighted_sum(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Tensor:
"""
Composite features within a z-buffer using normalized weighted sum.
Args:
pt_clds: Packed Tensor of shape (C, P) giving the features of each point
(can use RGB for example).
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
image_size) giving the weight of each point in the z-buffer.
Values should be in the interval [0, 1].
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
giving the indices of the nearest points at each pixel, sorted in z-order.
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
This is weighted by alphas[n, k, y, x].
Returns:
Combined features: Tensor of shape (N, C, image_size, image_size)
giving the accumulated features at each point.
"""
return _CompositeWeightedSumPoints.apply(pt_clds, alphas, pointsidx)

View File

@ -1 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .compositor import AlphaCompositor, NormWeightedCompositor
from .rasterize_points import rasterize_points
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
from .renderer import PointsRenderer
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,51 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from ..compositing import CompositeParams, alpha_composite, norm_weighted_sum
# A compositor should take as input 3D points and some corresponding information.
# Given this information, the compositor can:
# - blend colors across the top K vertices at a pixel
class AlphaCompositor(nn.Module):
"""
Accumulate points using alpha compositing.
"""
def __init__(self, composite_params=None):
super().__init__()
self.composite_params = (
composite_params
if composite_params is not None
else CompositeParams()
)
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
images = alpha_composite(
fragments, alphas, ptclds, self.composite_params
)
return images
class NormWeightedCompositor(nn.Module):
"""
Accumulate points using a normalized weighted sum.
"""
def __init__(self, composite_params=None):
super().__init__()
self.composite_params = (
composite_params
if composite_params is not None
else CompositeParams()
)
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
images = norm_weighted_sum(
fragments, alphas, ptclds, self.composite_params
)
return images

View File

@ -0,0 +1,103 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Optional
import torch
import torch.nn as nn
from ..cameras import get_world_to_view_transform
from .rasterize_points import rasterize_points
# Class to store the outputs of point rasterization
class PointFragments(NamedTuple):
idx: torch.Tensor
zbuf: torch.Tensor
dists: torch.Tensor
# Class to store the point rasterization params with defaults
class PointsRasterizationSettings(NamedTuple):
image_size: int = 256
radius: float = 0.01
points_per_pixel: int = 8
bin_size: Optional[int] = None
max_points_per_bin: Optional[int] = None
class PointsRasterizer(nn.Module):
"""
This class implements methods for rasterizing a batch of pointclouds.
"""
def __init__(self, cameras, raster_settings=None):
"""
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-screen
transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
All these initial settings can be overridden by passing keyword
arguments to the forward function.
"""
super().__init__()
if raster_settings is None:
raster_settings = PointsRasterizationSettings()
self.cameras = cameras
self.raster_settings = raster_settings
def transform(self, point_clouds, **kwargs) -> torch.Tensor:
"""
Args:
point_clouds: a set of point clouds
Returns:
points_screen: the points with the vertex positions in screen
space
NOTE: keeping this as a separate function for readability but it could
be moved into forward.
"""
cameras = kwargs.get("cameras", self.cameras)
pts_world = point_clouds.points_padded()
pts_world_packed = point_clouds.points_packed()
pts_screen = cameras.transform_points(pts_world, **kwargs)
# NOTE: Retaining view space z coordinate for now.
# TODO: Remove this line when the convention for the z coordinate in
# the rasterizer is decided. i.e. retain z in view space or transform
# to a different range.
view_transform = get_world_to_view_transform(R=cameras.R, T=cameras.T)
verts_view = view_transform.transform_points(pts_world)
pts_screen[..., 2] = verts_view[..., 2]
# Offset points of input pointcloud to reuse cached padded/packed calculations.
pad_to_packed_idx = point_clouds.padded_to_packed_idx()
pts_screen_packed = pts_screen.view(-1, 3)[pad_to_packed_idx, :]
pts_packed_offset = pts_screen_packed - pts_world_packed
point_clouds = point_clouds.offset(pts_packed_offset)
return point_clouds
def forward(self, point_clouds, **kwargs) -> PointFragments:
"""
Args:
point_clouds: a set of point clouds with coordinates in world space.
Returns:
PointFragments: Rasterization outputs as a named tuple.
"""
points_screen = self.transform(point_clouds, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
idx, zbuf, dists2 = rasterize_points(
points_screen,
image_size=raster_settings.image_size,
radius=raster_settings.radius,
points_per_pixel=raster_settings.points_per_pixel,
bin_size=raster_settings.bin_size,
max_points_per_bin=raster_settings.max_points_per_bin,
)
return PointFragments(idx=idx, zbuf=zbuf, dists=dists2)

View File

@ -0,0 +1,56 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
# A renderer class should be initialized with a
# function for rasterization and a function for compositing.
# The rasterizer should:
# - transform inputs from world -> screen space
# - rasterize inputs
# - return fragments
# The compositor can take fragments as input along with any other properties of
# the scene and generate images.
# E.g. rasterize inputs and then shade
#
# fragments = self.rasterize(point_clouds)
# images = self.compositor(fragments, point_clouds)
# return images
class PointsRenderer(nn.Module):
"""
A class for rendering a batch of points. The class should
be initialized with a rasterizer and compositor class which each have a forward
function.
"""
def __init__(self, rasterizer, compositor):
super().__init__()
self.rasterizer = rasterizer
self.compositor = compositor
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(point_clouds, **kwargs)
# Construct weights based on the distance of a point to the true point.
# However, this could be done differently: e.g. predicted as opposed
# to a function of the weights.
r = self.rasterizer.raster_settings.radius
dists2 = fragments.dists.permute(0, 3, 1, 2)
weights = 1 - dists2 / (r * r)
images = self.compositor(
fragments.idx.long().permute(0, 3, 1, 2),
weights,
point_clouds.features_packed().permute(1, 0),
**kwargs
)
# permute so image comes at the end
images = images.permute(0, 2, 3, 1)
return images

View File

@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .meshes import Meshes, join_meshes
from .pointclouds import Pointclouds
from .textures import Textures
from .utils import (
list_to_packed,

442
tests/test_compositing.py Normal file
View File

@ -0,0 +1,442 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from pytorch3d.renderer.compositing import (
alpha_composite,
norm_weighted_sum,
weighted_sum,
)
class TestAccumulatePoints(unittest.TestCase):
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING)
@staticmethod
def accumulate_alphacomposite_python(points_idx, alphas, features):
"""
Naive pure PyTorch implementation of alpha_composite.
Inputs / Outputs: Same as function
"""
B, K, H, W = points_idx.size()
C = features.size(0)
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
for b in range(0, B):
for c in range(0, C):
for i in range(0, W):
for j in range(0, H):
t_alpha = 1
for k in range(0, K):
n_idx = points_idx[b, k, j, i]
if n_idx < 0:
continue
alpha = alphas[b, k, j, i]
output[b, c, j, i] += (
features[c, n_idx] * alpha * t_alpha
)
t_alpha = (1 - alpha) * t_alpha
return output
@staticmethod
def accumulate_weightedsum_python(points_idx, alphas, features):
"""
Naive pure PyTorch implementation of weighted_sum rasterization.
Inputs / Outputs: Same as function
"""
B, K, H, W = points_idx.size()
C = features.size(0)
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
for b in range(0, B):
for c in range(0, C):
for i in range(0, W):
for j in range(0, H):
for k in range(0, K):
n_idx = points_idx[b, k, j, i]
if n_idx < 0:
continue
alpha = alphas[b, k, j, i]
output[b, c, j, i] += features[c, n_idx] * alpha
return output
@staticmethod
def accumulate_weightedsumnorm_python(points_idx, alphas, features):
"""
Naive pure PyTorch implementation of norm_weighted_sum.
Inputs / Outputs: Same as function
"""
B, K, H, W = points_idx.size()
C = features.size(0)
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
for b in range(0, B):
for c in range(0, C):
for i in range(0, W):
for j in range(0, H):
t_alpha = 0
for k in range(0, K):
n_idx = points_idx[b, k, j, i]
if n_idx < 0:
continue
t_alpha += alphas[b, k, j, i]
t_alpha = max(t_alpha, 1e-4)
for k in range(0, K):
n_idx = points_idx[b, k, j, i]
if n_idx < 0:
continue
alpha = alphas[b, k, j, i]
output[b, c, j, i] += (
features[c, n_idx] * alpha / t_alpha
)
return output
def test_python(self):
device = torch.device("cpu")
self._simple_alphacomposite(
self.accumulate_alphacomposite_python, device
)
self._simple_wsum(self.accumulate_weightedsum_python, device)
self._simple_wsumnorm(self.accumulate_weightedsumnorm_python, device)
def test_cpu(self):
device = torch.device("cpu")
self._simple_alphacomposite(alpha_composite, device)
self._simple_wsum(weighted_sum, device)
self._simple_wsumnorm(norm_weighted_sum, device)
def test_cuda(self):
device = torch.device("cuda:0")
self._simple_alphacomposite(alpha_composite, device)
self._simple_wsum(weighted_sum, device)
self._simple_wsumnorm(norm_weighted_sum, device)
def test_python_vs_cpu_vs_cuda(self):
self._python_vs_cpu_vs_cuda(
self.accumulate_alphacomposite_python, alpha_composite
)
self._python_vs_cpu_vs_cuda(
self.accumulate_weightedsumnorm_python, norm_weighted_sum
)
self._python_vs_cpu_vs_cuda(
self.accumulate_weightedsum_python, weighted_sum
)
def _python_vs_cpu_vs_cuda(self, accumulate_func_python, accumulate_func):
torch.manual_seed(231)
device = torch.device("cpu")
W = 8
C = 3
P = 32
for d in ["cpu", "cuda"]:
# TODO(gkioxari) add torch.float64 to types after double precision
# support is added to atomicAdd
for t in [torch.float32]:
device = torch.device(d)
# Create values
alphas = torch.rand(2, 4, W, W, dtype=t).to(device)
alphas.requires_grad = True
alphas_cpu = alphas.detach().cpu()
alphas_cpu.requires_grad = True
features = torch.randn(C, P, dtype=t).to(device)
features.requires_grad = True
features_cpu = features.detach().cpu()
features_cpu.requires_grad = True
inds = torch.randint(P + 1, size=(2, 4, W, W)).to(device) - 1
inds_cpu = inds.detach().cpu()
args_cuda = (inds, alphas, features)
args_cpu = (inds_cpu, alphas_cpu, features_cpu)
self._compare_impls(
accumulate_func_python,
accumulate_func,
args_cpu,
args_cuda,
(alphas_cpu, features_cpu),
(alphas, features),
compare_grads=True,
)
def _compare_impls(
self, fn1, fn2, args1, args2, grads1, grads2, compare_grads=False
):
res1 = fn1(*args1)
res2 = fn2(*args2)
self.assertTrue(torch.allclose(res1.cpu(), res2.cpu(), atol=1e-6))
if not compare_grads:
return
# Compare gradients
torch.manual_seed(231)
grad_res = torch.randn_like(res1)
loss1 = (res1 * grad_res).sum()
loss1.backward()
grads1 = [gradsi.grad.data.clone().cpu() for gradsi in grads1]
grad_res = grad_res.to(res2)
loss2 = (res2 * grad_res).sum()
loss2.backward()
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2]
for i in range(0, len(grads1)):
self.assertTrue(
torch.allclose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)
)
def _simple_wsum(self, accum_func, device):
# Initialise variables
features = torch.Tensor(
[[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]
).to(device)
alphas = torch.Tensor(
[
[
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
]
]
).to(device)
points_idx = (
torch.Tensor(
[
[
# fmt: off
[
[0, 0, 0, 0], # noqa: E241, E201
[0, -1, -1, -1], # noqa: E241, E201
[0, 1, 1, 0], # noqa: E241, E201
[0, 0, 0, 0], # noqa: E241, E201
],
[
[2, 2, 2, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 2, -1, 2], # noqa: E241, E201
],
# fmt: on
]
]
)
.long()
.to(device)
)
result = accum_func(points_idx, alphas, features)
self.assertTrue(result.shape == (1, 2, 4, 4))
true_result = torch.Tensor(
[
[
[
[0.35, 0.35, 0.35, 0.35],
[0.35, 0.90, 0.90, 0.30],
[0.35, 1.30, 1.30, 0.35],
[0.35, 0.35, 0.05, 0.35],
],
[
[0.35, 0.35, 0.35, 0.35],
[0.35, 0.90, 0.90, 0.30],
[0.35, 1.30, 1.30, 0.35],
[0.35, 0.35, 0.05, 0.35],
],
]
]
).to(device)
self.assertTrue(
torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)
)
def _simple_wsumnorm(self, accum_func, device):
# Initialise variables
features = torch.Tensor(
[[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]
).to(device)
alphas = torch.Tensor(
[
[
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
]
]
).to(device)
# fmt: off
points_idx = (
torch.Tensor(
[
[
[
[0, 0, 0, 0], # noqa: E241, E201
[0, -1, -1, -1], # noqa: E241, E201
[0, 1, 1, 0], # noqa: E241, E201
[0, 0, 0, 0], # noqa: E241, E201
],
[
[2, 2, 2, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 2, -1, 2], # noqa: E241, E201
],
]
]
)
.long()
.to(device)
)
# fmt: on
result = accum_func(points_idx, alphas, features)
self.assertTrue(result.shape == (1, 2, 4, 4))
true_result = torch.Tensor(
[
[
[
[0.35, 0.35, 0.35, 0.35],
[0.35, 0.90, 0.90, 0.60],
[0.35, 0.65, 0.65, 0.35],
[0.35, 0.35, 0.10, 0.35],
],
[
[0.35, 0.35, 0.35, 0.35],
[0.35, 0.90, 0.90, 0.60],
[0.35, 0.65, 0.65, 0.35],
[0.35, 0.35, 0.10, 0.35],
],
]
]
).to(device)
self.assertTrue(
torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)
)
def _simple_alphacomposite(self, accum_func, device):
# Initialise variables
features = torch.Tensor(
[[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]
).to(device)
alphas = torch.Tensor(
[
[
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
[
[0.5, 0.5, 0.5, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5],
[0.5, 0.5, 0.5, 0.5],
],
]
]
).to(device)
# fmt: off
points_idx = (
torch.Tensor(
[
[
[
[0, 0, 0, 0], # noqa: E241, E201
[0, -1, -1, -1], # noqa: E241, E201
[0, 1, 1, 0], # noqa: E241, E201
[0, 0, 0, 0], # noqa: E241, E201
],
[
[2, 2, 2, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 3, 3, 2], # noqa: E241, E201
[2, 2, -1, 2], # noqa: E241, E201
],
]
]
)
.long()
.to(device)
)
# fmt: on
result = accum_func(points_idx, alphas, features)
self.assertTrue(result.shape == (1, 2, 4, 4))
true_result = torch.Tensor(
[
[
[
[0.20, 0.20, 0.20, 0.20],
[0.20, 0.90, 0.90, 0.30],
[0.20, 0.40, 0.40, 0.20],
[0.20, 0.20, 0.05, 0.20],
],
[
[0.20, 0.20, 0.20, 0.20],
[0.20, 0.90, 0.90, 0.30],
[0.20, 0.40, 0.40, 0.20],
[0.20, 0.20, 0.05, 0.20],
],
]
]
).to(device)
self.assertTrue((result == true_result).all().item())

View File

@ -76,7 +76,10 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
verts_torch = verts.detach().clone().to(dtype)
verts_torch.requires_grad = True
faces_torch = faces.detach().clone()
areas_torch, normals_torch = TestFaceAreasNormals.face_areas_normals_python(
(
areas_torch,
normals_torch,
) = TestFaceAreasNormals.face_areas_normals_python(
verts_torch, faces_torch
)
self.assertClose(areas_torch, areas, atol=1e-7)