diff --git a/pytorch3d/renderer/implicit/sample_pdf.py b/pytorch3d/renderer/implicit/sample_pdf.py new file mode 100644 index 00000000..e3d7fedf --- /dev/null +++ b/pytorch3d/renderer/implicit/sample_pdf.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + + +def sample_pdf_python( + bins: torch.Tensor, + weights: torch.Tensor, + N_samples: int, + det: bool = False, + eps: float = 1e-5, +) -> torch.Tensor: + """ + Samples probability density functions defined by bin edges `bins` and + the non-negative per-bin probabilities `weights`. + + Note: This is a direct conversion of the TensorFlow function from the original + release [1] to PyTorch. + + Args: + bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins. + weights: Tensor of shape `(..., n_bins)` containing non-negative numbers + representing the probability of sampling the corresponding bin. + N_samples: The number of samples to draw from each set of bins. + det: If `False`, the sampling is random. `True` yields deterministic + uniformly-spaced sampling from the inverse cumulative density function. + eps: A constant preventing division by zero in case empty bins are present. + + Returns: + samples: Tensor of shape `(..., N_samples)` containing `N_samples` samples + drawn from each probability distribution. + + Refs: + [1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501 + """ + + # Get pdf + weights = weights + eps # prevent nans + if weights.min() <= 0: + raise ValueError("Negative weights provided.") + pdf = weights / weights.sum(dim=-1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + + # Take uniform samples u of shape (..., N_samples) + if det: + u = torch.linspace(0.0, 1.0, N_samples, device=cdf.device, dtype=cdf.dtype) + u = u.expand(list(cdf.shape[:-1]) + [N_samples]).contiguous() + else: + u = torch.rand( + list(cdf.shape[:-1]) + [N_samples], device=cdf.device, dtype=cdf.dtype + ) + + # Invert CDF + inds = torch.searchsorted(cdf, u, right=True) + # inds has shape (..., N_samples) identifying the bin of each sample. + below = (inds - 1).clamp(0) + above = inds.clamp(max=cdf.shape[-1] - 1) + # Below and above are of shape (..., N_samples), identifying the bin + # edges surrounding each sample. + + inds_g = torch.stack([below, above], -1).view( + *below.shape[:-1], below.shape[-1] * 2 + ) + cdf_g = torch.gather(cdf, -1, inds_g).view(*below.shape, 2) + bins_g = torch.gather(bins, -1, inds_g).view(*below.shape, 2) + # cdf_g and bins_g are of shape (..., N_samples, 2) and identify + # the cdf and the index of the two bin edges surrounding each sample. + + denom = cdf_g[..., 1] - cdf_g[..., 0] + denom = torch.where(denom < eps, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + # t is of shape (..., N_samples) and identifies how far through + # each sample is in its bin. + + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + return samples diff --git a/tests/bm_sample_pdf.py b/tests/bm_sample_pdf.py new file mode 100644 index 00000000..b56e62cc --- /dev/null +++ b/tests/bm_sample_pdf.py @@ -0,0 +1,37 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +from fvcore.common.benchmark import benchmark +from test_sample_pdf import TestSamplePDF + + +def bm_sample_pdf() -> None: + + backends = ["python_cuda", "python_cpu"] + + kwargs_list = [] + sample_counts = [64] + batch_sizes = [1024, 10240] + bin_counts = [62, 600] + test_cases = product(backends, sample_counts, batch_sizes, bin_counts) + for case in test_cases: + backend, n_samples, batch_size, n_bins = case + kwargs_list.append( + { + "backend": backend, + "n_samples": n_samples, + "batch_size": batch_size, + "n_bins": n_bins, + } + ) + + benchmark(TestSamplePDF.bm_fn, "SAMPLE_PDF", kwargs_list, warmup_iters=1) + + +if __name__ == "__main__": + bm_sample_pdf() diff --git a/tests/test_sample_pdf.py b/tests/test_sample_pdf.py new file mode 100644 index 00000000..ed76cd83 --- /dev/null +++ b/tests/test_sample_pdf.py @@ -0,0 +1,42 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.renderer.implicit.sample_pdf import sample_pdf_python + + +class TestSamplePDF(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(1) + + def test_single_bin(self): + bins = torch.arange(2).expand(5, 2) + 17 + weights = torch.ones(5, 1) + output = sample_pdf_python(bins, weights, 100, True) + calc = torch.linspace(17, 18, 100).expand(5, -1) + self.assertClose(output, calc) + + @staticmethod + def bm_fn(*, backend: str, n_samples, batch_size, n_bins): + f = sample_pdf_python + weights = torch.rand(size=(batch_size, n_bins)) + bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1) + + if "cuda" in backend: + weights = weights.cuda() + bins = bins.cuda() + + torch.cuda.synchronize() + + def output(): + f(bins, weights, n_samples) + torch.cuda.synchronize() + + return output