pytorch3d/pytorch3d/csrc/compositing/weighted_sum_cpu.cpp
Olivia 53599770dd Accumulate points (#4)
Summary:
Code for accumulating points in the z-buffer in three ways:
1. weighted sum
2. normalised weighted sum
3. alpha compositing

Pull Request resolved: https://github.com/fairinternal/pytorch3d/pull/4

Reviewed By: nikhilaravi

Differential Revision: D20522422

Pulled By: gkioxari

fbshipit-source-id: 5023baa05f15e338f3821ef08f5552c2dcbfc06c
2020-03-19 11:23:12 -07:00

99 lines
3.3 KiB
C++

// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <cmath>
#include <vector>
torch::Tensor weightedSumCpuForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t B = points_idx.size(0);
const int64_t K = points_idx.size(1);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
const int64_t C = features.size(0);
torch::Tensor result = torch::zeros({B, C, H, W}, features.options());
auto features_a = features.accessor<float, 2>();
auto alphas_a = alphas.accessor<float, 4>();
auto points_idx_a = points_idx.accessor<int64_t, 4>();
auto result_a = result.accessor<float, 4>();
// Iterate over the batch
for (int b = 0; b < B; ++b) {
// Iterate over the features
for (int c = 0; c < C; ++c) {
// Iterate through the horizontal lines of the image from top to bottom
for (int j = 0; j < H; ++j) {
// Iterate over pixels in a horizontal line, left to right
for (int i = 0; i < W; ++i) {
// Iterate through the closest K points for this pixel
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
// Sentinel value is -1 indicating no point overlaps the pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas_a[b][k][j][i];
result_a[b][c][j][i] += alpha * features_a[c][n_idx];
}
}
}
}
}
return result;
}
std::tuple<torch::Tensor, torch::Tensor> weightedSumCpuBackward(
const torch::Tensor& grad_outputs,
const torch::Tensor& features,
const torch::Tensor& alphas,
const torch::Tensor& points_idx) {
const int64_t B = points_idx.size(0);
const int64_t K = points_idx.size(1);
const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3);
const int64_t C = features.size(0);
torch::Tensor grad_features = torch::zeros_like(features);
torch::Tensor grad_alphas = torch::zeros_like(alphas);
auto grad_outputs_a = grad_outputs.accessor<float, 4>();
auto features_a = features.accessor<float, 2>();
auto alphas_a = alphas.accessor<float, 4>();
auto points_idx_a = points_idx.accessor<int64_t, 4>();
auto grad_features_a = grad_features.accessor<float, 2>();
auto grad_alphas_a = grad_alphas.accessor<float, 4>();
// Iterate over the batch
for (int b = 0; b < B; ++b) {
// Iterate oer the features
for (int c = 0; c < C; ++c) {
// Iterate through the horizontal lines of the image from top to bottom
for (int j = 0; j < H; ++j) {
// Iterate over pixels in a horizontal line, left to right
for (int i = 0; i < W; ++i) {
// Iterate through the closest K points for this pixel
for (int k = 0; k < K; ++k) {
int64_t n_idx = points_idx_a[b][k][j][i];
// Sentinal value is -1, indicating no point overlaps this pixel
if (n_idx < 0) {
continue;
}
float alpha = alphas_a[b][k][j][i];
grad_alphas_a[b][k][j][i] +=
grad_outputs_a[b][c][j][i] * features_a[c][n_idx];
grad_features_a[c][n_idx] += grad_outputs_a[b][c][j][i] * alpha;
}
}
}
}
}
return std::make_tuple(grad_features, grad_alphas);
}