mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Accumulate points (#4)
Summary: Code for accumulating points in the z-buffer in three ways: 1. weighted sum 2. normalised weighted sum 3. alpha compositing Pull Request resolved: https://github.com/fairinternal/pytorch3d/pull/4 Reviewed By: nikhilaravi Differential Revision: D20522422 Pulled By: gkioxari fbshipit-source-id: 5023baa05f15e338f3821ef08f5552c2dcbfc06c
This commit is contained in:
parent
5218f45c2c
commit
53599770dd
@ -673,9 +673,9 @@
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "pytorch3d (local)",
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "pytorch3d_local"
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@ -687,7 +687,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.5+"
|
||||
"version": "3.7.6"
|
||||
},
|
||||
"widgets": {
|
||||
"application/vnd.jupyter.widget-state+json": {
|
||||
|
303
docs/tutorials/render_coloured_points.ipynb
Normal file
303
docs/tutorials/render_coloured_points.ipynb
Normal file
File diff suppressed because one or more lines are too long
187
pytorch3d/csrc/compositing/alpha_composite.cu
Normal file
187
pytorch3d/csrc/compositing/alpha_composite.cu
Normal file
@ -0,0 +1,187 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
|
||||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||
// Currently, support is for floats only.
|
||||
__global__ void alphaCompositeCudaForwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
|
||||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Iterate over each feature in each pixel
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
|
||||
// alphacomposite the different values
|
||||
float cum_alpha = 1.;
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
int n_idx = points_idx[batch][k][j][i];
|
||||
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float alpha = alphas[batch][k][j][i];
|
||||
// TODO(gkioxari) It might be more efficient to have threads write in a
|
||||
// local variable, and move atomicAdd outside of the loop such that
|
||||
// atomicAdd is executed once per thread.
|
||||
atomicAdd(
|
||||
&result[batch][ch][j][i], features[ch][n_idx] * cum_alpha * alpha);
|
||||
cum_alpha = cum_alpha * (1 - alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||
// Currently, support is for floats only.
|
||||
__global__ void alphaCompositeCudaBackwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
|
||||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
|
||||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
|
||||
// alphacomposite the different values
|
||||
float cum_alpha = 1.;
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
int n_idx = points_idx[batch][k][j][i];
|
||||
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha = alphas[batch][k][j][i];
|
||||
|
||||
// TODO(gkioxari) It might be more efficient to have threads write in a
|
||||
// local variable, and move atomicAdd outside of the loop such that
|
||||
// atomicAdd is executed once per thread.
|
||||
atomicAdd(
|
||||
&grad_alphas[batch][k][j][i],
|
||||
cum_alpha * features[ch][n_idx] * grad_outputs[batch][ch][j][i]);
|
||||
atomicAdd(
|
||||
&grad_features[ch][n_idx],
|
||||
cum_alpha * alpha * grad_outputs[batch][ch][j][i]);
|
||||
|
||||
// Iterate over all (K-1) nearest points to update gradient
|
||||
for (int t = 0; t < k; ++t) {
|
||||
int t_idx = points_idx[batch][t][j][i];
|
||||
// Sentinel value is -1, indicating no point overlaps this pixel
|
||||
if (t_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha_tvalue = alphas[batch][t][j][i];
|
||||
// TODO(gkioxari) It might be more efficient to have threads write in a
|
||||
// local variable, and move atomicAdd outside of the loop such that
|
||||
// atomicAdd is executed once per thread.
|
||||
atomicAdd(
|
||||
&grad_alphas[batch][t][j][i],
|
||||
-grad_outputs[batch][ch][j][i] * features[ch][n_idx] * cum_alpha *
|
||||
alpha / (1 - alpha_tvalue));
|
||||
}
|
||||
|
||||
cum_alpha = cum_alpha * (1 - alphas[batch][k][j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor alphaCompositeCudaForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
auto result = torch::zeros({batch_size, C, H, W}, features.options());
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
alphaCompositeCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
|
||||
// clang-format on
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
auto grad_features = torch::zeros_like(features);
|
||||
auto grad_alphas = torch::zeros_like(alphas);
|
||||
|
||||
const int64_t bs = alphas.size(0);
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(bs, 1024 / bs + 1);
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
alphaCompositeCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
|
||||
// clang-format on
|
||||
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
110
pytorch3d/csrc/compositing/alpha_composite.h
Normal file
110
pytorch3d/csrc/compositing/alpha_composite.h
Normal file
@ -0,0 +1,110 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "pytorch3d_cutils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// Perform alpha compositing of points in a z-buffer.
|
||||
//
|
||||
// Inputs:
|
||||
// features: FloatTensor of shape (C, P) which gives the features
|
||||
// of each point where C is the size of the feature and
|
||||
// P the number of points.
|
||||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
|
||||
// points_per_pixel is the number of points in the z-buffer
|
||||
// sorted in z-order, and W is the image size.
|
||||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
||||
// indices of the nearest points at each pixel, sorted in z-order.
|
||||
// Returns:
|
||||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
|
||||
// feature for each point. Concretely, it gives:
|
||||
// weighted_fs[b,c,i,j] = sum_k cum_alpha_k *
|
||||
// features[c,points_idx[b,k,i,j]]
|
||||
// where cum_alpha_k =
|
||||
// alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
|
||||
|
||||
// CUDA declarations
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor alphaCompositeCudaForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCudaBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
#endif
|
||||
|
||||
// C++ declarations
|
||||
torch::Tensor alphaCompositeCpuForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCpuBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
|
||||
torch::Tensor alphaCompositeForward(
|
||||
torch::Tensor& features,
|
||||
torch::Tensor& alphas,
|
||||
torch::Tensor& points_idx) {
|
||||
features = features.contiguous();
|
||||
alphas = alphas.contiguous();
|
||||
points_idx = points_idx.contiguous();
|
||||
|
||||
if (features.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(features);
|
||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
return alphaCompositeCudaForward(features, alphas, points_idx);
|
||||
} else {
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(alphas);
|
||||
CHECK_CONTIGUOUS(points_idx);
|
||||
|
||||
return alphaCompositeCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
|
||||
torch::Tensor& grad_outputs,
|
||||
torch::Tensor& features,
|
||||
torch::Tensor& alphas,
|
||||
torch::Tensor& points_idx) {
|
||||
grad_outputs = grad_outputs.contiguous();
|
||||
features = features.contiguous();
|
||||
alphas = alphas.contiguous();
|
||||
points_idx = points_idx.contiguous();
|
||||
|
||||
if (grad_outputs.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(grad_outputs);
|
||||
CHECK_CONTIGUOUS_CUDA(features);
|
||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
|
||||
return alphaCompositeCudaBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
} else {
|
||||
CHECK_CONTIGUOUS(grad_outputs);
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(alphas);
|
||||
CHECK_CONTIGUOUS(points_idx);
|
||||
|
||||
return alphaCompositeCpuBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
}
|
114
pytorch3d/csrc/compositing/alpha_composite_cpu.cpp
Normal file
114
pytorch3d/csrc/compositing/alpha_composite_cpu.cpp
Normal file
@ -0,0 +1,114 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
torch::Tensor alphaCompositeCpuForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
const int64_t B = points_idx.size(0);
|
||||
const int64_t K = points_idx.size(1);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
const int64_t C = features.size(0);
|
||||
|
||||
torch::Tensor result = torch::zeros({B, C, H, W}, features.options());
|
||||
|
||||
auto features_a = features.accessor<float, 2>();
|
||||
auto alphas_a = alphas.accessor<float, 4>();
|
||||
auto points_idx_a = points_idx.accessor<int64_t, 4>();
|
||||
auto result_a = result.accessor<float, 4>();
|
||||
|
||||
// Iterate over the batch
|
||||
for (int b = 0; b < B; ++b) {
|
||||
// Iterate over the features
|
||||
for (int c = 0; c < C; ++c) {
|
||||
// Iterate through the horizontal lines of the image from top to bottom
|
||||
for (int j = 0; j < H; ++j) {
|
||||
// Iterate over pixels in a horizontal line, left to right
|
||||
for (int i = 0; i < W; ++i) {
|
||||
float cum_alpha = 1.;
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int64_t n_idx = points_idx_a[b][k][j][i];
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha = alphas_a[b][k][j][i];
|
||||
result_a[b][c][j][i] += cum_alpha * alpha * features_a[c][n_idx];
|
||||
cum_alpha = cum_alpha * (1 - alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCpuBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
torch::Tensor grad_features = torch::zeros_like(features);
|
||||
torch::Tensor grad_alphas = torch::zeros_like(alphas);
|
||||
|
||||
const int64_t B = points_idx.size(0);
|
||||
const int64_t K = points_idx.size(1);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
const int64_t C = features.size(0);
|
||||
|
||||
auto grad_outputs_a = grad_outputs.accessor<float, 4>();
|
||||
auto features_a = features.accessor<float, 2>();
|
||||
auto alphas_a = alphas.accessor<float, 4>();
|
||||
auto points_idx_a = points_idx.accessor<int64_t, 4>();
|
||||
auto grad_features_a = grad_features.accessor<float, 2>();
|
||||
auto grad_alphas_a = grad_alphas.accessor<float, 4>();
|
||||
|
||||
// Iterate over the batch
|
||||
for (int b = 0; b < B; ++b) {
|
||||
// Iterate over the features
|
||||
for (int c = 0; c < C; ++c) {
|
||||
// Iterate through the horizontal lines of the image from top to bottom
|
||||
for (int j = 0; j < H; ++j) {
|
||||
// Iterate over pixels in a horizontal line, left to right
|
||||
for (int i = 0; i < W; ++i) {
|
||||
float cum_alpha = 1.;
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int64_t n_idx = points_idx_a[b][k][j][i];
|
||||
// Sentinal value is -1, indicating no point overlaps this pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha = alphas_a[b][k][j][i];
|
||||
grad_alphas_a[b][k][j][i] +=
|
||||
grad_outputs_a[b][c][j][i] * features_a[c][n_idx] * cum_alpha;
|
||||
grad_features_a[c][n_idx] +=
|
||||
grad_outputs_a[b][c][j][i] * cum_alpha * alpha;
|
||||
|
||||
// Iterate over all (K-1) nearer points to update gradient
|
||||
for (int t = 0; t < k; t++) {
|
||||
int64_t t_idx = points_idx_a[b][t][j][i];
|
||||
// Sentinal value is -1, indicating no point overlaps this pixel
|
||||
if (t_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha_tvalue = alphas_a[b][t][j][i];
|
||||
grad_alphas_a[b][t][j][i] -= grad_outputs_a[b][c][j][i] *
|
||||
features_a[c][n_idx] * cum_alpha * alpha / (1 - alpha_tvalue);
|
||||
}
|
||||
|
||||
cum_alpha = cum_alpha * (1 - alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
202
pytorch3d/csrc/compositing/norm_weighted_sum.cu
Normal file
202
pytorch3d/csrc/compositing/norm_weighted_sum.cu
Normal file
@ -0,0 +1,202 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
|
||||
__constant__ const float kEpsilon = 1e-4;
|
||||
|
||||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||
// Currently, support is for floats only.
|
||||
__global__ void weightedSumNormCudaForwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
|
||||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
|
||||
// Store the accumulated alpha value
|
||||
float cum_alpha = 0.;
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
int n_idx = points_idx[batch][k][j][i];
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
cum_alpha += alphas[batch][k][j][i];
|
||||
}
|
||||
|
||||
if (cum_alpha < kEpsilon) {
|
||||
cum_alpha = kEpsilon;
|
||||
}
|
||||
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
int n_idx = points_idx[batch][k][j][i];
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha = alphas[batch][k][j][i];
|
||||
// TODO(gkioxari) It might be more efficient to have threads write in a
|
||||
// local variable, and move atomicAdd outside of the loop such that
|
||||
// atomicAdd is executed once per thread.
|
||||
atomicAdd(
|
||||
&result[batch][ch][j][i], features[ch][n_idx] * alpha / cum_alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||
// Currently, support is for floats only.
|
||||
__global__ void weightedSumNormCudaBackwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
|
||||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
|
||||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
|
||||
float sum_alpha = 0.;
|
||||
float sum_alphafs = 0.;
|
||||
// Iterate through the closest K points for this pixel to calculate the
|
||||
// cumulative sum of the alphas for this pixel
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
int n_idx = points_idx[batch][k][j][i];
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
sum_alpha += alphas[batch][k][j][i];
|
||||
sum_alphafs += alphas[batch][k][j][i] * features[ch][n_idx];
|
||||
}
|
||||
|
||||
if (sum_alpha < kEpsilon) {
|
||||
sum_alpha = kEpsilon;
|
||||
}
|
||||
|
||||
// Iterate again through the closest K points for this pixel to calculate
|
||||
// the gradient.
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
int n_idx = points_idx[batch][k][j][i];
|
||||
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha = alphas[batch][k][j][i];
|
||||
|
||||
// TODO(gkioxari) It might be more efficient to have threads write in a
|
||||
// local variable, and move atomicAdd outside of the loop such that
|
||||
// atomicAdd is executed once per thread.
|
||||
atomicAdd(
|
||||
&grad_alphas[batch][k][j][i],
|
||||
(features[ch][n_idx] * sum_alpha - sum_alphafs) /
|
||||
(sum_alpha * sum_alpha) * grad_outputs[batch][ch][j][i]);
|
||||
atomicAdd(
|
||||
&grad_features[ch][n_idx],
|
||||
alpha * grad_outputs[batch][ch][j][i] / sum_alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor weightedSumNormCudaForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
auto result = torch::zeros({batch_size, C, H, W}, features.options());
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
// clang-format off
|
||||
weightedSumNormCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
|
||||
// clang-format on
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
auto grad_features = torch::zeros_like(features);
|
||||
auto grad_alphas = torch::zeros_like(alphas);
|
||||
|
||||
const int64_t bs = points_idx.size(0);
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(bs, 1024 / bs + 1);
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
weightedSumNormCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
|
||||
// clang-format on
|
||||
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
109
pytorch3d/csrc/compositing/norm_weighted_sum.h
Normal file
109
pytorch3d/csrc/compositing/norm_weighted_sum.h
Normal file
@ -0,0 +1,109 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "pytorch3d_cutils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// Perform normalized weighted sum compositing of points in a z-buffer.
|
||||
//
|
||||
// Inputs:
|
||||
// features: FloatTensor of shape (C, P) which gives the features
|
||||
// of each point where C is the size of the feature and
|
||||
// P the number of points.
|
||||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
|
||||
// points_per_pixel is the number of points in the z-buffer
|
||||
// sorted in z-order, and W is the image size.
|
||||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
||||
// indices of the nearest points at each pixel, sorted in z-order.
|
||||
// Returns:
|
||||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
|
||||
// feature in each point. Concretely, it gives:
|
||||
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
|
||||
// features[c,points_idx[b,k,i,j]] / sum_k alphas[b,k,i,j]
|
||||
|
||||
// CUDA declarations
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor weightedSumNormCudaForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCudaBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
#endif
|
||||
|
||||
// C++ declarations
|
||||
torch::Tensor weightedSumNormCpuForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCpuBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
|
||||
torch::Tensor weightedSumNormForward(
|
||||
torch::Tensor& features,
|
||||
torch::Tensor& alphas,
|
||||
torch::Tensor& points_idx) {
|
||||
features = features.contiguous();
|
||||
alphas = alphas.contiguous();
|
||||
points_idx = points_idx.contiguous();
|
||||
|
||||
if (features.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(features);
|
||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
|
||||
return weightedSumNormCudaForward(features, alphas, points_idx);
|
||||
} else {
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(alphas);
|
||||
CHECK_CONTIGUOUS(points_idx);
|
||||
|
||||
return weightedSumNormCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormBackward(
|
||||
torch::Tensor& grad_outputs,
|
||||
torch::Tensor& features,
|
||||
torch::Tensor& alphas,
|
||||
torch::Tensor& points_idx) {
|
||||
grad_outputs = grad_outputs.contiguous();
|
||||
features = features.contiguous();
|
||||
alphas = alphas.contiguous();
|
||||
points_idx = points_idx.contiguous();
|
||||
|
||||
if (grad_outputs.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(grad_outputs);
|
||||
CHECK_CONTIGUOUS_CUDA(features);
|
||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
|
||||
return weightedSumNormCudaBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
} else {
|
||||
CHECK_CONTIGUOUS(grad_outputs);
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(alphas);
|
||||
CHECK_CONTIGUOUS(points_idx);
|
||||
|
||||
return weightedSumNormCpuBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
}
|
134
pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp
Normal file
134
pytorch3d/csrc/compositing/norm_weighted_sum_cpu.cpp
Normal file
@ -0,0 +1,134 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
// Epsilon float
|
||||
const float kEps = 1e-4;
|
||||
|
||||
torch::Tensor weightedSumNormCpuForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
const int64_t B = points_idx.size(0);
|
||||
const int64_t K = points_idx.size(1);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
const int64_t C = features.size(0);
|
||||
|
||||
torch::Tensor result = torch::zeros({B, C, H, W}, features.options());
|
||||
|
||||
auto features_a = features.accessor<float, 2>();
|
||||
auto alphas_a = alphas.accessor<float, 4>();
|
||||
auto points_idx_a = points_idx.accessor<int64_t, 4>();
|
||||
auto result_a = result.accessor<float, 4>();
|
||||
|
||||
// Iterate over the batch
|
||||
for (int b = 0; b < B; ++b) {
|
||||
// Iterate oer the features
|
||||
for (int c = 0; c < C; ++c) {
|
||||
// Iterate through the horizontal lines of the image from top to bottom
|
||||
for (int j = 0; j < H; ++j) {
|
||||
// Iterate over pixels in a horizontal line, left to right
|
||||
for (int i = 0; i < W; ++i) {
|
||||
float t_alpha = 0.;
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int64_t n_idx = points_idx_a[b][k][j][i];
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
t_alpha += alphas_a[b][k][j][i];
|
||||
}
|
||||
|
||||
if (t_alpha < kEps) {
|
||||
t_alpha = kEps;
|
||||
}
|
||||
|
||||
// Iterate over the different zs to combine
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int64_t n_idx = points_idx_a[b][k][j][i];
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha = alphas_a[b][k][j][i];
|
||||
result_a[b][c][j][i] += alpha * features_a[c][n_idx] / t_alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumNormCpuBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
torch::Tensor grad_features = torch::zeros_like(features);
|
||||
torch::Tensor grad_alphas = torch::zeros_like(alphas);
|
||||
|
||||
const int64_t B = points_idx.size(0);
|
||||
const int64_t K = points_idx.size(1);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
const int64_t C = features.size(0);
|
||||
|
||||
auto grad_outputs_a = grad_outputs.accessor<float, 4>();
|
||||
auto features_a = features.accessor<float, 2>();
|
||||
auto alphas_a = alphas.accessor<float, 4>();
|
||||
auto points_idx_a = points_idx.accessor<int64_t, 4>();
|
||||
auto grad_features_a = grad_features.accessor<float, 2>();
|
||||
auto grad_alphas_a = grad_alphas.accessor<float, 4>();
|
||||
|
||||
// Iterate over the batch
|
||||
for (int b = 0; b < B; ++b) {
|
||||
// Iterate oer the features
|
||||
for (int c = 0; c < C; ++c) {
|
||||
// Iterate through the horizontal lines of the image from top to bottom
|
||||
for (int j = 0; j < H; ++j) {
|
||||
// Iterate over pixels in a horizontal line, left to right
|
||||
for (int i = 0; i < W; ++i) {
|
||||
float t_alpha = 0.;
|
||||
float t_alphafs = 0.;
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int64_t n_idx = points_idx_a[b][k][j][i];
|
||||
// Sentinel value is -1, indicating no point overlaps this pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
t_alpha += alphas_a[b][k][j][i];
|
||||
t_alphafs += alphas_a[b][k][j][i] * features_a[c][n_idx];
|
||||
}
|
||||
|
||||
if (t_alpha < kEps) {
|
||||
t_alpha = kEps;
|
||||
}
|
||||
|
||||
// Iterate through the closest K points for this pixel ordered by z
|
||||
// distance.
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int64_t n_idx = points_idx_a[b][k][j][i];
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha = alphas_a[b][k][j][i];
|
||||
grad_alphas_a[b][k][j][i] += grad_outputs_a[b][c][j][i] *
|
||||
(features_a[c][n_idx] * t_alpha - t_alphafs) /
|
||||
(t_alpha * t_alpha);
|
||||
grad_features_a[c][n_idx] +=
|
||||
grad_outputs_a[b][c][j][i] * alpha / t_alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
161
pytorch3d/csrc/compositing/weighted_sum.cu
Normal file
161
pytorch3d/csrc/compositing/weighted_sum.cu
Normal file
@ -0,0 +1,161 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
|
||||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||
// Currently, support is for floats only.
|
||||
__global__ void weightedSumCudaForwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> result,
|
||||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = result.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Parallelize over each feature in each pixel in images of size H * W,
|
||||
// for each image in the batch of size batch_size
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
int n_idx = points_idx[batch][k][j][i];
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Accumulate the values
|
||||
float alpha = alphas[batch][k][j][i];
|
||||
// TODO(gkioxari) It might be more efficient to have threads write in a
|
||||
// local variable, and move atomicAdd outside of the loop such that
|
||||
// atomicAdd is executed once per thread.
|
||||
atomicAdd(&result[batch][ch][j][i], features[ch][n_idx] * alpha);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||
// Currently, support is for floats only.
|
||||
__global__ void weightedSumCudaBackwardKernel(
|
||||
// clang-format off
|
||||
torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> grad_features,
|
||||
torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_alphas,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> grad_outputs,
|
||||
const torch::PackedTensorAccessor<float, 2, torch::RestrictPtrTraits, size_t> features,
|
||||
const torch::PackedTensorAccessor<float, 4, torch::RestrictPtrTraits, size_t> alphas,
|
||||
const torch::PackedTensorAccessor<int64_t, 4, torch::RestrictPtrTraits, size_t> points_idx) {
|
||||
// clang-format on
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
// Get the batch and index
|
||||
const int batch = blockIdx.x;
|
||||
|
||||
const int num_pixels = C * W * H;
|
||||
const int num_threads = gridDim.y * blockDim.x;
|
||||
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// Iterate over each pixel to compute the contribution to the
|
||||
// gradient for the features and weights
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
int ch = pid / (W * H);
|
||||
int j = (pid % (W * H)) / H;
|
||||
int i = (pid % (W * H)) % H;
|
||||
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < points_idx.size(1); ++k) {
|
||||
int n_idx = points_idx[batch][k][j][i];
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
float alpha = alphas[batch][k][j][i];
|
||||
|
||||
// TODO(gkioxari) It might be more efficient to have threads write in a
|
||||
// local variable, and move atomicAdd outside of the loop such that
|
||||
// atomicAdd is executed once per thread.
|
||||
atomicAdd(
|
||||
&grad_alphas[batch][k][j][i],
|
||||
features[ch][n_idx] * grad_outputs[batch][ch][j][i]);
|
||||
atomicAdd(
|
||||
&grad_features[ch][n_idx], alpha * grad_outputs[batch][ch][j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor weightedSumCudaForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
const int64_t batch_size = points_idx.size(0);
|
||||
const int64_t C = features.size(0);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
|
||||
auto result = torch::zeros({batch_size, C, H, W}, features.options());
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(batch_size, 1024 / batch_size + 1);
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
weightedSumCudaForwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
result.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
|
||||
// clang-format on
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumCudaBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
auto grad_features = torch::zeros_like(features);
|
||||
auto grad_alphas = torch::zeros_like(alphas);
|
||||
|
||||
const int64_t bs = points_idx.size(0);
|
||||
|
||||
const dim3 threadsPerBlock(64);
|
||||
const dim3 numBlocks(bs, 1024 / bs + 1);
|
||||
|
||||
// TODO(gkioxari) add AT_DISPATCH_FLOATING_TYPES once atomicAdd supports
|
||||
// doubles. Currently, support is for floats only.
|
||||
weightedSumCudaBackwardKernel<<<numBlocks, threadsPerBlock>>>(
|
||||
// clang-format off
|
||||
grad_features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
grad_alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
grad_outputs.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
features.packed_accessor<float, 2, torch::RestrictPtrTraits, size_t>(),
|
||||
alphas.packed_accessor<float, 4, torch::RestrictPtrTraits, size_t>(),
|
||||
points_idx.packed_accessor<int64_t, 4, torch::RestrictPtrTraits, size_t>());
|
||||
// clang-format on
|
||||
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
107
pytorch3d/csrc/compositing/weighted_sum.h
Normal file
107
pytorch3d/csrc/compositing/weighted_sum.h
Normal file
@ -0,0 +1,107 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "pytorch3d_cutils.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
// Perform weighted sum compositing of points in a z-buffer.
|
||||
//
|
||||
// Inputs:
|
||||
// features: FloatTensor of shape (C, P) which gives the features
|
||||
// of each point where C is the size of the feature and
|
||||
// P the number of points.
|
||||
// alphas: FloatTensor of shape (N, points_per_pixel, W, W) where
|
||||
// points_per_pixel is the number of points in the z-buffer
|
||||
// sorted in z-order, and W is the image size.
|
||||
// points_idx: IntTensor of shape (N, points_per_pixel, W, W) giving the
|
||||
// indices of the nearest points at each pixel, sorted in z-order.
|
||||
// Returns:
|
||||
// weighted_fs: FloatTensor of shape (N, C, W, W) giving the accumulated
|
||||
// feature in each point. Concretely, it gives:
|
||||
// weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] *
|
||||
// features[c,points_idx[b,k,i,j]]
|
||||
|
||||
// CUDA declarations
|
||||
#ifdef WITH_CUDA
|
||||
torch::Tensor weightedSumCudaForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumCudaBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
#endif
|
||||
|
||||
// C++ declarations
|
||||
torch::Tensor weightedSumCpuForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumCpuBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx);
|
||||
|
||||
torch::Tensor weightedSumForward(
|
||||
torch::Tensor& features,
|
||||
torch::Tensor& alphas,
|
||||
torch::Tensor& points_idx) {
|
||||
features = features.contiguous();
|
||||
alphas = alphas.contiguous();
|
||||
points_idx = points_idx.contiguous();
|
||||
|
||||
if (features.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(features);
|
||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
|
||||
return weightedSumCudaForward(features, alphas, points_idx);
|
||||
} else {
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(alphas);
|
||||
CHECK_CONTIGUOUS(points_idx);
|
||||
|
||||
return weightedSumCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumBackward(
|
||||
torch::Tensor& grad_outputs,
|
||||
torch::Tensor& features,
|
||||
torch::Tensor& alphas,
|
||||
torch::Tensor& points_idx) {
|
||||
grad_outputs = grad_outputs.contiguous();
|
||||
features = features.contiguous();
|
||||
alphas = alphas.contiguous();
|
||||
points_idx = points_idx.contiguous();
|
||||
|
||||
if (grad_outputs.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CONTIGUOUS_CUDA(grad_outputs);
|
||||
CHECK_CONTIGUOUS_CUDA(features);
|
||||
CHECK_CONTIGUOUS_CUDA(alphas);
|
||||
CHECK_CONTIGUOUS_CUDA(points_idx);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
|
||||
return weightedSumCudaBackward(grad_outputs, features, alphas, points_idx);
|
||||
} else {
|
||||
CHECK_CONTIGUOUS(grad_outputs);
|
||||
CHECK_CONTIGUOUS(features);
|
||||
CHECK_CONTIGUOUS(alphas);
|
||||
CHECK_CONTIGUOUS(points_idx);
|
||||
|
||||
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
}
|
98
pytorch3d/csrc/compositing/weighted_sum_cpu.cpp
Normal file
98
pytorch3d/csrc/compositing/weighted_sum_cpu.cpp
Normal file
@ -0,0 +1,98 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
torch::Tensor weightedSumCpuForward(
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
const int64_t B = points_idx.size(0);
|
||||
const int64_t K = points_idx.size(1);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
const int64_t C = features.size(0);
|
||||
|
||||
torch::Tensor result = torch::zeros({B, C, H, W}, features.options());
|
||||
|
||||
auto features_a = features.accessor<float, 2>();
|
||||
auto alphas_a = alphas.accessor<float, 4>();
|
||||
auto points_idx_a = points_idx.accessor<int64_t, 4>();
|
||||
auto result_a = result.accessor<float, 4>();
|
||||
|
||||
// Iterate over the batch
|
||||
for (int b = 0; b < B; ++b) {
|
||||
// Iterate over the features
|
||||
for (int c = 0; c < C; ++c) {
|
||||
// Iterate through the horizontal lines of the image from top to bottom
|
||||
for (int j = 0; j < H; ++j) {
|
||||
// Iterate over pixels in a horizontal line, left to right
|
||||
for (int i = 0; i < W; ++i) {
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int64_t n_idx = points_idx_a[b][k][j][i];
|
||||
// Sentinel value is -1 indicating no point overlaps the pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float alpha = alphas_a[b][k][j][i];
|
||||
result_a[b][c][j][i] += alpha * features_a[c][n_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> weightedSumCpuBackward(
|
||||
const torch::Tensor& grad_outputs,
|
||||
const torch::Tensor& features,
|
||||
const torch::Tensor& alphas,
|
||||
const torch::Tensor& points_idx) {
|
||||
const int64_t B = points_idx.size(0);
|
||||
const int64_t K = points_idx.size(1);
|
||||
const int64_t H = points_idx.size(2);
|
||||
const int64_t W = points_idx.size(3);
|
||||
const int64_t C = features.size(0);
|
||||
|
||||
torch::Tensor grad_features = torch::zeros_like(features);
|
||||
torch::Tensor grad_alphas = torch::zeros_like(alphas);
|
||||
|
||||
auto grad_outputs_a = grad_outputs.accessor<float, 4>();
|
||||
auto features_a = features.accessor<float, 2>();
|
||||
auto alphas_a = alphas.accessor<float, 4>();
|
||||
auto points_idx_a = points_idx.accessor<int64_t, 4>();
|
||||
auto grad_features_a = grad_features.accessor<float, 2>();
|
||||
auto grad_alphas_a = grad_alphas.accessor<float, 4>();
|
||||
|
||||
// Iterate over the batch
|
||||
for (int b = 0; b < B; ++b) {
|
||||
// Iterate oer the features
|
||||
for (int c = 0; c < C; ++c) {
|
||||
// Iterate through the horizontal lines of the image from top to bottom
|
||||
for (int j = 0; j < H; ++j) {
|
||||
// Iterate over pixels in a horizontal line, left to right
|
||||
for (int i = 0; i < W; ++i) {
|
||||
// Iterate through the closest K points for this pixel
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int64_t n_idx = points_idx_a[b][k][j][i];
|
||||
// Sentinal value is -1, indicating no point overlaps this pixel
|
||||
if (n_idx < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
float alpha = alphas_a[b][k][j][i];
|
||||
grad_alphas_a[b][k][j][i] +=
|
||||
grad_outputs_a[b][c][j][i] * features_a[c][n_idx];
|
||||
grad_features_a[c][n_idx] += grad_outputs_a[b][c][j][i] * alpha;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_tuple(grad_features, grad_alphas);
|
||||
}
|
@ -1,6 +1,9 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include "compositing/alpha_composite.h"
|
||||
#include "compositing/norm_weighted_sum.h"
|
||||
#include "compositing/weighted_sum.h"
|
||||
#include "face_areas_normals/face_areas_normals.h"
|
||||
#include "gather_scatter/gather_scatter.h"
|
||||
#include "nearest_neighbor_points/nearest_neighbor_points.h"
|
||||
@ -20,6 +23,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);
|
||||
m.def("rasterize_meshes", &RasterizeMeshes);
|
||||
|
||||
// Accumulation functions
|
||||
m.def("accum_weightedsumnorm", &weightedSumNormForward);
|
||||
m.def("accum_weightedsum", &weightedSumForward);
|
||||
m.def("accum_alphacomposite", &alphaCompositeForward);
|
||||
m.def("accum_weightedsumnorm_backward", &weightedSumNormBackward);
|
||||
m.def("accum_weightedsum_backward", &weightedSumBackward);
|
||||
m.def("accum_alphacomposite_backward", &alphaCompositeBackward);
|
||||
|
||||
// These are only visible for testing; users should not call them directly
|
||||
m.def("_rasterize_points_coarse", &RasterizePointsCoarse);
|
||||
m.def("_rasterize_points_naive", &RasterizePointsNaive);
|
||||
|
@ -34,6 +34,14 @@ from .mesh import (
|
||||
phong_shading,
|
||||
rasterize_meshes,
|
||||
)
|
||||
from .points import (
|
||||
AlphaCompositor,
|
||||
NormWeightedCompositor,
|
||||
PointsRasterizationSettings,
|
||||
PointsRasterizer,
|
||||
PointsRenderer,
|
||||
rasterize_points,
|
||||
)
|
||||
from .utils import TensorProperties, convert_to_tensors_and_broadcast
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
255
pytorch3d/renderer/compositing.py
Normal file
255
pytorch3d/renderer/compositing.py
Normal file
@ -0,0 +1,255 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from typing import NamedTuple
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
|
||||
# Example functions for blending the top K features per pixel using the outputs
|
||||
# from rasterization.
|
||||
# NOTE: All blending function should return a (N, H, W, C) tensor per batch element.
|
||||
# This can be an image (C=3) or a set of features.
|
||||
|
||||
|
||||
# Data class to store blending params with defaults
|
||||
class CompositeParams(NamedTuple):
|
||||
radius: float = 4.0 / 256.0
|
||||
|
||||
|
||||
class _CompositeAlphaPoints(torch.autograd.Function):
|
||||
"""
|
||||
Composite features within a z-buffer using alpha compositing. Given a zbuffer
|
||||
with corresponding features and weights, these values are accumulated according
|
||||
to their weights such that features nearer in depth contribute more to the final
|
||||
feature than ones further away.
|
||||
|
||||
Concretely this means:
|
||||
weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]]
|
||||
cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
|
||||
|
||||
Args:
|
||||
features: Packed Tensor of shape (C, P) giving the features of each point.
|
||||
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
|
||||
image_size) giving the weight of each point in the z-buffer.
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
weighted_fs: Tensor of shape (N, C, image_size, image_size)
|
||||
giving the accumulated features at each point.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features, alphas, points_idx):
|
||||
pt_cld = _C.accum_alphacomposite(features, alphas, points_idx)
|
||||
|
||||
ctx.save_for_backward(
|
||||
features.clone(), alphas.clone(), points_idx.clone()
|
||||
)
|
||||
return pt_cld
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_features = None
|
||||
grad_alphas = None
|
||||
grad_points_idx = None
|
||||
features, alphas, points_idx = ctx.saved_tensors
|
||||
|
||||
grad_features, grad_alphas = _C.accum_alphacomposite_backward(
|
||||
grad_output, features, alphas, points_idx
|
||||
)
|
||||
|
||||
return grad_features, grad_alphas, grad_points_idx, None
|
||||
|
||||
|
||||
def alpha_composite(
|
||||
pointsidx, alphas, pt_clds, blend_params=None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Composite features within a z-buffer using alpha compositing. Given a zbuffer
|
||||
with corresponding features and weights, these values are accumulated according
|
||||
to their weights such that features nearer in depth contribute more to the final
|
||||
feature than ones further away.
|
||||
|
||||
Concretely this means:
|
||||
weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]]
|
||||
cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
|
||||
|
||||
|
||||
Args:
|
||||
pt_clds: Tensor of shape (N, C, P) giving the features of each point (can use RGB for example).
|
||||
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
|
||||
image_size) giving the weight of each point in the z-buffer.
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[n, :, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
Combined features: Tensor of shape (N, C, image_size, image_size)
|
||||
giving the accumulated features at each point.
|
||||
"""
|
||||
return _CompositeAlphaPoints.apply(pt_clds, alphas, pointsidx)
|
||||
|
||||
|
||||
class _CompositeNormWeightedSumPoints(torch.autograd.Function):
|
||||
"""
|
||||
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
|
||||
with corresponding features and weights, these values are accumulated
|
||||
according to their weights such that depth is ignored; the weights are used to perform
|
||||
a weighted sum.
|
||||
|
||||
Concretely this means:
|
||||
weighted_fs[b,c,i,j] =
|
||||
sum_k alphas[b,k,i,j] * features[c,pointsidx[b,k,i,j]] / sum_k alphas[b,k,i,j]
|
||||
|
||||
Args:
|
||||
features: Packed Tensor of shape (C, P) giving the features of each point.
|
||||
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
|
||||
image_size) giving the weight of each point in the z-buffer.
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
weighted_fs: Tensor of shape (N, C, image_size, image_size)
|
||||
giving the accumulated features at each point.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features, alphas, points_idx):
|
||||
pt_cld = _C.accum_weightedsumnorm(features, alphas, points_idx)
|
||||
|
||||
ctx.save_for_backward(
|
||||
features.clone(), alphas.clone(), points_idx.clone()
|
||||
)
|
||||
return pt_cld
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_features = None
|
||||
grad_alphas = None
|
||||
grad_points_idx = None
|
||||
features, alphas, points_idx = ctx.saved_tensors
|
||||
|
||||
grad_features, grad_alphas = _C.accum_weightedsumnorm_backward(
|
||||
grad_output, features, alphas, points_idx
|
||||
)
|
||||
|
||||
return grad_features, grad_alphas, grad_points_idx, None
|
||||
|
||||
|
||||
def norm_weighted_sum(
|
||||
pointsidx, alphas, pt_clds, blend_params=None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
|
||||
with corresponding features and weights, these values are accumulated
|
||||
according to their weights such that depth is ignored; the weights are used to perform
|
||||
a weighted sum.
|
||||
|
||||
Concretely this means:
|
||||
weighted_fs[b,c,i,j] =
|
||||
sum_k alphas[b,k,i,j] * features[c,pointsidx[b,k,i,j]] / sum_k alphas[b,k,i,j]
|
||||
|
||||
Args:
|
||||
pt_clds: Packed feature tensor of shape (C, P) giving the features of each point
|
||||
(can use RGB for example).
|
||||
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
|
||||
image_size) giving the weight of each point in the z-buffer.
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
Combined features: Tensor of shape (N, C, image_size, image_size)
|
||||
giving the accumulated features at each point.
|
||||
"""
|
||||
return _CompositeNormWeightedSumPoints.apply(pt_clds, alphas, pointsidx)
|
||||
|
||||
|
||||
class _CompositeWeightedSumPoints(torch.autograd.Function):
|
||||
"""
|
||||
Composite features within a z-buffer using normalized weighted sum. Given a zbuffer
|
||||
with corresponding features and weights, these values are accumulated
|
||||
according to their weights such that depth is ignored; the weights are used to
|
||||
perform a weighted sum. As opposed to norm weighted sum, the weights are not
|
||||
normalized to sum to 1.
|
||||
|
||||
Concretely this means:
|
||||
weighted_fs[b,c,i,j] = sum_k alphas[b,k,i,j] * features[c,pointsidx[b,k,i,j]]
|
||||
|
||||
Args:
|
||||
features: Packed Tensor of shape (C, P) giving the features of each point.
|
||||
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
|
||||
image_size) giving the weight of each point in the z-buffer.
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
weighted_fs: Tensor of shape (N, C, image_size, image_size)
|
||||
giving the accumulated features at each point.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, features, alphas, points_idx):
|
||||
pt_cld = _C.accum_weightedsum(features, alphas, points_idx)
|
||||
|
||||
ctx.save_for_backward(
|
||||
features.clone(), alphas.clone(), points_idx.clone()
|
||||
)
|
||||
return pt_cld
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
grad_features = None
|
||||
grad_alphas = None
|
||||
grad_points_idx = None
|
||||
features, alphas, points_idx = ctx.saved_tensors
|
||||
|
||||
grad_features, grad_alphas = _C.accum_weightedsum_backward(
|
||||
grad_output, features, alphas, points_idx
|
||||
)
|
||||
|
||||
return grad_features, grad_alphas, grad_points_idx, None
|
||||
|
||||
|
||||
def weighted_sum(pointsidx, alphas, pt_clds, blend_params=None) -> torch.Tensor:
|
||||
"""
|
||||
Composite features within a z-buffer using normalized weighted sum.
|
||||
|
||||
Args:
|
||||
pt_clds: Packed Tensor of shape (C, P) giving the features of each point
|
||||
(can use RGB for example).
|
||||
alphas: float32 Tensor of shape (N, points_per_pixel, image_size,
|
||||
image_size) giving the weight of each point in the z-buffer.
|
||||
Values should be in the interval [0, 1].
|
||||
pointsidx: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
|
||||
giving the indices of the nearest points at each pixel, sorted in z-order.
|
||||
Concretely pointsidx[n, k, y, x] = p means that features[:, p] is the feature of
|
||||
the kth closest point (along the z-direction) to pixel (y, x) in batch element n.
|
||||
This is weighted by alphas[n, k, y, x].
|
||||
|
||||
Returns:
|
||||
Combined features: Tensor of shape (N, C, image_size, image_size)
|
||||
giving the accumulated features at each point.
|
||||
"""
|
||||
return _CompositeWeightedSumPoints.apply(pt_clds, alphas, pointsidx)
|
@ -1 +1,8 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .compositor import AlphaCompositor, NormWeightedCompositor
|
||||
from .rasterize_points import rasterize_points
|
||||
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
|
||||
from .renderer import PointsRenderer
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
51
pytorch3d/renderer/points/compositor.py
Normal file
51
pytorch3d/renderer/points/compositor.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..compositing import CompositeParams, alpha_composite, norm_weighted_sum
|
||||
|
||||
# A compositor should take as input 3D points and some corresponding information.
|
||||
# Given this information, the compositor can:
|
||||
# - blend colors across the top K vertices at a pixel
|
||||
|
||||
|
||||
class AlphaCompositor(nn.Module):
|
||||
"""
|
||||
Accumulate points using alpha compositing.
|
||||
"""
|
||||
|
||||
def __init__(self, composite_params=None):
|
||||
super().__init__()
|
||||
|
||||
self.composite_params = (
|
||||
composite_params
|
||||
if composite_params is not None
|
||||
else CompositeParams()
|
||||
)
|
||||
|
||||
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
|
||||
images = alpha_composite(
|
||||
fragments, alphas, ptclds, self.composite_params
|
||||
)
|
||||
return images
|
||||
|
||||
|
||||
class NormWeightedCompositor(nn.Module):
|
||||
"""
|
||||
Accumulate points using a normalized weighted sum.
|
||||
"""
|
||||
|
||||
def __init__(self, composite_params=None):
|
||||
super().__init__()
|
||||
self.composite_params = (
|
||||
composite_params
|
||||
if composite_params is not None
|
||||
else CompositeParams()
|
||||
)
|
||||
|
||||
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
|
||||
images = norm_weighted_sum(
|
||||
fragments, alphas, ptclds, self.composite_params
|
||||
)
|
||||
return images
|
103
pytorch3d/renderer/points/rasterizer.py
Normal file
103
pytorch3d/renderer/points/rasterizer.py
Normal file
@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from typing import NamedTuple, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..cameras import get_world_to_view_transform
|
||||
from .rasterize_points import rasterize_points
|
||||
|
||||
|
||||
# Class to store the outputs of point rasterization
|
||||
class PointFragments(NamedTuple):
|
||||
idx: torch.Tensor
|
||||
zbuf: torch.Tensor
|
||||
dists: torch.Tensor
|
||||
|
||||
|
||||
# Class to store the point rasterization params with defaults
|
||||
class PointsRasterizationSettings(NamedTuple):
|
||||
image_size: int = 256
|
||||
radius: float = 0.01
|
||||
points_per_pixel: int = 8
|
||||
bin_size: Optional[int] = None
|
||||
max_points_per_bin: Optional[int] = None
|
||||
|
||||
|
||||
class PointsRasterizer(nn.Module):
|
||||
"""
|
||||
This class implements methods for rasterizing a batch of pointclouds.
|
||||
"""
|
||||
|
||||
def __init__(self, cameras, raster_settings=None):
|
||||
"""
|
||||
cameras: A cameras object which has a `transform_points` method
|
||||
which returns the transformed points after applying the
|
||||
world-to-view and view-to-screen
|
||||
transformations.
|
||||
raster_settings: the parameters for rasterization. This should be a
|
||||
named tuple.
|
||||
|
||||
All these initial settings can be overridden by passing keyword
|
||||
arguments to the forward function.
|
||||
"""
|
||||
super().__init__()
|
||||
if raster_settings is None:
|
||||
raster_settings = PointsRasterizationSettings()
|
||||
|
||||
self.cameras = cameras
|
||||
self.raster_settings = raster_settings
|
||||
|
||||
def transform(self, point_clouds, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
point_clouds: a set of point clouds
|
||||
|
||||
Returns:
|
||||
points_screen: the points with the vertex positions in screen
|
||||
space
|
||||
|
||||
NOTE: keeping this as a separate function for readability but it could
|
||||
be moved into forward.
|
||||
"""
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
|
||||
pts_world = point_clouds.points_padded()
|
||||
pts_world_packed = point_clouds.points_packed()
|
||||
pts_screen = cameras.transform_points(pts_world, **kwargs)
|
||||
|
||||
# NOTE: Retaining view space z coordinate for now.
|
||||
# TODO: Remove this line when the convention for the z coordinate in
|
||||
# the rasterizer is decided. i.e. retain z in view space or transform
|
||||
# to a different range.
|
||||
view_transform = get_world_to_view_transform(R=cameras.R, T=cameras.T)
|
||||
verts_view = view_transform.transform_points(pts_world)
|
||||
pts_screen[..., 2] = verts_view[..., 2]
|
||||
|
||||
# Offset points of input pointcloud to reuse cached padded/packed calculations.
|
||||
pad_to_packed_idx = point_clouds.padded_to_packed_idx()
|
||||
pts_screen_packed = pts_screen.view(-1, 3)[pad_to_packed_idx, :]
|
||||
pts_packed_offset = pts_screen_packed - pts_world_packed
|
||||
point_clouds = point_clouds.offset(pts_packed_offset)
|
||||
return point_clouds
|
||||
|
||||
def forward(self, point_clouds, **kwargs) -> PointFragments:
|
||||
"""
|
||||
Args:
|
||||
point_clouds: a set of point clouds with coordinates in world space.
|
||||
Returns:
|
||||
PointFragments: Rasterization outputs as a named tuple.
|
||||
"""
|
||||
points_screen = self.transform(point_clouds, **kwargs)
|
||||
raster_settings = kwargs.get("raster_settings", self.raster_settings)
|
||||
idx, zbuf, dists2 = rasterize_points(
|
||||
points_screen,
|
||||
image_size=raster_settings.image_size,
|
||||
radius=raster_settings.radius,
|
||||
points_per_pixel=raster_settings.points_per_pixel,
|
||||
bin_size=raster_settings.bin_size,
|
||||
max_points_per_bin=raster_settings.max_points_per_bin,
|
||||
)
|
||||
return PointFragments(idx=idx, zbuf=zbuf, dists=dists2)
|
56
pytorch3d/renderer/points/renderer.py
Normal file
56
pytorch3d/renderer/points/renderer.py
Normal file
@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# A renderer class should be initialized with a
|
||||
# function for rasterization and a function for compositing.
|
||||
# The rasterizer should:
|
||||
# - transform inputs from world -> screen space
|
||||
# - rasterize inputs
|
||||
# - return fragments
|
||||
# The compositor can take fragments as input along with any other properties of
|
||||
# the scene and generate images.
|
||||
|
||||
# E.g. rasterize inputs and then shade
|
||||
#
|
||||
# fragments = self.rasterize(point_clouds)
|
||||
# images = self.compositor(fragments, point_clouds)
|
||||
# return images
|
||||
|
||||
|
||||
class PointsRenderer(nn.Module):
|
||||
"""
|
||||
A class for rendering a batch of points. The class should
|
||||
be initialized with a rasterizer and compositor class which each have a forward
|
||||
function.
|
||||
"""
|
||||
|
||||
def __init__(self, rasterizer, compositor):
|
||||
super().__init__()
|
||||
self.rasterizer = rasterizer
|
||||
self.compositor = compositor
|
||||
|
||||
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
|
||||
fragments = self.rasterizer(point_clouds, **kwargs)
|
||||
|
||||
# Construct weights based on the distance of a point to the true point.
|
||||
# However, this could be done differently: e.g. predicted as opposed
|
||||
# to a function of the weights.
|
||||
r = self.rasterizer.raster_settings.radius
|
||||
|
||||
dists2 = fragments.dists.permute(0, 3, 1, 2)
|
||||
weights = 1 - dists2 / (r * r)
|
||||
images = self.compositor(
|
||||
fragments.idx.long().permute(0, 3, 1, 2),
|
||||
weights,
|
||||
point_clouds.features_packed().permute(1, 0),
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# permute so image comes at the end
|
||||
images = images.permute(0, 2, 3, 1)
|
||||
|
||||
return images
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .meshes import Meshes, join_meshes
|
||||
from .pointclouds import Pointclouds
|
||||
from .textures import Textures
|
||||
from .utils import (
|
||||
list_to_packed,
|
||||
|
442
tests/test_compositing.py
Normal file
442
tests/test_compositing.py
Normal file
@ -0,0 +1,442 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d.renderer.compositing import (
|
||||
alpha_composite,
|
||||
norm_weighted_sum,
|
||||
weighted_sum,
|
||||
)
|
||||
|
||||
|
||||
class TestAccumulatePoints(unittest.TestCase):
|
||||
|
||||
# NAIVE PYTHON IMPLEMENTATIONS (USED FOR TESTING)
|
||||
@staticmethod
|
||||
def accumulate_alphacomposite_python(points_idx, alphas, features):
|
||||
"""
|
||||
Naive pure PyTorch implementation of alpha_composite.
|
||||
Inputs / Outputs: Same as function
|
||||
"""
|
||||
|
||||
B, K, H, W = points_idx.size()
|
||||
C = features.size(0)
|
||||
|
||||
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
|
||||
|
||||
for b in range(0, B):
|
||||
for c in range(0, C):
|
||||
for i in range(0, W):
|
||||
for j in range(0, H):
|
||||
t_alpha = 1
|
||||
for k in range(0, K):
|
||||
n_idx = points_idx[b, k, j, i]
|
||||
|
||||
if n_idx < 0:
|
||||
continue
|
||||
|
||||
alpha = alphas[b, k, j, i]
|
||||
output[b, c, j, i] += (
|
||||
features[c, n_idx] * alpha * t_alpha
|
||||
)
|
||||
t_alpha = (1 - alpha) * t_alpha
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def accumulate_weightedsum_python(points_idx, alphas, features):
|
||||
"""
|
||||
Naive pure PyTorch implementation of weighted_sum rasterization.
|
||||
Inputs / Outputs: Same as function
|
||||
"""
|
||||
B, K, H, W = points_idx.size()
|
||||
C = features.size(0)
|
||||
|
||||
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
|
||||
|
||||
for b in range(0, B):
|
||||
for c in range(0, C):
|
||||
for i in range(0, W):
|
||||
for j in range(0, H):
|
||||
|
||||
for k in range(0, K):
|
||||
n_idx = points_idx[b, k, j, i]
|
||||
|
||||
if n_idx < 0:
|
||||
continue
|
||||
|
||||
alpha = alphas[b, k, j, i]
|
||||
output[b, c, j, i] += features[c, n_idx] * alpha
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def accumulate_weightedsumnorm_python(points_idx, alphas, features):
|
||||
"""
|
||||
Naive pure PyTorch implementation of norm_weighted_sum.
|
||||
Inputs / Outputs: Same as function
|
||||
"""
|
||||
|
||||
B, K, H, W = points_idx.size()
|
||||
C = features.size(0)
|
||||
|
||||
output = torch.zeros(B, C, H, W, dtype=alphas.dtype)
|
||||
|
||||
for b in range(0, B):
|
||||
for c in range(0, C):
|
||||
for i in range(0, W):
|
||||
for j in range(0, H):
|
||||
t_alpha = 0
|
||||
for k in range(0, K):
|
||||
n_idx = points_idx[b, k, j, i]
|
||||
|
||||
if n_idx < 0:
|
||||
continue
|
||||
|
||||
t_alpha += alphas[b, k, j, i]
|
||||
|
||||
t_alpha = max(t_alpha, 1e-4)
|
||||
|
||||
for k in range(0, K):
|
||||
n_idx = points_idx[b, k, j, i]
|
||||
|
||||
if n_idx < 0:
|
||||
continue
|
||||
|
||||
alpha = alphas[b, k, j, i]
|
||||
output[b, c, j, i] += (
|
||||
features[c, n_idx] * alpha / t_alpha
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def test_python(self):
|
||||
device = torch.device("cpu")
|
||||
self._simple_alphacomposite(
|
||||
self.accumulate_alphacomposite_python, device
|
||||
)
|
||||
self._simple_wsum(self.accumulate_weightedsum_python, device)
|
||||
self._simple_wsumnorm(self.accumulate_weightedsumnorm_python, device)
|
||||
|
||||
def test_cpu(self):
|
||||
device = torch.device("cpu")
|
||||
self._simple_alphacomposite(alpha_composite, device)
|
||||
self._simple_wsum(weighted_sum, device)
|
||||
self._simple_wsumnorm(norm_weighted_sum, device)
|
||||
|
||||
def test_cuda(self):
|
||||
device = torch.device("cuda:0")
|
||||
self._simple_alphacomposite(alpha_composite, device)
|
||||
self._simple_wsum(weighted_sum, device)
|
||||
self._simple_wsumnorm(norm_weighted_sum, device)
|
||||
|
||||
def test_python_vs_cpu_vs_cuda(self):
|
||||
self._python_vs_cpu_vs_cuda(
|
||||
self.accumulate_alphacomposite_python, alpha_composite
|
||||
)
|
||||
self._python_vs_cpu_vs_cuda(
|
||||
self.accumulate_weightedsumnorm_python, norm_weighted_sum
|
||||
)
|
||||
self._python_vs_cpu_vs_cuda(
|
||||
self.accumulate_weightedsum_python, weighted_sum
|
||||
)
|
||||
|
||||
def _python_vs_cpu_vs_cuda(self, accumulate_func_python, accumulate_func):
|
||||
torch.manual_seed(231)
|
||||
device = torch.device("cpu")
|
||||
|
||||
W = 8
|
||||
C = 3
|
||||
P = 32
|
||||
|
||||
for d in ["cpu", "cuda"]:
|
||||
# TODO(gkioxari) add torch.float64 to types after double precision
|
||||
# support is added to atomicAdd
|
||||
for t in [torch.float32]:
|
||||
device = torch.device(d)
|
||||
|
||||
# Create values
|
||||
alphas = torch.rand(2, 4, W, W, dtype=t).to(device)
|
||||
alphas.requires_grad = True
|
||||
alphas_cpu = alphas.detach().cpu()
|
||||
alphas_cpu.requires_grad = True
|
||||
|
||||
features = torch.randn(C, P, dtype=t).to(device)
|
||||
features.requires_grad = True
|
||||
features_cpu = features.detach().cpu()
|
||||
features_cpu.requires_grad = True
|
||||
|
||||
inds = torch.randint(P + 1, size=(2, 4, W, W)).to(device) - 1
|
||||
inds_cpu = inds.detach().cpu()
|
||||
|
||||
args_cuda = (inds, alphas, features)
|
||||
args_cpu = (inds_cpu, alphas_cpu, features_cpu)
|
||||
|
||||
self._compare_impls(
|
||||
accumulate_func_python,
|
||||
accumulate_func,
|
||||
args_cpu,
|
||||
args_cuda,
|
||||
(alphas_cpu, features_cpu),
|
||||
(alphas, features),
|
||||
compare_grads=True,
|
||||
)
|
||||
|
||||
def _compare_impls(
|
||||
self, fn1, fn2, args1, args2, grads1, grads2, compare_grads=False
|
||||
):
|
||||
res1 = fn1(*args1)
|
||||
res2 = fn2(*args2)
|
||||
|
||||
self.assertTrue(torch.allclose(res1.cpu(), res2.cpu(), atol=1e-6))
|
||||
|
||||
if not compare_grads:
|
||||
return
|
||||
|
||||
# Compare gradients
|
||||
torch.manual_seed(231)
|
||||
grad_res = torch.randn_like(res1)
|
||||
loss1 = (res1 * grad_res).sum()
|
||||
loss1.backward()
|
||||
|
||||
grads1 = [gradsi.grad.data.clone().cpu() for gradsi in grads1]
|
||||
grad_res = grad_res.to(res2)
|
||||
|
||||
loss2 = (res2 * grad_res).sum()
|
||||
loss2.backward()
|
||||
grads2 = [gradsi.grad.data.clone().cpu() for gradsi in grads2]
|
||||
|
||||
for i in range(0, len(grads1)):
|
||||
self.assertTrue(
|
||||
torch.allclose(grads1[i].cpu(), grads2[i].cpu(), atol=1e-6)
|
||||
)
|
||||
|
||||
def _simple_wsum(self, accum_func, device):
|
||||
# Initialise variables
|
||||
features = torch.Tensor(
|
||||
[[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]
|
||||
).to(device)
|
||||
|
||||
alphas = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
points_idx = (
|
||||
torch.Tensor(
|
||||
[
|
||||
[
|
||||
# fmt: off
|
||||
[
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
[0, -1, -1, -1], # noqa: E241, E201
|
||||
[0, 1, 1, 0], # noqa: E241, E201
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[2, 2, 2, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 2, -1, 2], # noqa: E241, E201
|
||||
],
|
||||
# fmt: on
|
||||
]
|
||||
]
|
||||
)
|
||||
.long()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
result = accum_func(points_idx, alphas, features)
|
||||
|
||||
self.assertTrue(result.shape == (1, 2, 4, 4))
|
||||
|
||||
true_result = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.35, 0.35, 0.35, 0.35],
|
||||
[0.35, 0.90, 0.90, 0.30],
|
||||
[0.35, 1.30, 1.30, 0.35],
|
||||
[0.35, 0.35, 0.05, 0.35],
|
||||
],
|
||||
[
|
||||
[0.35, 0.35, 0.35, 0.35],
|
||||
[0.35, 0.90, 0.90, 0.30],
|
||||
[0.35, 1.30, 1.30, 0.35],
|
||||
[0.35, 0.35, 0.05, 0.35],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
||||
)
|
||||
|
||||
def _simple_wsumnorm(self, accum_func, device):
|
||||
# Initialise variables
|
||||
features = torch.Tensor(
|
||||
[[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]
|
||||
).to(device)
|
||||
|
||||
alphas = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
# fmt: off
|
||||
points_idx = (
|
||||
torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
[0, -1, -1, -1], # noqa: E241, E201
|
||||
[0, 1, 1, 0], # noqa: E241, E201
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[2, 2, 2, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 2, -1, 2], # noqa: E241, E201
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
.long()
|
||||
.to(device)
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
result = accum_func(points_idx, alphas, features)
|
||||
|
||||
self.assertTrue(result.shape == (1, 2, 4, 4))
|
||||
|
||||
true_result = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.35, 0.35, 0.35, 0.35],
|
||||
[0.35, 0.90, 0.90, 0.60],
|
||||
[0.35, 0.65, 0.65, 0.35],
|
||||
[0.35, 0.35, 0.10, 0.35],
|
||||
],
|
||||
[
|
||||
[0.35, 0.35, 0.35, 0.35],
|
||||
[0.35, 0.90, 0.90, 0.60],
|
||||
[0.35, 0.65, 0.65, 0.35],
|
||||
[0.35, 0.35, 0.10, 0.35],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(result.cpu(), true_result.cpu(), rtol=1e-3)
|
||||
)
|
||||
|
||||
def _simple_alphacomposite(self, accum_func, device):
|
||||
# Initialise variables
|
||||
features = torch.Tensor(
|
||||
[[0.1, 0.4, 0.6, 0.9], [0.1, 0.4, 0.6, 0.9]]
|
||||
).to(device)
|
||||
|
||||
alphas = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
[
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 1.0, 1.0, 0.5],
|
||||
[0.5, 0.5, 0.5, 0.5],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
# fmt: off
|
||||
points_idx = (
|
||||
torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
[0, -1, -1, -1], # noqa: E241, E201
|
||||
[0, 1, 1, 0], # noqa: E241, E201
|
||||
[0, 0, 0, 0], # noqa: E241, E201
|
||||
],
|
||||
[
|
||||
[2, 2, 2, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 3, 3, 2], # noqa: E241, E201
|
||||
[2, 2, -1, 2], # noqa: E241, E201
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
.long()
|
||||
.to(device)
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
result = accum_func(points_idx, alphas, features)
|
||||
|
||||
self.assertTrue(result.shape == (1, 2, 4, 4))
|
||||
|
||||
true_result = torch.Tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.20, 0.20, 0.20, 0.20],
|
||||
[0.20, 0.90, 0.90, 0.30],
|
||||
[0.20, 0.40, 0.40, 0.20],
|
||||
[0.20, 0.20, 0.05, 0.20],
|
||||
],
|
||||
[
|
||||
[0.20, 0.20, 0.20, 0.20],
|
||||
[0.20, 0.90, 0.90, 0.30],
|
||||
[0.20, 0.40, 0.40, 0.20],
|
||||
[0.20, 0.20, 0.05, 0.20],
|
||||
],
|
||||
]
|
||||
]
|
||||
).to(device)
|
||||
|
||||
self.assertTrue((result == true_result).all().item())
|
@ -76,7 +76,10 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
verts_torch = verts.detach().clone().to(dtype)
|
||||
verts_torch.requires_grad = True
|
||||
faces_torch = faces.detach().clone()
|
||||
areas_torch, normals_torch = TestFaceAreasNormals.face_areas_normals_python(
|
||||
(
|
||||
areas_torch,
|
||||
normals_torch,
|
||||
) = TestFaceAreasNormals.face_areas_normals_python(
|
||||
verts_torch, faces_torch
|
||||
)
|
||||
self.assertClose(areas_torch, areas, atol=1e-7)
|
||||
|
Loading…
x
Reference in New Issue
Block a user