sample_pdf CUDA and C++ implementations.

Summary: Implement the sample_pdf function from the NeRF project as compiled operators.. The binary search (in searchsorted) is replaced with a low tech linear search, but this is not a problem for the envisaged numbers of bins.

Reviewed By: gkioxari

Differential Revision: D26312535

fbshipit-source-id: df1c3119cd63d944380ed1b2657b6ad81d743e49
This commit is contained in:
Jeremy Reizenstein 2021-08-17 08:06:48 -07:00 committed by Facebook GitHub Bot
parent 7d7d00f288
commit 1ea2b7272a
7 changed files with 488 additions and 3 deletions

View File

@ -26,6 +26,7 @@
#include "point_mesh/point_mesh_cuda.h"
#include "rasterize_meshes/rasterize_meshes.h"
#include "rasterize_points/rasterize_points.h"
#include "sample_pdf/sample_pdf.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
@ -83,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward);
m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward);
// Sample PDF
m.def("sample_pdf", &SamplePdf);
// Pulsar.
#ifdef PULSAR_LOGGING_ENABLED
c10::ShowLogInfoToStderr();

View File

@ -0,0 +1,153 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
// There is no intermediate memory, so no reason not to have blocksize=32.
// 256 is a reasonable number of blocks.
// DESIGN
// We exploit the fact that n_samples is not tiny.
// A chunk of work is T*blocksize many samples from
// a single batch elememt.
// For each batch element there will be
// chunks_per_batch = 1 + (n_samples-1)/(T*blocksize) of them.
// The number of potential chunks to do is
// n_chunks = chunks_per_batch * n_batches.
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on batch_element i/chunks_per_batch
// on samples starting from (i%chunks_per_batch) * (T*blocksize)
// BEGIN HYPOTHETICAL
// Another option (not implemented) if batch_size was always large
// would be as follows.
// A chunk of work is S samples from each of blocksize-many
// batch elements.
// For each batch element there will be
// chunks_per_batch = (1+(n_samples-1)/S) of them.
// The number of potential chunks to do is
// n_chunks = chunks_per_batch * (1+(n_batches-1)/blocksize)
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on samples starting from S*(i%chunks_per_batch)
// on batch elements starting from blocksize*(i/chunks_per_batch).
// END HYPOTHETICAL
__global__ void SamplePdfCudaKernel(
const float* __restrict__ bins,
const float* __restrict__ weights,
float* __restrict__ outputs,
float eps,
const int T,
const int64_t batch_size,
const int64_t n_bins,
const int64_t n_samples) {
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * blockDim.x);
const int64_t n_chunks = chunks_per_batch * batch_size;
for (int64_t i_chunk = blockIdx.x; i_chunk < n_chunks; i_chunk += gridDim.x) {
// Loop over the chunks.
int64_t i_batch_element = i_chunk / chunks_per_batch;
int64_t sample_start = (i_chunk % chunks_per_batch) * (T * blockDim.x);
const float* const weight_startp = weights + n_bins * i_batch_element;
const float* const bin_startp = bins + (1 + n_bins) * i_batch_element;
// Each chunk looks at a single batch element, so we do the preprocessing
// which depends on the batch element, namely finding the total weight.
// Idenntical work is being done in sync here by every thread of the block.
float total_weight = eps;
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
total_weight += weight_startp[i_bin];
}
float* const output_startp =
outputs + n_samples * i_batch_element + sample_start;
for (int t = 0; t < T; ++t) {
// Loop over T, which is the number of samples each thread makes within
// the chunk.
const int64_t i_sample_within_chunk = threadIdx.x + t * blockDim.x;
if (sample_start + i_sample_within_chunk >= n_samples) {
// Some threads need to exit early because the sample they would
// make is unwanted.
continue;
}
// output_startp[i_sample_within_chunk] contains the quantile we (i.e.
// this thread) are calcvulating.
float uniform = total_weight * output_startp[i_sample_within_chunk];
int64_t i_bin = 0;
// We find the bin containing the quantile by walking along the weights.
// This loop must be thread dependent. I.e. the whole warp will wait until
// every thread has found the bin for its quantile.
// It may be best to write it differently.
while (i_bin + 1 < n_bins && uniform > weight_startp[i_bin]) {
uniform -= weight_startp[i_bin];
++i_bin;
}
// Now we know which bin to look in, we use linear interpolation
// to find the location of the quantile within the bin, and
// write the answer back.
float bin_start = bin_startp[i_bin];
float bin_end = bin_startp[i_bin + 1];
float bin_weight = weight_startp[i_bin];
float output_value = bin_start;
if (uniform > bin_weight) {
output_value = bin_end;
} else if (bin_weight > eps) {
output_value += (uniform / bin_weight) * (bin_end - bin_start);
}
output_startp[i_sample_within_chunk] = output_value;
}
}
}
void SamplePdfCuda(
const at::Tensor& bins,
const at::Tensor& weights,
const at::Tensor& outputs,
float eps) {
// Check inputs are on the same device
at::TensorArg bins_t{bins, "bins", 1}, weights_t{weights, "weights", 2},
outputs_t{outputs, "outputs", 3};
at::CheckedFrom c = "SamplePdfCuda";
at::checkAllSameGPU(c, {bins_t, weights_t, outputs_t});
at::checkAllSameType(c, {bins_t, weights_t, outputs_t});
// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(bins.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const int64_t batch_size = bins.size(0);
const int64_t n_bins = weights.size(1);
const int64_t n_samples = outputs.size(1);
const int64_t threads = 32;
const int64_t T = n_samples <= threads ? 1 : 2;
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * threads);
const int64_t n_chunks = chunks_per_batch * batch_size;
const int64_t max_blocks = 1024;
const int64_t blocks = n_chunks < max_blocks ? n_chunks : max_blocks;
SamplePdfCudaKernel<<<blocks, threads, 0, stream>>>(
bins.contiguous().data_ptr<float>(),
weights.contiguous().data_ptr<float>(),
outputs.data_ptr<float>(), // Checked contiguous in header file.
eps,
T,
batch_size,
n_bins,
n_samples);
AT_CUDA_CHECK(cudaGetLastError());
}

