mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Farthest point sampling CUDA
Summary:
CUDA implementation of farthest point sampling algorithm.
## Visual comparison
Compared to random sampling, farthest point sampling gives better coverage of the shape.
{F658631262}
## Reduction
Parallelized block reduction to find the max value at each iteration happens as follows:
1. First split the points into two equal sized parts (e.g. for a list with 8 values):
`[20, 27, 6, 8 | 11, 10, 2, 33]`
2. Use half of the thread (4 threads) to compare pairs of elements from each half (e.g elements [0, 4], [1, 5] etc) and store the result in the first half of the list:
`[20, 27, 6, 33 | 11, 10, 2, 33]`
Now we no longer care about the second part but again divide the first part into two
`[20, 27 | 6, 33| -, -, -, -]`
Now we can use 2 threads to compare the 4 elements
4. Finally we have gotten down to a single pair
`[20 | 33 | -, - | -, -, -, -]`
Use 1 thread to compare the remaining two elements
5. The max will now be at thread id = 0
`[33 | - | -, - | -, -, -, -]`
The reduction will give the farthest point for the selected batch index at this iteration.
Reviewed By: bottler, jcjohnson
Differential Revision: D30401803
fbshipit-source-id: 525bd5ae27c4b13b501812cfe62306bb003827d2
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d9f7611c4b
commit
bd04ffaf77
@@ -29,8 +29,17 @@ def bm_fps() -> None:
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
kwargs_list = [k for k in kwargs_list if k["device"] == "cpu"]
|
||||
benchmark(TestFPS.sample_farthest_points, "FPS_CPU", kwargs_list, warmup_iters=1)
|
||||
# Add some larger batch sizes and pointcloud sizes
|
||||
Ns = [32]
|
||||
Ps = [2048, 8192, 18384]
|
||||
Ds = [3, 9]
|
||||
Ks = [24, 48]
|
||||
test_cases = product(Ns, Ps, Ds, Ks, backends)
|
||||
for case in test_cases:
|
||||
N, P, D, K, d = case
|
||||
kwargs_list.append({"N": N, "P": P, "D": D, "K": K, "device": d})
|
||||
|
||||
benchmark(TestFPS.sample_farthest_points, "FPS", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,14 +6,25 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from common_testing import (
|
||||
TestCaseMixin,
|
||||
get_random_cuda_device,
|
||||
get_tests_dir,
|
||||
get_pytorch3d_dir,
|
||||
)
|
||||
from pytorch3d.io import load_obj
|
||||
from pytorch3d.ops.sample_farthest_points import (
|
||||
sample_farthest_points_naive,
|
||||
sample_farthest_points,
|
||||
)
|
||||
from pytorch3d.ops.utils import masked_gather
|
||||
|
||||
DATA_DIR = get_tests_dir() / "data"
|
||||
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
||||
DEBUG = False
|
||||
|
||||
|
||||
class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||
def _test_simple(self, fps_func, device="cpu"):
|
||||
@@ -123,22 +134,22 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def _test_random_start(self, fps_func, device="cpu"):
|
||||
N, P, D, K = 5, 40, 5, 8
|
||||
points = torch.randn((N, P, D), device=device)
|
||||
out_points, out_idxs = sample_farthest_points_naive(
|
||||
points, K=K, random_start_point=True
|
||||
)
|
||||
# Check the first index is not 0 for all batch elements
|
||||
points = torch.randn((N, P, D), dtype=torch.float32, device=device)
|
||||
out_points, out_idxs = fps_func(points, K=K, random_start_point=True)
|
||||
# Check the first index is not 0 or the same number for all batch elements
|
||||
# when random_start_point = True
|
||||
self.assertTrue(out_idxs[:, 0].sum() > 0)
|
||||
self.assertFalse(out_idxs[:, 0].eq(out_idxs[0, 0]).all())
|
||||
|
||||
def _test_gradcheck(self, fps_func, device="cpu"):
|
||||
N, P, D, K = 2, 5, 3, 2
|
||||
N, P, D, K = 2, 10, 3, 2
|
||||
points = torch.randn(
|
||||
(N, P, D), dtype=torch.float32, device=device, requires_grad=True
|
||||
)
|
||||
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
|
||||
torch.autograd.gradcheck(
|
||||
fps_func,
|
||||
(points, None, K),
|
||||
(points, lengths, K),
|
||||
check_undefined_grad=False,
|
||||
eps=2e-3,
|
||||
atol=0.001,
|
||||
@@ -158,6 +169,76 @@ class TestFPS(TestCaseMixin, unittest.TestCase):
|
||||
self._test_random_start(sample_farthest_points, "cpu")
|
||||
self._test_gradcheck(sample_farthest_points, "cpu")
|
||||
|
||||
def test_sample_farthest_points_cuda(self):
|
||||
device = get_random_cuda_device()
|
||||
self._test_simple(sample_farthest_points, device)
|
||||
self._test_errors(sample_farthest_points, device)
|
||||
self._test_compare_random_heterogeneous(device)
|
||||
self._test_random_start(sample_farthest_points, device)
|
||||
self._test_gradcheck(sample_farthest_points, device)
|
||||
|
||||
def test_cuda_vs_cpu(self):
|
||||
"""
|
||||
Compare cuda vs cpu on a complex object
|
||||
"""
|
||||
obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj"
|
||||
K = 250
|
||||
|
||||
# Run on CPU
|
||||
device = "cpu"
|
||||
points, _, _ = load_obj(obj_filename, device=device, load_textures=False)
|
||||
points = points[None, ...]
|
||||
out_points_cpu, out_idxs_cpu = sample_farthest_points(points, K=K)
|
||||
|
||||
# Run on GPU
|
||||
device = get_random_cuda_device()
|
||||
points_cuda = points.to(device)
|
||||
out_points_cuda, out_idxs_cuda = sample_farthest_points(points_cuda, K=K)
|
||||
|
||||
# Check that the indices from CUDA and CPU match
|
||||
self.assertClose(out_idxs_cpu, out_idxs_cuda.cpu())
|
||||
|
||||
# Check there are no duplicate indices
|
||||
val_mask = out_idxs_cuda[0].ne(-1)
|
||||
vals, counts = torch.unique(out_idxs_cuda[0][val_mask], return_counts=True)
|
||||
self.assertTrue(counts.le(1).all())
|
||||
|
||||
# Plot all results
|
||||
if DEBUG:
|
||||
# mplot3d is required for 3d projection plots
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits import mplot3d # noqa: F401
|
||||
|
||||
# Move to cpu and convert to numpy for plotting
|
||||
points = points.squeeze()
|
||||
out_points_cpu = out_points_cpu.squeeze().numpy()
|
||||
out_points_cuda = out_points_cuda.squeeze().cpu().numpy()
|
||||
|
||||
# Farthest point sampling CPU
|
||||
fig = plt.figure(figsize=plt.figaspect(1.0 / 3))
|
||||
ax1 = fig.add_subplot(1, 3, 1, projection="3d")
|
||||
ax1.scatter(*points.T, alpha=0.1)
|
||||
ax1.scatter(*out_points_cpu.T, color="black")
|
||||
ax1.set_title("FPS CPU")
|
||||
|
||||
# Farthest point sampling CUDA
|
||||
ax2 = fig.add_subplot(1, 3, 2, projection="3d")
|
||||
ax2.scatter(*points.T, alpha=0.1)
|
||||
ax2.scatter(*out_points_cuda.T, color="red")
|
||||
ax2.set_title("FPS CUDA")
|
||||
|
||||
# Random Sampling
|
||||
random_points = np.random.permutation(points)[:K]
|
||||
ax3 = fig.add_subplot(1, 3, 3, projection="3d")
|
||||
ax3.scatter(*points.T, alpha=0.1)
|
||||
ax3.scatter(*random_points.T, color="green")
|
||||
ax3.set_title("Random")
|
||||
|
||||
# Save image
|
||||
filename = "DEBUG_fps.jpg"
|
||||
filepath = DATA_DIR / filename
|
||||
plt.savefig(filepath)
|
||||
|
||||
@staticmethod
|
||||
def sample_farthest_points_naive(N: int, P: int, D: int, K: int, device: str):
|
||||
device = torch.device(device)
|
||||
|
||||
Reference in New Issue
Block a user