mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
sample_pdf CUDA and C++ implementations.
Summary: Implement the sample_pdf function from the NeRF project as compiled operators.. The binary search (in searchsorted) is replaced with a low tech linear search, but this is not a problem for the envisaged numbers of bins. Reviewed By: gkioxari Differential Revision: D26312535 fbshipit-source-id: df1c3119cd63d944380ed1b2657b6ad81d743e49
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7d7d00f288
commit
1ea2b7272a
@@ -26,6 +26,7 @@
|
||||
#include "point_mesh/point_mesh_cuda.h"
|
||||
#include "rasterize_meshes/rasterize_meshes.h"
|
||||
#include "rasterize_points/rasterize_points.h"
|
||||
#include "sample_pdf/sample_pdf.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
|
||||
@@ -83,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward);
|
||||
m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward);
|
||||
|
||||
// Sample PDF
|
||||
m.def("sample_pdf", &SamplePdf);
|
||||
|
||||
// Pulsar.
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
c10::ShowLogInfoToStderr();
|
||||
|
||||
153
pytorch3d/csrc/sample_pdf/sample_pdf.cu
Normal file
153
pytorch3d/csrc/sample_pdf/sample_pdf.cu
Normal file
@@ -0,0 +1,153 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
// There is no intermediate memory, so no reason not to have blocksize=32.
|
||||
// 256 is a reasonable number of blocks.
|
||||
|
||||
// DESIGN
|
||||
// We exploit the fact that n_samples is not tiny.
|
||||
// A chunk of work is T*blocksize many samples from
|
||||
// a single batch elememt.
|
||||
// For each batch element there will be
|
||||
// chunks_per_batch = 1 + (n_samples-1)/(T*blocksize) of them.
|
||||
// The number of potential chunks to do is
|
||||
// n_chunks = chunks_per_batch * n_batches.
|
||||
// These chunks are divided among the gridSize-many blocks.
|
||||
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
|
||||
// In chunk i, we work on batch_element i/chunks_per_batch
|
||||
// on samples starting from (i%chunks_per_batch) * (T*blocksize)
|
||||
|
||||
// BEGIN HYPOTHETICAL
|
||||
// Another option (not implemented) if batch_size was always large
|
||||
// would be as follows.
|
||||
|
||||
// A chunk of work is S samples from each of blocksize-many
|
||||
// batch elements.
|
||||
// For each batch element there will be
|
||||
// chunks_per_batch = (1+(n_samples-1)/S) of them.
|
||||
// The number of potential chunks to do is
|
||||
// n_chunks = chunks_per_batch * (1+(n_batches-1)/blocksize)
|
||||
// These chunks are divided among the gridSize-many blocks.
|
||||
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
|
||||
// In chunk i, we work on samples starting from S*(i%chunks_per_batch)
|
||||
// on batch elements starting from blocksize*(i/chunks_per_batch).
|
||||
// END HYPOTHETICAL
|
||||
|
||||
__global__ void SamplePdfCudaKernel(
|
||||
const float* __restrict__ bins,
|
||||
const float* __restrict__ weights,
|
||||
float* __restrict__ outputs,
|
||||
float eps,
|
||||
const int T,
|
||||
const int64_t batch_size,
|
||||
const int64_t n_bins,
|
||||
const int64_t n_samples) {
|
||||
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * blockDim.x);
|
||||
const int64_t n_chunks = chunks_per_batch * batch_size;
|
||||
|
||||
for (int64_t i_chunk = blockIdx.x; i_chunk < n_chunks; i_chunk += gridDim.x) {
|
||||
// Loop over the chunks.
|
||||
int64_t i_batch_element = i_chunk / chunks_per_batch;
|
||||
int64_t sample_start = (i_chunk % chunks_per_batch) * (T * blockDim.x);
|
||||
const float* const weight_startp = weights + n_bins * i_batch_element;
|
||||
const float* const bin_startp = bins + (1 + n_bins) * i_batch_element;
|
||||
|
||||
// Each chunk looks at a single batch element, so we do the preprocessing
|
||||
// which depends on the batch element, namely finding the total weight.
|
||||
// Idenntical work is being done in sync here by every thread of the block.
|
||||
float total_weight = eps;
|
||||
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
|
||||
total_weight += weight_startp[i_bin];
|
||||
}
|
||||
|
||||
float* const output_startp =
|
||||
outputs + n_samples * i_batch_element + sample_start;
|
||||
|
||||
for (int t = 0; t < T; ++t) {
|
||||
// Loop over T, which is the number of samples each thread makes within
|
||||
// the chunk.
|
||||
const int64_t i_sample_within_chunk = threadIdx.x + t * blockDim.x;
|
||||
if (sample_start + i_sample_within_chunk >= n_samples) {
|
||||
// Some threads need to exit early because the sample they would
|
||||
// make is unwanted.
|
||||
continue;
|
||||
}
|
||||
// output_startp[i_sample_within_chunk] contains the quantile we (i.e.
|
||||
// this thread) are calcvulating.
|
||||
float uniform = total_weight * output_startp[i_sample_within_chunk];
|
||||
int64_t i_bin = 0;
|
||||
// We find the bin containing the quantile by walking along the weights.
|
||||
// This loop must be thread dependent. I.e. the whole warp will wait until
|
||||
// every thread has found the bin for its quantile.
|
||||
// It may be best to write it differently.
|
||||
while (i_bin + 1 < n_bins && uniform > weight_startp[i_bin]) {
|
||||
uniform -= weight_startp[i_bin];
|
||||
++i_bin;
|
||||
}
|
||||
|
||||
// Now we know which bin to look in, we use linear interpolation
|
||||
// to find the location of the quantile within the bin, and
|
||||
// write the answer back.
|
||||
float bin_start = bin_startp[i_bin];
|
||||
float bin_end = bin_startp[i_bin + 1];
|
||||
float bin_weight = weight_startp[i_bin];
|
||||
float output_value = bin_start;
|
||||
if (uniform > bin_weight) {
|
||||
output_value = bin_end;
|
||||
} else if (bin_weight > eps) {
|
||||
output_value += (uniform / bin_weight) * (bin_end - bin_start);
|
||||
}
|
||||
output_startp[i_sample_within_chunk] = output_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SamplePdfCuda(
|
||||
const at::Tensor& bins,
|
||||
const at::Tensor& weights,
|
||||
const at::Tensor& outputs,
|
||||
float eps) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg bins_t{bins, "bins", 1}, weights_t{weights, "weights", 2},
|
||||
outputs_t{outputs, "outputs", 3};
|
||||
at::CheckedFrom c = "SamplePdfCuda";
|
||||
at::checkAllSameGPU(c, {bins_t, weights_t, outputs_t});
|
||||
at::checkAllSameType(c, {bins_t, weights_t, outputs_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(bins.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t batch_size = bins.size(0);
|
||||
const int64_t n_bins = weights.size(1);
|
||||
const int64_t n_samples = outputs.size(1);
|
||||
|
||||
const int64_t threads = 32;
|
||||
const int64_t T = n_samples <= threads ? 1 : 2;
|
||||
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * threads);
|
||||
const int64_t n_chunks = chunks_per_batch * batch_size;
|
||||
|
||||
const int64_t max_blocks = 1024;
|
||||
const int64_t blocks = n_chunks < max_blocks ? n_chunks : max_blocks;
|
||||
|
||||
SamplePdfCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
bins.contiguous().data_ptr<float>(),
|
||||
weights.contiguous().data_ptr<float>(),
|
||||
outputs.data_ptr<float>(), // Checked contiguous in header file.
|
||||
eps,
|
||||
T,
|
||||
batch_size,
|
||||
n_bins,
|
||||
n_samples);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
74
pytorch3d/csrc/sample_pdf/sample_pdf.h
Normal file
74
pytorch3d/csrc/sample_pdf/sample_pdf.h
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// ****************************************************************************
|
||||
// * SamplePdf *
|
||||
// ****************************************************************************
|
||||
|
||||
// Samples a probability density functions defined by bin edges `bins` and
|
||||
// the non-negative per-bin probabilities `weights`.
|
||||
|
||||
// Args:
|
||||
// bins: FloatTensor of shape `(batch_size, n_bins+1)` denoting the edges
|
||||
// of the sampling bins.
|
||||
|
||||
// weights: FloatTensor of shape `(batch_size, n_bins)` containing
|
||||
// non-negative numbers representing the probability of sampling the
|
||||
// corresponding bin.
|
||||
|
||||
// uniforms: The quantiles to draw, FloatTensor of shape
|
||||
// `(batch_size, n_samples)`.
|
||||
|
||||
// outputs: On call, this contains the quantiles to draw. It is overwritten
|
||||
// with the drawn samples. FloatTensor of shape
|
||||
// `(batch_size, n_samples), where `n_samples are drawn from each
|
||||
// distribution.
|
||||
|
||||
// eps: A constant preventing division by zero in case empty bins are
|
||||
// present.
|
||||
|
||||
// Not differentiable
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
void SamplePdfCuda(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps);
|
||||
#endif
|
||||
|
||||
void SamplePdfCpu(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps);
|
||||
|
||||
inline void SamplePdf(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps) {
|
||||
if (bins.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(weights);
|
||||
CHECK_CONTIGUOUS_CUDA(outputs);
|
||||
SamplePdfCuda(bins, weights, outputs, eps);
|
||||
return;
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
SamplePdfCpu(bins, weights, outputs, eps);
|
||||
}
|
||||
141
pytorch3d/csrc/sample_pdf/sample_pdf_cpu.cpp
Normal file
141
pytorch3d/csrc/sample_pdf/sample_pdf_cpu.cpp
Normal file
@@ -0,0 +1,141 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// If the number of bins is the typical 64, it is
|
||||
// quicker to use binary search than linear scan.
|
||||
// With more bins, it is more important.
|
||||
// There is no equivalent CUDA implementation yet.
|
||||
#define USE_BINARY_SEARCH
|
||||
|
||||
namespace {
|
||||
// This worker function does the job of SamplePdf but only on
|
||||
// batch elements in [start_batch, end_batch).
|
||||
void SamplePdfCpu_worker(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps,
|
||||
int64_t start_batch,
|
||||
int64_t end_batch) {
|
||||
const int64_t n_bins = weights.size(1);
|
||||
const int64_t n_samples = outputs.size(1);
|
||||
|
||||
auto bins_a = bins.accessor<float, 2>();
|
||||
auto weights_a = weights.accessor<float, 2>();
|
||||
float* __restrict__ output_p =
|
||||
outputs.data_ptr<float>() + start_batch * n_samples;
|
||||
|
||||
#ifdef USE_BINARY_SEARCH
|
||||
std::vector<float> partial_sums(n_bins);
|
||||
#endif
|
||||
|
||||
for (int64_t i_batch_elt = start_batch; i_batch_elt < end_batch;
|
||||
++i_batch_elt) {
|
||||
auto bin_a = bins_a[i_batch_elt];
|
||||
auto weight_a = weights_a[i_batch_elt];
|
||||
|
||||
// Here we do the work which has to be done once per batch element.
|
||||
// i.e. (1) finding the total weight. (2) If using binary search,
|
||||
// precompute the partial sums of the weights.
|
||||
|
||||
float total_weight = 0;
|
||||
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
|
||||
total_weight += weight_a[i_bin];
|
||||
#ifdef USE_BINARY_SEARCH
|
||||
partial_sums[i_bin] = total_weight;
|
||||
#endif
|
||||
}
|
||||
total_weight += eps;
|
||||
|
||||
for (int64_t i_sample = 0; i_sample < n_samples; ++i_sample) {
|
||||
// Here we are taking a single random quantile (which is stored
|
||||
// in *output_p) and using it to make a single sample, which we
|
||||
// write back to the same location. First we find which bin
|
||||
// the quantile lives in, either by binary search in the
|
||||
// precomputed partial sums, or by scanning through the weights.
|
||||
|
||||
float uniform = total_weight * *output_p;
|
||||
#ifdef USE_BINARY_SEARCH
|
||||
int64_t i_bin = std::lower_bound(
|
||||
partial_sums.begin(), --partial_sums.end(), uniform) -
|
||||
partial_sums.begin();
|
||||
if (i_bin > 0) {
|
||||
uniform -= partial_sums[i_bin - 1];
|
||||
}
|
||||
#else
|
||||
int64_t i_bin = 0;
|
||||
while (i_bin + 1 < n_bins && uniform > weight_a[i_bin]) {
|
||||
uniform -= weight_a[i_bin];
|
||||
++i_bin;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Now i_bin identifies the bin the quantile lives in, we use
|
||||
// straight line interpolation to find the position of the
|
||||
// quantile within the bin, and write it to *output_p.
|
||||
|
||||
float bin_start = bin_a[i_bin];
|
||||
float bin_end = bin_a[i_bin + 1];
|
||||
float bin_weight = weight_a[i_bin];
|
||||
float output_value = bin_start;
|
||||
if (uniform > bin_weight) {
|
||||
output_value = bin_end;
|
||||
} else if (bin_weight > eps) {
|
||||
output_value += (uniform / bin_weight) * (bin_end - bin_start);
|
||||
}
|
||||
*output_p = output_value;
|
||||
++output_p;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void SamplePdfCpu(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps) {
|
||||
const int64_t batch_size = bins.size(0);
|
||||
const int64_t max_threads = std::min(4, at::get_num_threads());
|
||||
const int64_t n_threads = std::min(max_threads, batch_size);
|
||||
if (batch_size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// SamplePdfCpu_worker does the work of this function. We send separate ranges
|
||||
// of batch elements to that function in nThreads-1 separate threads.
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(n_threads - 1);
|
||||
const int64_t batch_elements_per_thread = 1 + (batch_size - 1) / n_threads;
|
||||
int64_t start_batch = 0;
|
||||
for (int iThread = 0; iThread < n_threads - 1; ++iThread) {
|
||||
threads.emplace_back(
|
||||
SamplePdfCpu_worker,
|
||||
bins,
|
||||
weights,
|
||||
outputs,
|
||||
eps,
|
||||
start_batch,
|
||||
start_batch + batch_elements_per_thread);
|
||||
start_batch += batch_elements_per_thread;
|
||||
}
|
||||
|
||||
// The remaining batch elements are calculated in this threads. If nThreads is
|
||||
// 1 then all the work happens in this line.
|
||||
SamplePdfCpu_worker(bins, weights, outputs, eps, start_batch, batch_size);
|
||||
for (auto&& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user