mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
sample_pdf CUDA and C++ implementations.
Summary: Implement the sample_pdf function from the NeRF project as compiled operators.. The binary search (in searchsorted) is replaced with a low tech linear search, but this is not a problem for the envisaged numbers of bins. Reviewed By: gkioxari Differential Revision: D26312535 fbshipit-source-id: df1c3119cd63d944380ed1b2657b6ad81d743e49
This commit is contained in:
parent
7d7d00f288
commit
1ea2b7272a
@ -26,6 +26,7 @@
|
||||
#include "point_mesh/point_mesh_cuda.h"
|
||||
#include "rasterize_meshes/rasterize_meshes.h"
|
||||
#include "rasterize_points/rasterize_points.h"
|
||||
#include "sample_pdf/sample_pdf.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
|
||||
@ -83,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward);
|
||||
m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward);
|
||||
|
||||
// Sample PDF
|
||||
m.def("sample_pdf", &SamplePdf);
|
||||
|
||||
// Pulsar.
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
c10::ShowLogInfoToStderr();
|
||||
|
153
pytorch3d/csrc/sample_pdf/sample_pdf.cu
Normal file
153
pytorch3d/csrc/sample_pdf/sample_pdf.cu
Normal file
@ -0,0 +1,153 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
// There is no intermediate memory, so no reason not to have blocksize=32.
|
||||
// 256 is a reasonable number of blocks.
|
||||
|
||||
// DESIGN
|
||||
// We exploit the fact that n_samples is not tiny.
|
||||
// A chunk of work is T*blocksize many samples from
|
||||
// a single batch elememt.
|
||||
// For each batch element there will be
|
||||
// chunks_per_batch = 1 + (n_samples-1)/(T*blocksize) of them.
|
||||
// The number of potential chunks to do is
|
||||
// n_chunks = chunks_per_batch * n_batches.
|
||||
// These chunks are divided among the gridSize-many blocks.
|
||||
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
|
||||
// In chunk i, we work on batch_element i/chunks_per_batch
|
||||
// on samples starting from (i%chunks_per_batch) * (T*blocksize)
|
||||
|
||||
// BEGIN HYPOTHETICAL
|
||||
// Another option (not implemented) if batch_size was always large
|
||||
// would be as follows.
|
||||
|
||||
// A chunk of work is S samples from each of blocksize-many
|
||||
// batch elements.
|
||||
// For each batch element there will be
|
||||
// chunks_per_batch = (1+(n_samples-1)/S) of them.
|
||||
// The number of potential chunks to do is
|
||||
// n_chunks = chunks_per_batch * (1+(n_batches-1)/blocksize)
|
||||
// These chunks are divided among the gridSize-many blocks.
|
||||
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
|
||||
// In chunk i, we work on samples starting from S*(i%chunks_per_batch)
|
||||
// on batch elements starting from blocksize*(i/chunks_per_batch).
|
||||
// END HYPOTHETICAL
|
||||
|
||||
__global__ void SamplePdfCudaKernel(
|
||||
const float* __restrict__ bins,
|
||||
const float* __restrict__ weights,
|
||||
float* __restrict__ outputs,
|
||||
float eps,
|
||||
const int T,
|
||||
const int64_t batch_size,
|
||||
const int64_t n_bins,
|
||||
const int64_t n_samples) {
|
||||
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * blockDim.x);
|
||||
const int64_t n_chunks = chunks_per_batch * batch_size;
|
||||
|
||||
for (int64_t i_chunk = blockIdx.x; i_chunk < n_chunks; i_chunk += gridDim.x) {
|
||||
// Loop over the chunks.
|
||||
int64_t i_batch_element = i_chunk / chunks_per_batch;
|
||||
int64_t sample_start = (i_chunk % chunks_per_batch) * (T * blockDim.x);
|
||||
const float* const weight_startp = weights + n_bins * i_batch_element;
|
||||
const float* const bin_startp = bins + (1 + n_bins) * i_batch_element;
|
||||
|
||||
// Each chunk looks at a single batch element, so we do the preprocessing
|
||||
// which depends on the batch element, namely finding the total weight.
|
||||
// Idenntical work is being done in sync here by every thread of the block.
|
||||
float total_weight = eps;
|
||||
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
|
||||
total_weight += weight_startp[i_bin];
|
||||
}
|
||||
|
||||
float* const output_startp =
|
||||
outputs + n_samples * i_batch_element + sample_start;
|
||||
|
||||
for (int t = 0; t < T; ++t) {
|
||||
// Loop over T, which is the number of samples each thread makes within
|
||||
// the chunk.
|
||||
const int64_t i_sample_within_chunk = threadIdx.x + t * blockDim.x;
|
||||
if (sample_start + i_sample_within_chunk >= n_samples) {
|
||||
// Some threads need to exit early because the sample they would
|
||||
// make is unwanted.
|
||||
continue;
|
||||
}
|
||||
// output_startp[i_sample_within_chunk] contains the quantile we (i.e.
|
||||
// this thread) are calcvulating.
|
||||
float uniform = total_weight * output_startp[i_sample_within_chunk];
|
||||
int64_t i_bin = 0;
|
||||
// We find the bin containing the quantile by walking along the weights.
|
||||
// This loop must be thread dependent. I.e. the whole warp will wait until
|
||||
// every thread has found the bin for its quantile.
|
||||
// It may be best to write it differently.
|
||||
while (i_bin + 1 < n_bins && uniform > weight_startp[i_bin]) {
|
||||
uniform -= weight_startp[i_bin];
|
||||
++i_bin;
|
||||
}
|
||||
|
||||
// Now we know which bin to look in, we use linear interpolation
|
||||
// to find the location of the quantile within the bin, and
|
||||
// write the answer back.
|
||||
float bin_start = bin_startp[i_bin];
|
||||
float bin_end = bin_startp[i_bin + 1];
|
||||
float bin_weight = weight_startp[i_bin];
|
||||
float output_value = bin_start;
|
||||
if (uniform > bin_weight) {
|
||||
output_value = bin_end;
|
||||
} else if (bin_weight > eps) {
|
||||
output_value += (uniform / bin_weight) * (bin_end - bin_start);
|
||||
}
|
||||
output_startp[i_sample_within_chunk] = output_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SamplePdfCuda(
|
||||
const at::Tensor& bins,
|
||||
const at::Tensor& weights,
|
||||
const at::Tensor& outputs,
|
||||
float eps) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg bins_t{bins, "bins", 1}, weights_t{weights, "weights", 2},
|
||||
outputs_t{outputs, "outputs", 3};
|
||||
at::CheckedFrom c = "SamplePdfCuda";
|
||||
at::checkAllSameGPU(c, {bins_t, weights_t, outputs_t});
|
||||
at::checkAllSameType(c, {bins_t, weights_t, outputs_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(bins.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int64_t batch_size = bins.size(0);
|
||||
const int64_t n_bins = weights.size(1);
|
||||
const int64_t n_samples = outputs.size(1);
|
||||
|
||||
const int64_t threads = 32;
|
||||
const int64_t T = n_samples <= threads ? 1 : 2;
|
||||
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * threads);
|
||||
const int64_t n_chunks = chunks_per_batch * batch_size;
|
||||
|
||||
const int64_t max_blocks = 1024;
|
||||
const int64_t blocks = n_chunks < max_blocks ? n_chunks : max_blocks;
|
||||
|
||||
SamplePdfCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
bins.contiguous().data_ptr<float>(),
|
||||
weights.contiguous().data_ptr<float>(),
|
||||
outputs.data_ptr<float>(), // Checked contiguous in header file.
|
||||
eps,
|
||||
T,
|
||||
batch_size,
|
||||
n_bins,
|
||||
n_samples);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
74
pytorch3d/csrc/sample_pdf/sample_pdf.h
Normal file
74
pytorch3d/csrc/sample_pdf/sample_pdf.h
Normal file
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include <tuple>
|
||||
#include "utils/pytorch3d_cutils.h"
|
||||
|
||||
// ****************************************************************************
|
||||
// * SamplePdf *
|
||||
// ****************************************************************************
|
||||
|
||||
// Samples a probability density functions defined by bin edges `bins` and
|
||||
// the non-negative per-bin probabilities `weights`.
|
||||
|
||||
// Args:
|
||||
// bins: FloatTensor of shape `(batch_size, n_bins+1)` denoting the edges
|
||||
// of the sampling bins.
|
||||
|
||||
// weights: FloatTensor of shape `(batch_size, n_bins)` containing
|
||||
// non-negative numbers representing the probability of sampling the
|
||||
// corresponding bin.
|
||||
|
||||
// uniforms: The quantiles to draw, FloatTensor of shape
|
||||
// `(batch_size, n_samples)`.
|
||||
|
||||
// outputs: On call, this contains the quantiles to draw. It is overwritten
|
||||
// with the drawn samples. FloatTensor of shape
|
||||
// `(batch_size, n_samples), where `n_samples are drawn from each
|
||||
// distribution.
|
||||
|
||||
// eps: A constant preventing division by zero in case empty bins are
|
||||
// present.
|
||||
|
||||
// Not differentiable
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
void SamplePdfCuda(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps);
|
||||
#endif
|
||||
|
||||
void SamplePdfCpu(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps);
|
||||
|
||||
inline void SamplePdf(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps) {
|
||||
if (bins.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(weights);
|
||||
CHECK_CONTIGUOUS_CUDA(outputs);
|
||||
SamplePdfCuda(bins, weights, outputs, eps);
|
||||
return;
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
SamplePdfCpu(bins, weights, outputs, eps);
|
||||
}
|
141
pytorch3d/csrc/sample_pdf/sample_pdf_cpu.cpp
Normal file
141
pytorch3d/csrc/sample_pdf/sample_pdf_cpu.cpp
Normal file
@ -0,0 +1,141 @@
|
||||
/*
|
||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <algorithm>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// If the number of bins is the typical 64, it is
|
||||
// quicker to use binary search than linear scan.
|
||||
// With more bins, it is more important.
|
||||
// There is no equivalent CUDA implementation yet.
|
||||
#define USE_BINARY_SEARCH
|
||||
|
||||
namespace {
|
||||
// This worker function does the job of SamplePdf but only on
|
||||
// batch elements in [start_batch, end_batch).
|
||||
void SamplePdfCpu_worker(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps,
|
||||
int64_t start_batch,
|
||||
int64_t end_batch) {
|
||||
const int64_t n_bins = weights.size(1);
|
||||
const int64_t n_samples = outputs.size(1);
|
||||
|
||||
auto bins_a = bins.accessor<float, 2>();
|
||||
auto weights_a = weights.accessor<float, 2>();
|
||||
float* __restrict__ output_p =
|
||||
outputs.data_ptr<float>() + start_batch * n_samples;
|
||||
|
||||
#ifdef USE_BINARY_SEARCH
|
||||
std::vector<float> partial_sums(n_bins);
|
||||
#endif
|
||||
|
||||
for (int64_t i_batch_elt = start_batch; i_batch_elt < end_batch;
|
||||
++i_batch_elt) {
|
||||
auto bin_a = bins_a[i_batch_elt];
|
||||
auto weight_a = weights_a[i_batch_elt];
|
||||
|
||||
// Here we do the work which has to be done once per batch element.
|
||||
// i.e. (1) finding the total weight. (2) If using binary search,
|
||||
// precompute the partial sums of the weights.
|
||||
|
||||
float total_weight = 0;
|
||||
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
|
||||
total_weight += weight_a[i_bin];
|
||||
#ifdef USE_BINARY_SEARCH
|
||||
partial_sums[i_bin] = total_weight;
|
||||
#endif
|
||||
}
|
||||
total_weight += eps;
|
||||
|
||||
for (int64_t i_sample = 0; i_sample < n_samples; ++i_sample) {
|
||||
// Here we are taking a single random quantile (which is stored
|
||||
// in *output_p) and using it to make a single sample, which we
|
||||
// write back to the same location. First we find which bin
|
||||
// the quantile lives in, either by binary search in the
|
||||
// precomputed partial sums, or by scanning through the weights.
|
||||
|
||||
float uniform = total_weight * *output_p;
|
||||
#ifdef USE_BINARY_SEARCH
|
||||
int64_t i_bin = std::lower_bound(
|
||||
partial_sums.begin(), --partial_sums.end(), uniform) -
|
||||
partial_sums.begin();
|
||||
if (i_bin > 0) {
|
||||
uniform -= partial_sums[i_bin - 1];
|
||||
}
|
||||
#else
|
||||
int64_t i_bin = 0;
|
||||
while (i_bin + 1 < n_bins && uniform > weight_a[i_bin]) {
|
||||
uniform -= weight_a[i_bin];
|
||||
++i_bin;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Now i_bin identifies the bin the quantile lives in, we use
|
||||
// straight line interpolation to find the position of the
|
||||
// quantile within the bin, and write it to *output_p.
|
||||
|
||||
float bin_start = bin_a[i_bin];
|
||||
float bin_end = bin_a[i_bin + 1];
|
||||
float bin_weight = weight_a[i_bin];
|
||||
float output_value = bin_start;
|
||||
if (uniform > bin_weight) {
|
||||
output_value = bin_end;
|
||||
} else if (bin_weight > eps) {
|
||||
output_value += (uniform / bin_weight) * (bin_end - bin_start);
|
||||
}
|
||||
*output_p = output_value;
|
||||
++output_p;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void SamplePdfCpu(
|
||||
const torch::Tensor& bins,
|
||||
const torch::Tensor& weights,
|
||||
const torch::Tensor& outputs,
|
||||
float eps) {
|
||||
const int64_t batch_size = bins.size(0);
|
||||
const int64_t max_threads = std::min(4, at::get_num_threads());
|
||||
const int64_t n_threads = std::min(max_threads, batch_size);
|
||||
if (batch_size == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// SamplePdfCpu_worker does the work of this function. We send separate ranges
|
||||
// of batch elements to that function in nThreads-1 separate threads.
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
threads.reserve(n_threads - 1);
|
||||
const int64_t batch_elements_per_thread = 1 + (batch_size - 1) / n_threads;
|
||||
int64_t start_batch = 0;
|
||||
for (int iThread = 0; iThread < n_threads - 1; ++iThread) {
|
||||
threads.emplace_back(
|
||||
SamplePdfCpu_worker,
|
||||
bins,
|
||||
weights,
|
||||
outputs,
|
||||
eps,
|
||||
start_batch,
|
||||
start_batch + batch_elements_per_thread);
|
||||
start_batch += batch_elements_per_thread;
|
||||
}
|
||||
|
||||
// The remaining batch elements are calculated in this threads. If nThreads is
|
||||
// 1 then all the work happens in this line.
|
||||
SamplePdfCpu_worker(bins, weights, outputs, eps, start_batch, batch_size);
|
||||
for (auto&& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
@ -6,6 +6,62 @@
|
||||
|
||||
|
||||
import torch
|
||||
from pytorch3d import _C
|
||||
|
||||
|
||||
def sample_pdf(
|
||||
bins: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
n_samples: int,
|
||||
det: bool = False,
|
||||
eps: float = 1e-5,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Samples probability density functions defined by bin edges `bins` and
|
||||
the non-negative per-bin probabilities `weights`.
|
||||
|
||||
Args:
|
||||
bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins.
|
||||
weights: Tensor of shape `(..., n_bins)` containing non-negative numbers
|
||||
representing the probability of sampling the corresponding bin.
|
||||
n_samples: The number of samples to draw from each set of bins.
|
||||
det: If `False`, the sampling is random. `True` yields deterministic
|
||||
uniformly-spaced sampling from the inverse cumulative density function.
|
||||
eps: A constant preventing division by zero in case empty bins are present.
|
||||
|
||||
Returns:
|
||||
samples: Tensor of shape `(..., n_samples)` containing `n_samples` samples
|
||||
drawn from each probability distribution.
|
||||
|
||||
Refs:
|
||||
[1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501
|
||||
"""
|
||||
if torch.is_grad_enabled() and (bins.requires_grad or weights.requires_grad):
|
||||
raise NotImplementedError("sample_pdf differentiability.")
|
||||
if weights.min() <= -eps:
|
||||
raise ValueError("Negative weights provided.")
|
||||
batch_shape = bins.shape[:-1]
|
||||
n_bins = weights.shape[-1]
|
||||
if n_bins + 1 != bins.shape[-1] or weights.shape[:-1] != batch_shape:
|
||||
shapes = f"{bins.shape}{weights.shape}"
|
||||
raise ValueError("Inconsistent shapes of bins and weights: " + shapes)
|
||||
output_shape = batch_shape + (n_samples,)
|
||||
|
||||
if det:
|
||||
u = torch.linspace(0.0, 1.0, n_samples, device=bins.device, dtype=torch.float32)
|
||||
output = u.expand(output_shape).contiguous()
|
||||
else:
|
||||
output = torch.rand(output_shape, dtype=torch.float32, device=bins.device)
|
||||
|
||||
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
|
||||
_C.sample_pdf(
|
||||
bins.reshape(-1, n_bins + 1),
|
||||
weights.reshape(-1, n_bins),
|
||||
output.reshape(-1, n_samples),
|
||||
eps,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def sample_pdf_python(
|
||||
@ -16,6 +72,12 @@ def sample_pdf_python(
|
||||
eps: float = 1e-5,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This is a pure python implementation of the `sample_pdf` function.
|
||||
It may be faster than sample_pdf when the number of bins is very large,
|
||||
because it behaves as O(batchsize * [n_bins + log(n_bins) * n_samples] )
|
||||
whereas sample_pdf behaves as O(batchsize * n_bins * n_samples).
|
||||
For 64 bins sample_pdf is much faster.
|
||||
|
||||
Samples probability density functions defined by bin edges `bins` and
|
||||
the non-negative per-bin probabilities `weights`.
|
||||
|
||||
|
@ -12,7 +12,7 @@ from test_sample_pdf import TestSamplePDF
|
||||
|
||||
def bm_sample_pdf() -> None:
|
||||
|
||||
backends = ["python_cuda", "python_cpu"]
|
||||
backends = ["python_cuda", "cuda", "python_cpu", "cpu"]
|
||||
|
||||
kwargs_list = []
|
||||
sample_counts = [64]
|
||||
|
@ -5,10 +5,11 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf_python
|
||||
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf, sample_pdf_python
|
||||
|
||||
|
||||
class TestSamplePDF(TestCaseMixin, unittest.TestCase):
|
||||
@ -23,9 +24,59 @@ class TestSamplePDF(TestCaseMixin, unittest.TestCase):
|
||||
calc = torch.linspace(17, 18, 100).expand(5, -1)
|
||||
self.assertClose(output, calc)
|
||||
|
||||
def test_simple_det(self):
|
||||
for n_bins, n_samples, batch in product(
|
||||
[7, 20], [2, 7, 31, 32, 33], [(), (1, 4), (31,), (32,), (33,)]
|
||||
):
|
||||
weights = torch.rand(size=(batch + (n_bins,)))
|
||||
bins = torch.cumsum(torch.rand(size=(batch + (n_bins + 1,))), dim=-1)
|
||||
python = sample_pdf_python(bins, weights, n_samples, det=True)
|
||||
|
||||
cpp = sample_pdf(bins, weights, n_samples, det=True)
|
||||
self.assertClose(cpp, python, atol=2e-3)
|
||||
|
||||
nthreads = torch.get_num_threads()
|
||||
torch.set_num_threads(1)
|
||||
cpp_singlethread = sample_pdf(bins, weights, n_samples, det=True)
|
||||
self.assertClose(cpp_singlethread, python, atol=2e-3)
|
||||
torch.set_num_threads(nthreads)
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
cuda = sample_pdf(
|
||||
bins.to(device), weights.to(device), n_samples, det=True
|
||||
).cpu()
|
||||
|
||||
self.assertClose(cuda, python, atol=2e-3)
|
||||
|
||||
def test_rand_cpu(self):
|
||||
n_bins, n_samples, batch_size = 11, 17, 9
|
||||
weights = torch.rand(size=(batch_size, n_bins))
|
||||
bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1)
|
||||
torch.manual_seed(1)
|
||||
python = sample_pdf_python(bins, weights, n_samples)
|
||||
torch.manual_seed(1)
|
||||
cpp = sample_pdf(bins, weights, n_samples)
|
||||
|
||||
self.assertClose(cpp, python, atol=2e-3)
|
||||
|
||||
def test_rand_nogap(self):
|
||||
# Case where random is actually deterministic
|
||||
weights = torch.FloatTensor([0, 10, 0])
|
||||
bins = torch.FloatTensor([0, 10, 10, 25])
|
||||
n_samples = 8
|
||||
predicted = torch.full((n_samples,), 10.0)
|
||||
python = sample_pdf_python(bins, weights, n_samples)
|
||||
self.assertClose(python, predicted)
|
||||
cpp = sample_pdf(bins, weights, n_samples)
|
||||
self.assertClose(cpp, predicted)
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
cuda = sample_pdf(bins.to(device), weights.to(device), n_samples).cpu()
|
||||
self.assertClose(cuda, predicted)
|
||||
|
||||
@staticmethod
|
||||
def bm_fn(*, backend: str, n_samples, batch_size, n_bins):
|
||||
f = sample_pdf_python
|
||||
f = sample_pdf_python if "python" in backend else sample_pdf
|
||||
weights = torch.rand(size=(batch_size, n_bins))
|
||||
bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user