pytorch3d/tests/test_sample_farthest_points.py
Tim Hatch 34bbb3ad32 apply import merging for fbcode/vision/fair (2 of 2)
Summary:
Applies new import merging and sorting from µsort v1.0.

When merging imports, µsort will make a best-effort to move associated
comments to match merged elements, but there are known limitations due to
the diynamic nature of Python and developer tooling. These changes should
not produce any dangerous runtime changes, but may require touch-ups to
satisfy linters and other tooling.

Note that µsort uses case-insensitive, lexicographical sorting, which
results in a different ordering compared to isort. This provides a more
consistent sorting order, matching the case-insensitive order used when
sorting import statements by module name, and ensures that "frog", "FROG",
and "Frog" always sort next to each other.

For details on µsort's sorting and merging semantics, see the user guide:
https://usort.readthedocs.io/en/stable/guide.html#sorting

Reviewed By: bottler

Differential Revision: D35553814

fbshipit-source-id: be49bdb6a4c25264ff8d4db3a601f18736d17be1
2022-04-13 06:51:33 -07:00

276 lines
11 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
import torch
from common_testing import (
get_pytorch3d_dir,
get_random_cuda_device,
get_tests_dir,
TestCaseMixin,
)
from pytorch3d.io import load_obj
from pytorch3d.ops.sample_farthest_points import (
sample_farthest_points,
sample_farthest_points_naive,
)
from pytorch3d.ops.utils import masked_gather
DATA_DIR = get_tests_dir() / "data"
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
DEBUG = False
class TestFPS(TestCaseMixin, unittest.TestCase):
def _test_simple(self, fps_func, device="cpu"):
# fmt: off
points = torch.tensor(
[
[
[-1.0, -1.0], # noqa: E241, E201
[-1.3, 1.1], # noqa: E241, E201
[ 0.2, -1.1], # noqa: E241, E201
[ 0.0, 0.0], # noqa: E241, E201
[ 1.3, 1.3], # noqa: E241, E201
[ 1.0, 0.5], # noqa: E241, E201
[-1.3, 0.2], # noqa: E241, E201
[ 1.5, -0.5], # noqa: E241, E201
],
[
[-2.2, -2.4], # noqa: E241, E201
[-2.1, 2.0], # noqa: E241, E201
[ 2.2, 2.1], # noqa: E241, E201
[ 2.1, -2.4], # noqa: E241, E201
[ 0.4, -1.0], # noqa: E241, E201
[ 0.3, 0.3], # noqa: E241, E201
[ 1.2, 0.5], # noqa: E241, E201
[ 4.5, 4.5], # noqa: E241, E201
],
],
dtype=torch.float32,
device=device,
)
# fmt: on
expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device)
out_points, out_inds = fps_func(points, K=2)
self.assertClose(out_inds, expected_inds)
# Gather the points
expected_inds = expected_inds[..., None].expand(-1, -1, points.shape[-1])
self.assertClose(out_points, points.gather(dim=1, index=expected_inds))
# Different number of points sampled for each pointcloud in the batch
expected_inds = torch.tensor(
[[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device
)
out_points, out_inds = fps_func(points, K=[3, 2])
self.assertClose(out_inds, expected_inds)
# Gather the points
expected_points = masked_gather(points, expected_inds)
self.assertClose(out_points, expected_points)
def _test_compare_random_heterogeneous(self, device="cpu"):
N, P, D, K = 5, 20, 5, 8
points = torch.randn((N, P, D), device=device, dtype=torch.float32)
out_points_naive, out_idxs_naive = sample_farthest_points_naive(points, K=K)
out_points, out_idxs = sample_farthest_points(points, K=K)
self.assertTrue(out_idxs.min() >= 0)
self.assertClose(out_idxs, out_idxs_naive)
self.assertClose(out_points, out_points_naive)
for n in range(N):
self.assertEqual(out_idxs[n].ne(-1).sum(), K)
# Test case where K > P
K = 30
points1 = torch.randn((N, P, D), dtype=torch.float32, device=device)
points2 = points1.clone()
points1.requires_grad = True
points2.requires_grad = True
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
out_points_naive, out_idxs_naive = sample_farthest_points_naive(
points1, lengths, K=K
)
out_points, out_idxs = sample_farthest_points(points2, lengths, K=K)
self.assertClose(out_idxs, out_idxs_naive)
self.assertClose(out_points, out_points_naive)
for n in range(N):
# Check that for heterogeneous batches, the max number of
# selected points is less than the length
self.assertTrue(out_idxs[n].ne(-1).sum() <= lengths[n])
self.assertTrue(out_idxs[n].max() <= lengths[n])
# Check there are no duplicate indices
val_mask = out_idxs[n].ne(-1)
vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True)
self.assertTrue(counts.le(1).all())
# Check gradients
grad_sampled_points = torch.ones((N, K, D), dtype=torch.float32, device=device)
loss1 = (out_points_naive * grad_sampled_points).sum()
loss1.backward()
loss2 = (out_points * grad_sampled_points).sum()
loss2.backward()
self.assertClose(points1.grad, points2.grad, atol=5e-6)
def _test_errors(self, fps_func, device="cpu"):
N, P, D, K = 5, 40, 5, 8
points = torch.randn((N, P, D), device=device)
wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device)
# K has diferent batch dimension to points
with self.assertRaisesRegex(ValueError, "K and points must have"):
sample_farthest_points_naive(points, K=wrong_batch_dim)
# lengths has diferent batch dimension to points
with self.assertRaisesRegex(ValueError, "points and lengths must have"):
sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K)
def _test_random_start(self, fps_func, device="cpu"):
N, P, D, K = 5, 40, 5, 8
points = torch.randn((N, P, D), dtype=torch.float32, device=device)
out_points, out_idxs = fps_func(points, K=K, random_start_point=True)
# Check the first index is not 0 or the same number for all batch elements
# when random_start_point = True
self.assertTrue(out_idxs[:, 0].sum() > 0)
self.assertFalse(out_idxs[:, 0].eq(out_idxs[0, 0]).all())
def _test_gradcheck(self, fps_func, device="cpu"):
N, P, D, K = 2, 10, 3, 2
points = torch.randn(
(N, P, D), dtype=torch.float32, device=device, requires_grad=True
)
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
torch.autograd.gradcheck(
fps_func,
(points, lengths, K),
check_undefined_grad=False,
eps=2e-3,
atol=0.001,
)
def test_sample_farthest_points_naive(self):
device = get_random_cuda_device()
self._test_simple(sample_farthest_points_naive, device)
self._test_errors(sample_farthest_points_naive, device)
self._test_random_start(sample_farthest_points_naive, device)
self._test_gradcheck(sample_farthest_points_naive, device)
def test_sample_farthest_points_cpu(self):
self._test_simple(sample_farthest_points, "cpu")
self._test_errors(sample_farthest_points, "cpu")
self._test_compare_random_heterogeneous("cpu")
self._test_random_start(sample_farthest_points, "cpu")
self._test_gradcheck(sample_farthest_points, "cpu")
def test_sample_farthest_points_cuda(self):
device = get_random_cuda_device()
self._test_simple(sample_farthest_points, device)
self._test_errors(sample_farthest_points, device)
self._test_compare_random_heterogeneous(device)
self._test_random_start(sample_farthest_points, device)
self._test_gradcheck(sample_farthest_points, device)
def test_cuda_vs_cpu(self):
"""
Compare cuda vs cpu on a complex object
"""
obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj"
K = 250
# Run on CPU
device = "cpu"
points, _, _ = load_obj(obj_filename, device=device, load_textures=False)
points = points[None, ...]
out_points_cpu, out_idxs_cpu = sample_farthest_points(points, K=K)
# Run on GPU
device = get_random_cuda_device()
points_cuda = points.to(device)
out_points_cuda, out_idxs_cuda = sample_farthest_points(points_cuda, K=K)
# Check that the indices from CUDA and CPU match
self.assertClose(out_idxs_cpu, out_idxs_cuda.cpu())
# Check there are no duplicate indices
val_mask = out_idxs_cuda[0].ne(-1)
vals, counts = torch.unique(out_idxs_cuda[0][val_mask], return_counts=True)
self.assertTrue(counts.le(1).all())
# Plot all results
if DEBUG:
# mplot3d is required for 3d projection plots
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d # noqa: F401
# Move to cpu and convert to numpy for plotting
points = points.squeeze()
out_points_cpu = out_points_cpu.squeeze().numpy()
out_points_cuda = out_points_cuda.squeeze().cpu().numpy()
# Farthest point sampling CPU
fig = plt.figure(figsize=plt.figaspect(1.0 / 3))
ax1 = fig.add_subplot(1, 3, 1, projection="3d")
ax1.scatter(*points.T, alpha=0.1)
ax1.scatter(*out_points_cpu.T, color="black")
ax1.set_title("FPS CPU")
# Farthest point sampling CUDA
ax2 = fig.add_subplot(1, 3, 2, projection="3d")
ax2.scatter(*points.T, alpha=0.1)
ax2.scatter(*out_points_cuda.T, color="red")
ax2.set_title("FPS CUDA")
# Random Sampling
random_points = np.random.permutation(points)[:K]
ax3 = fig.add_subplot(1, 3, 3, projection="3d")
ax3.scatter(*points.T, alpha=0.1)
ax3.scatter(*random_points.T, color="green")
ax3.set_title("Random")
# Save image
filename = "DEBUG_fps.jpg"
filepath = DATA_DIR / filename
plt.savefig(filepath)
@staticmethod
def sample_farthest_points_naive(N: int, P: int, D: int, K: int, device: str):
device = torch.device(device)
pts = torch.randn(
N, P, D, dtype=torch.float32, device=device, requires_grad=True
)
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
torch.cuda.synchronize()
def output():
out_points, _ = sample_farthest_points_naive(pts, K=K)
loss = (out_points * grad_pts).sum()
loss.backward()
torch.cuda.synchronize()
return output
@staticmethod
def sample_farthest_points(N: int, P: int, D: int, K: int, device: str):
device = torch.device(device)
pts = torch.randn(
N, P, D, dtype=torch.float32, device=device, requires_grad=True
)
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
torch.cuda.synchronize()
def output():
out_points, _ = sample_farthest_points(pts, K=K)
loss = (out_points * grad_pts).sum()
loss.backward()
torch.cuda.synchronize()
return output