mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Summary: Move testing targets from pytorch3d/tests/TARGETS to pytorch3d/TARGETS. Reviewed By: shapovalov Differential Revision: D36186940 fbshipit-source-id: a4c52c4d99351f885e2b0bf870532d530324039b
277 lines
11 KiB
Python
277 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from pytorch3d.io import load_obj
|
|
from pytorch3d.ops.sample_farthest_points import (
|
|
sample_farthest_points,
|
|
sample_farthest_points_naive,
|
|
)
|
|
from pytorch3d.ops.utils import masked_gather
|
|
|
|
from .common_testing import (
|
|
get_pytorch3d_dir,
|
|
get_random_cuda_device,
|
|
get_tests_dir,
|
|
TestCaseMixin,
|
|
)
|
|
|
|
|
|
DATA_DIR = get_tests_dir() / "data"
|
|
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
|
DEBUG = False
|
|
|
|
|
|
class TestFPS(TestCaseMixin, unittest.TestCase):
|
|
def _test_simple(self, fps_func, device="cpu"):
|
|
# fmt: off
|
|
points = torch.tensor(
|
|
[
|
|
[
|
|
[-1.0, -1.0], # noqa: E241, E201
|
|
[-1.3, 1.1], # noqa: E241, E201
|
|
[ 0.2, -1.1], # noqa: E241, E201
|
|
[ 0.0, 0.0], # noqa: E241, E201
|
|
[ 1.3, 1.3], # noqa: E241, E201
|
|
[ 1.0, 0.5], # noqa: E241, E201
|
|
[-1.3, 0.2], # noqa: E241, E201
|
|
[ 1.5, -0.5], # noqa: E241, E201
|
|
],
|
|
[
|
|
[-2.2, -2.4], # noqa: E241, E201
|
|
[-2.1, 2.0], # noqa: E241, E201
|
|
[ 2.2, 2.1], # noqa: E241, E201
|
|
[ 2.1, -2.4], # noqa: E241, E201
|
|
[ 0.4, -1.0], # noqa: E241, E201
|
|
[ 0.3, 0.3], # noqa: E241, E201
|
|
[ 1.2, 0.5], # noqa: E241, E201
|
|
[ 4.5, 4.5], # noqa: E241, E201
|
|
],
|
|
],
|
|
dtype=torch.float32,
|
|
device=device,
|
|
)
|
|
# fmt: on
|
|
expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device)
|
|
out_points, out_inds = fps_func(points, K=2)
|
|
self.assertClose(out_inds, expected_inds)
|
|
|
|
# Gather the points
|
|
expected_inds = expected_inds[..., None].expand(-1, -1, points.shape[-1])
|
|
self.assertClose(out_points, points.gather(dim=1, index=expected_inds))
|
|
|
|
# Different number of points sampled for each pointcloud in the batch
|
|
expected_inds = torch.tensor(
|
|
[[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device
|
|
)
|
|
out_points, out_inds = fps_func(points, K=[3, 2])
|
|
self.assertClose(out_inds, expected_inds)
|
|
|
|
# Gather the points
|
|
expected_points = masked_gather(points, expected_inds)
|
|
self.assertClose(out_points, expected_points)
|
|
|
|
def _test_compare_random_heterogeneous(self, device="cpu"):
|
|
N, P, D, K = 5, 20, 5, 8
|
|
points = torch.randn((N, P, D), device=device, dtype=torch.float32)
|
|
out_points_naive, out_idxs_naive = sample_farthest_points_naive(points, K=K)
|
|
out_points, out_idxs = sample_farthest_points(points, K=K)
|
|
self.assertTrue(out_idxs.min() >= 0)
|
|
self.assertClose(out_idxs, out_idxs_naive)
|
|
self.assertClose(out_points, out_points_naive)
|
|
for n in range(N):
|
|
self.assertEqual(out_idxs[n].ne(-1).sum(), K)
|
|
|
|
# Test case where K > P
|
|
K = 30
|
|
points1 = torch.randn((N, P, D), dtype=torch.float32, device=device)
|
|
points2 = points1.clone()
|
|
points1.requires_grad = True
|
|
points2.requires_grad = True
|
|
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
|
|
out_points_naive, out_idxs_naive = sample_farthest_points_naive(
|
|
points1, lengths, K=K
|
|
)
|
|
out_points, out_idxs = sample_farthest_points(points2, lengths, K=K)
|
|
self.assertClose(out_idxs, out_idxs_naive)
|
|
self.assertClose(out_points, out_points_naive)
|
|
|
|
for n in range(N):
|
|
# Check that for heterogeneous batches, the max number of
|
|
# selected points is less than the length
|
|
self.assertTrue(out_idxs[n].ne(-1).sum() <= lengths[n])
|
|
self.assertTrue(out_idxs[n].max() <= lengths[n])
|
|
|
|
# Check there are no duplicate indices
|
|
val_mask = out_idxs[n].ne(-1)
|
|
vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True)
|
|
self.assertTrue(counts.le(1).all())
|
|
|
|
# Check gradients
|
|
grad_sampled_points = torch.ones((N, K, D), dtype=torch.float32, device=device)
|
|
loss1 = (out_points_naive * grad_sampled_points).sum()
|
|
loss1.backward()
|
|
loss2 = (out_points * grad_sampled_points).sum()
|
|
loss2.backward()
|
|
self.assertClose(points1.grad, points2.grad, atol=5e-6)
|
|
|
|
def _test_errors(self, fps_func, device="cpu"):
|
|
N, P, D, K = 5, 40, 5, 8
|
|
points = torch.randn((N, P, D), device=device)
|
|
wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device)
|
|
|
|
# K has diferent batch dimension to points
|
|
with self.assertRaisesRegex(ValueError, "K and points must have"):
|
|
sample_farthest_points_naive(points, K=wrong_batch_dim)
|
|
|
|
# lengths has diferent batch dimension to points
|
|
with self.assertRaisesRegex(ValueError, "points and lengths must have"):
|
|
sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K)
|
|
|
|
def _test_random_start(self, fps_func, device="cpu"):
|
|
N, P, D, K = 5, 40, 5, 8
|
|
points = torch.randn((N, P, D), dtype=torch.float32, device=device)
|
|
out_points, out_idxs = fps_func(points, K=K, random_start_point=True)
|
|
# Check the first index is not 0 or the same number for all batch elements
|
|
# when random_start_point = True
|
|
self.assertTrue(out_idxs[:, 0].sum() > 0)
|
|
self.assertFalse(out_idxs[:, 0].eq(out_idxs[0, 0]).all())
|
|
|
|
def _test_gradcheck(self, fps_func, device="cpu"):
|
|
N, P, D, K = 2, 10, 3, 2
|
|
points = torch.randn(
|
|
(N, P, D), dtype=torch.float32, device=device, requires_grad=True
|
|
)
|
|
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
|
|
torch.autograd.gradcheck(
|
|
fps_func,
|
|
(points, lengths, K),
|
|
check_undefined_grad=False,
|
|
eps=2e-3,
|
|
atol=0.001,
|
|
)
|
|
|
|
def test_sample_farthest_points_naive(self):
|
|
device = get_random_cuda_device()
|
|
self._test_simple(sample_farthest_points_naive, device)
|
|
self._test_errors(sample_farthest_points_naive, device)
|
|
self._test_random_start(sample_farthest_points_naive, device)
|
|
self._test_gradcheck(sample_farthest_points_naive, device)
|
|
|
|
def test_sample_farthest_points_cpu(self):
|
|
self._test_simple(sample_farthest_points, "cpu")
|
|
self._test_errors(sample_farthest_points, "cpu")
|
|
self._test_compare_random_heterogeneous("cpu")
|
|
self._test_random_start(sample_farthest_points, "cpu")
|
|
self._test_gradcheck(sample_farthest_points, "cpu")
|
|
|
|
def test_sample_farthest_points_cuda(self):
|
|
device = get_random_cuda_device()
|
|
self._test_simple(sample_farthest_points, device)
|
|
self._test_errors(sample_farthest_points, device)
|
|
self._test_compare_random_heterogeneous(device)
|
|
self._test_random_start(sample_farthest_points, device)
|
|
self._test_gradcheck(sample_farthest_points, device)
|
|
|
|
def test_cuda_vs_cpu(self):
|
|
"""
|
|
Compare cuda vs cpu on a complex object
|
|
"""
|
|
obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj"
|
|
K = 250
|
|
|
|
# Run on CPU
|
|
device = "cpu"
|
|
points, _, _ = load_obj(obj_filename, device=device, load_textures=False)
|
|
points = points[None, ...]
|
|
out_points_cpu, out_idxs_cpu = sample_farthest_points(points, K=K)
|
|
|
|
# Run on GPU
|
|
device = get_random_cuda_device()
|
|
points_cuda = points.to(device)
|
|
out_points_cuda, out_idxs_cuda = sample_farthest_points(points_cuda, K=K)
|
|
|
|
# Check that the indices from CUDA and CPU match
|
|
self.assertClose(out_idxs_cpu, out_idxs_cuda.cpu())
|
|
|
|
# Check there are no duplicate indices
|
|
val_mask = out_idxs_cuda[0].ne(-1)
|
|
vals, counts = torch.unique(out_idxs_cuda[0][val_mask], return_counts=True)
|
|
self.assertTrue(counts.le(1).all())
|
|
|
|
# Plot all results
|
|
if DEBUG:
|
|
# mplot3d is required for 3d projection plots
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits import mplot3d # noqa: F401
|
|
|
|
# Move to cpu and convert to numpy for plotting
|
|
points = points.squeeze()
|
|
out_points_cpu = out_points_cpu.squeeze().numpy()
|
|
out_points_cuda = out_points_cuda.squeeze().cpu().numpy()
|
|
|
|
# Farthest point sampling CPU
|
|
fig = plt.figure(figsize=plt.figaspect(1.0 / 3))
|
|
ax1 = fig.add_subplot(1, 3, 1, projection="3d")
|
|
ax1.scatter(*points.T, alpha=0.1)
|
|
ax1.scatter(*out_points_cpu.T, color="black")
|
|
ax1.set_title("FPS CPU")
|
|
|
|
# Farthest point sampling CUDA
|
|
ax2 = fig.add_subplot(1, 3, 2, projection="3d")
|
|
ax2.scatter(*points.T, alpha=0.1)
|
|
ax2.scatter(*out_points_cuda.T, color="red")
|
|
ax2.set_title("FPS CUDA")
|
|
|
|
# Random Sampling
|
|
random_points = np.random.permutation(points)[:K]
|
|
ax3 = fig.add_subplot(1, 3, 3, projection="3d")
|
|
ax3.scatter(*points.T, alpha=0.1)
|
|
ax3.scatter(*random_points.T, color="green")
|
|
ax3.set_title("Random")
|
|
|
|
# Save image
|
|
filename = "DEBUG_fps.jpg"
|
|
filepath = DATA_DIR / filename
|
|
plt.savefig(filepath)
|
|
|
|
@staticmethod
|
|
def sample_farthest_points_naive(N: int, P: int, D: int, K: int, device: str):
|
|
device = torch.device(device)
|
|
pts = torch.randn(
|
|
N, P, D, dtype=torch.float32, device=device, requires_grad=True
|
|
)
|
|
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
|
|
torch.cuda.synchronize()
|
|
|
|
def output():
|
|
out_points, _ = sample_farthest_points_naive(pts, K=K)
|
|
loss = (out_points * grad_pts).sum()
|
|
loss.backward()
|
|
torch.cuda.synchronize()
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
def sample_farthest_points(N: int, P: int, D: int, K: int, device: str):
|
|
device = torch.device(device)
|
|
pts = torch.randn(
|
|
N, P, D, dtype=torch.float32, device=device, requires_grad=True
|
|
)
|
|
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
|
|
torch.cuda.synchronize()
|
|
|
|
def output():
|
|
out_points, _ = sample_farthest_points(pts, K=K)
|
|
loss = (out_points * grad_pts).sum()
|
|
loss.backward()
|
|
torch.cuda.synchronize()
|
|
|
|
return output
|