View File

@ -0,0 +1,74 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "utils/pytorch3d_cutils.h"
// ****************************************************************************
// * SamplePdf *
// ****************************************************************************
// Samples a probability density functions defined by bin edges `bins` and
// the non-negative per-bin probabilities `weights`.
// Args:
// bins: FloatTensor of shape `(batch_size, n_bins+1)` denoting the edges
// of the sampling bins.
// weights: FloatTensor of shape `(batch_size, n_bins)` containing
// non-negative numbers representing the probability of sampling the
// corresponding bin.
// uniforms: The quantiles to draw, FloatTensor of shape
// `(batch_size, n_samples)`.
// outputs: On call, this contains the quantiles to draw. It is overwritten
// with the drawn samples. FloatTensor of shape
// `(batch_size, n_samples), where `n_samples are drawn from each
// distribution.
// eps: A constant preventing division by zero in case empty bins are
// present.
// Not differentiable
#ifdef WITH_CUDA
void SamplePdfCuda(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps);
#endif
void SamplePdfCpu(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps);
inline void SamplePdf(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps) {
if (bins.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(weights);
CHECK_CONTIGUOUS_CUDA(outputs);
SamplePdfCuda(bins, weights, outputs, eps);
return;
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CONTIGUOUS(outputs);
SamplePdfCpu(bins, weights, outputs, eps);
}

View File

@ -0,0 +1,141 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <algorithm>
#include <thread>
#include <vector>
// If the number of bins is the typical 64, it is
// quicker to use binary search than linear scan.
// With more bins, it is more important.
// There is no equivalent CUDA implementation yet.
#define USE_BINARY_SEARCH
namespace {
// This worker function does the job of SamplePdf but only on
// batch elements in [start_batch, end_batch).
void SamplePdfCpu_worker(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps,
int64_t start_batch,
int64_t end_batch) {
const int64_t n_bins = weights.size(1);
const int64_t n_samples = outputs.size(1);
auto bins_a = bins.accessor<float, 2>();
auto weights_a = weights.accessor<float, 2>();
float* __restrict__ output_p =
outputs.data_ptr<float>() + start_batch * n_samples;
#ifdef USE_BINARY_SEARCH
std::vector<float> partial_sums(n_bins);
#endif
for (int64_t i_batch_elt = start_batch; i_batch_elt < end_batch;
++i_batch_elt) {
auto bin_a = bins_a[i_batch_elt];
auto weight_a = weights_a[i_batch_elt];
// Here we do the work which has to be done once per batch element.
// i.e. (1) finding the total weight. (2) If using binary search,
// precompute the partial sums of the weights.
float total_weight = 0;
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
total_weight += weight_a[i_bin];
#ifdef USE_BINARY_SEARCH
partial_sums[i_bin] = total_weight;
#endif
}
total_weight += eps;
for (int64_t i_sample = 0; i_sample < n_samples; ++i_sample) {
// Here we are taking a single random quantile (which is stored
// in *output_p) and using it to make a single sample, which we
// write back to the same location. First we find which bin
// the quantile lives in, either by binary search in the
// precomputed partial sums, or by scanning through the weights.
float uniform = total_weight * *output_p;
#ifdef USE_BINARY_SEARCH
int64_t i_bin = std::lower_bound(
partial_sums.begin(), --partial_sums.end(), uniform) -
partial_sums.begin();
if (i_bin > 0) {
uniform -= partial_sums[i_bin - 1];
}
#else
int64_t i_bin = 0;
while (i_bin + 1 < n_bins && uniform > weight_a[i_bin]) {
uniform -= weight_a[i_bin];
++i_bin;
}
#endif
// Now i_bin identifies the bin the quantile lives in, we use
// straight line interpolation to find the position of the
// quantile within the bin, and write it to *output_p.
float bin_start = bin_a[i_bin];
float bin_end = bin_a[i_bin + 1];
float bin_weight = weight_a[i_bin];
float output_value = bin_start;
if (uniform > bin_weight) {
output_value = bin_end;
} else if (bin_weight > eps) {
output_value += (uniform / bin_weight) * (bin_end - bin_start);
}
*output_p = output_value;
++output_p;
}
}
}
} // anonymous namespace
void SamplePdfCpu(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps) {
const int64_t batch_size = bins.size(0);
const int64_t max_threads = std::min(4, at::get_num_threads());
const int64_t n_threads = std::min(max_threads, batch_size);
if (batch_size == 0) {
return;
}
// SamplePdfCpu_worker does the work of this function. We send separate ranges
// of batch elements to that function in nThreads-1 separate threads.
std::vector<std::thread> threads;
threads.reserve(n_threads - 1);
const int64_t batch_elements_per_thread = 1 + (batch_size - 1) / n_threads;
int64_t start_batch = 0;
for (int iThread = 0; iThread < n_threads - 1; ++iThread) {
threads.emplace_back(
SamplePdfCpu_worker,
bins,
weights,
outputs,
eps,
start_batch,
start_batch + batch_elements_per_thread);
start_batch += batch_elements_per_thread;
}
// The remaining batch elements are calculated in this threads. If nThreads is
// 1 then all the work happens in this line.
SamplePdfCpu_worker(bins, weights, outputs, eps, start_batch, batch_size);
for (auto&& thread : threads) {
thread.join();
}
}

