mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Move sample_pdf into PyTorch3D
Summary: Copy the sample_pdf operation from the NeRF project in to PyTorch3D, in preparation for optimizing it. Reviewed By: gkioxari Differential Revision: D27117930 fbshipit-source-id: 20286b007f589a4c4d53ed818c4bc5f2abd22833
This commit is contained in:
parent
b481cfbd01
commit
7d7d00f288
83
pytorch3d/renderer/implicit/sample_pdf.py
Normal file
83
pytorch3d/renderer/implicit/sample_pdf.py
Normal file
@ -0,0 +1,83 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def sample_pdf_python(
|
||||
bins: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
N_samples: int,
|
||||
det: bool = False,
|
||||
eps: float = 1e-5,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Samples probability density functions defined by bin edges `bins` and
|
||||
the non-negative per-bin probabilities `weights`.
|
||||
|
||||
Note: This is a direct conversion of the TensorFlow function from the original
|
||||
release [1] to PyTorch.
|
||||
|
||||
Args:
|
||||
bins: Tensor of shape `(..., n_bins+1)` denoting the edges of the sampling bins.
|
||||
weights: Tensor of shape `(..., n_bins)` containing non-negative numbers
|
||||
representing the probability of sampling the corresponding bin.
|
||||
N_samples: The number of samples to draw from each set of bins.
|
||||
det: If `False`, the sampling is random. `True` yields deterministic
|
||||
uniformly-spaced sampling from the inverse cumulative density function.
|
||||
eps: A constant preventing division by zero in case empty bins are present.
|
||||
|
||||
Returns:
|
||||
samples: Tensor of shape `(..., N_samples)` containing `N_samples` samples
|
||||
drawn from each probability distribution.
|
||||
|
||||
Refs:
|
||||
[1] https://github.com/bmild/nerf/blob/55d8b00244d7b5178f4d003526ab6667683c9da9/run_nerf_helpers.py#L183 # noqa E501
|
||||
"""
|
||||
|
||||
# Get pdf
|
||||
weights = weights + eps # prevent nans
|
||||
if weights.min() <= 0:
|
||||
raise ValueError("Negative weights provided.")
|
||||
pdf = weights / weights.sum(dim=-1, keepdim=True)
|
||||
cdf = torch.cumsum(pdf, -1)
|
||||
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
||||
|
||||
# Take uniform samples u of shape (..., N_samples)
|
||||
if det:
|
||||
u = torch.linspace(0.0, 1.0, N_samples, device=cdf.device, dtype=cdf.dtype)
|
||||
u = u.expand(list(cdf.shape[:-1]) + [N_samples]).contiguous()
|
||||
else:
|
||||
u = torch.rand(
|
||||
list(cdf.shape[:-1]) + [N_samples], device=cdf.device, dtype=cdf.dtype
|
||||
)
|
||||
|
||||
# Invert CDF
|
||||
inds = torch.searchsorted(cdf, u, right=True)
|
||||
# inds has shape (..., N_samples) identifying the bin of each sample.
|
||||
below = (inds - 1).clamp(0)
|
||||
above = inds.clamp(max=cdf.shape[-1] - 1)
|
||||
# Below and above are of shape (..., N_samples), identifying the bin
|
||||
# edges surrounding each sample.
|
||||
|
||||
inds_g = torch.stack([below, above], -1).view(
|
||||
*below.shape[:-1], below.shape[-1] * 2
|
||||
)
|
||||
cdf_g = torch.gather(cdf, -1, inds_g).view(*below.shape, 2)
|
||||
bins_g = torch.gather(bins, -1, inds_g).view(*below.shape, 2)
|
||||
# cdf_g and bins_g are of shape (..., N_samples, 2) and identify
|
||||
# the cdf and the index of the two bin edges surrounding each sample.
|
||||
|
||||
denom = cdf_g[..., 1] - cdf_g[..., 0]
|
||||
denom = torch.where(denom < eps, torch.ones_like(denom), denom)
|
||||
t = (u - cdf_g[..., 0]) / denom
|
||||
# t is of shape (..., N_samples) and identifies how far through
|
||||
# each sample is in its bin.
|
||||
|
||||
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
||||
|
||||
return samples
|
37
tests/bm_sample_pdf.py
Normal file
37
tests/bm_sample_pdf.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_sample_pdf import TestSamplePDF
|
||||
|
||||
|
||||
def bm_sample_pdf() -> None:
|
||||
|
||||
backends = ["python_cuda", "python_cpu"]
|
||||
|
||||
kwargs_list = []
|
||||
sample_counts = [64]
|
||||
batch_sizes = [1024, 10240]
|
||||
bin_counts = [62, 600]
|
||||
test_cases = product(backends, sample_counts, batch_sizes, bin_counts)
|
||||
for case in test_cases:
|
||||
backend, n_samples, batch_size, n_bins = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"backend": backend,
|
||||
"n_samples": n_samples,
|
||||
"batch_size": batch_size,
|
||||
"n_bins": n_bins,
|
||||
}
|
||||
)
|
||||
|
||||
benchmark(TestSamplePDF.bm_fn, "SAMPLE_PDF", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_sample_pdf()
|
42
tests/test_sample_pdf.py
Normal file
42
tests/test_sample_pdf.py
Normal file
@ -0,0 +1,42 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf_python
|
||||
|
||||
|
||||
class TestSamplePDF(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
def test_single_bin(self):
|
||||
bins = torch.arange(2).expand(5, 2) + 17
|
||||
weights = torch.ones(5, 1)
|
||||
output = sample_pdf_python(bins, weights, 100, True)
|
||||
calc = torch.linspace(17, 18, 100).expand(5, -1)
|
||||
self.assertClose(output, calc)
|
||||
|
||||
@staticmethod
|
||||
def bm_fn(*, backend: str, n_samples, batch_size, n_bins):
|
||||
f = sample_pdf_python
|
||||
weights = torch.rand(size=(batch_size, n_bins))
|
||||
bins = torch.cumsum(torch.rand(size=(batch_size, n_bins + 1)), dim=-1)
|
||||
|
||||
if "cuda" in backend:
|
||||
weights = weights.cuda()
|
||||
bins = bins.cuda()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def output():
|
||||
f(bins, weights, n_samples)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return output
|
Loading…
x
Reference in New Issue
Block a user