View File

@ -6,6 +6,62 @@
import torch
from pytorch3d import _C
def sample_pdf(
bins: torch.Tensor,
weights: torch.Tensor,
n_samples: int,
det: bool = False,
eps: float = 1e-5,
) -> torch.Tensor:
"""
Samples probability density functions defined by bin edges `bins` and
the non-negative per-bin probabilities `weights`.
Args:
bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins.
weights: Tensor of shape `(..., n_bins)` containing non-negative numbers
representing the probability of sampling the corresponding bin.
n_samples: The number of samples to draw from each set of bins.
det: If `False`, the sampling is random. `True` yields deterministic
uniformly-spaced sampling from the inverse cumulative density function.
eps: A constant preventing division by zero in case empty bins are present.
Returns:
samples: Tensor of shape `(..., n_samples)` containing `n_samples` samples
drawn from each probability distribution.
Refs:
[1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501
"""
if torch.is_grad_enabled() and (bins.requires_grad or weights.requires_grad):
raise NotImplementedError("sample_pdf differentiability.")
if weights.min() <= -eps:
raise ValueError("Negative weights provided.")
batch_shape = bins.shape[:-1]
n_bins = weights.shape[-1]
if n_bins + 1 != bins.shape[-1] or weights.shape[:-1] != batch_shape:
shapes = f"{bins.shape}{weights.shape}"
raise ValueError("Inconsistent shapes of bins and weights: " + shapes)
output_shape = batch_shape + (n_samples,)
if det:
u = torch.linspace(0.0, 1.0, n_samples, device=bins.device, dtype=torch.float32)
output = u.expand(output_shape).contiguous()
else:
output = torch.rand(output_shape, dtype=torch.float32, device=bins.device)
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
_C.sample_pdf(
bins.reshape(-1, n_bins + 1),
weights.reshape(-1, n_bins),
output.reshape(-1, n_samples),
eps,
)
return output
def sample_pdf_python(
@ -16,6 +72,12 @@ def sample_pdf_python(
eps: float = 1e-5,
) -> torch.Tensor:
"""
This is a pure python implementation of the `sample_pdf` function.
It may be faster than sample_pdf when the number of bins is very large,
because it behaves as O(batchsize * [n_bins + log(n_bins) * n_samples] )
whereas sample_pdf behaves as O(batchsize * n_bins * n_samples).
For 64 bins sample_pdf is much faster.
Samples probability density functions defined by bin edges `bins` and
the non-negative per-bin probabilities `weights`.

View File

@ -12,7 +12,7 @@ from test_sample_pdf import TestSamplePDF
def bm_sample_pdf() -> None:
backends = ["python_cuda", "python_cpu"]
backends = ["python_cuda", "cuda", "python_cpu", "cpu"]
kwargs_list = []
sample_counts = [64]

View File

@ -5,10 +5,11 @@
# LICENSE file in the root directory of this source tree.
import unittest
from itertools import product
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf_python
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf, sample_pdf_python
class TestSamplePDF(TestCaseMixin, unittest.TestCase):
@ -23,9 +24,59 @@ class TestSamplePDF(TestCaseMixin, unittest.TestCase):
calc = torch.linspace(17, 18, 100).expand(5, -1)
self.assertClose(output, calc)
def test_simple_det(self):
for n_bins, n_samples, batch in product(
[7, 20], [2, 7, 31, 32, 33], [(), (1, 4), (31,), (32,), (33,)]
):
weights = torch.rand(size=(batch + (n_bins,)))
bins = torch.cumsum(torch.rand(size=(batch + (n_bins + 1,))), dim=-1)
python = sample_pdf_python(bins, weights, n_samples, det=True)
cpp = sample_pdf(bins, weights, n_samples, det=True)
self.assertClose(cpp, python, atol=2e-3)
nthreads = torch.get_num_threads()
torch.set_num_threads(1)
cpp_singlethread = sample_pdf(bins, weights, n_samples, det=True)
self.assertClose(cpp_singlethread, python, atol=2e-3)
torch.set_num_threads(nthreads)
device = torch.device("cuda:0")
cuda = sample_pdf(
bins.to(device), weights.to(device), n_samples, det=True
).cpu()
self.assertClose(cuda, python, atol=2e-3)
def test_rand_cpu(self):
n_bins, n_samples, batch_size = 11, 17, 9
weights = torch.rand(size=(batch_size, n_bins))
bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1)
torch.manual_seed(1)
python = sample_pdf_python(bins, weights, n_samples)
torch.manual_seed(1)
cpp = sample_pdf(bins, weights, n_samples)
self.assertClose(cpp, python, atol=2e-3)
def test_rand_nogap(self):
# Case where random is actually deterministic
weights = torch.FloatTensor([0, 10, 0])
bins = torch.FloatTensor([0, 10, 10, 25])
n_samples = 8
predicted = torch.full((n_samples,), 10.0)
python = sample_pdf_python(bins, weights, n_samples)
self.assertClose(python, predicted)
cpp = sample_pdf(bins, weights, n_samples)
self.assertClose(cpp, predicted)
device = torch.device("cuda:0")
cuda = sample_pdf(bins.to(device), weights.to(device), n_samples).cpu()
self.assertClose(cuda, predicted)
@staticmethod
def bm_fn(*, backend: str, n_samples, batch_size, n_bins):
f = sample_pdf_python
f = sample_pdf_python if "python" in backend else sample_pdf
weights = torch.rand(size=(batch_size, n_bins))
bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1)