mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Applies new import merging and sorting from µsort v1.0. When merging imports, µsort will make a best-effort to move associated comments to match merged elements, but there are known limitations due to the diynamic nature of Python and developer tooling. These changes should not produce any dangerous runtime changes, but may require touch-ups to satisfy linters and other tooling. Note that µsort uses case-insensitive, lexicographical sorting, which results in a different ordering compared to isort. This provides a more consistent sorting order, matching the case-insensitive order used when sorting import statements by module name, and ensures that "frog", "FROG", and "Frog" always sort next to each other. For details on µsort's sorting and merging semantics, see the user guide: https://usort.readthedocs.io/en/stable/guide.html#sorting Reviewed By: bottler Differential Revision: D35553814 fbshipit-source-id: be49bdb6a4c25264ff8d4db3a601f18736d17be1
276 lines
11 KiB
Python
276 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from common_testing import (
|
|
get_pytorch3d_dir,
|
|
get_random_cuda_device,
|
|
get_tests_dir,
|
|
TestCaseMixin,
|
|
)
|
|
from pytorch3d.io import load_obj
|
|
from pytorch3d.ops.sample_farthest_points import (
|
|
sample_farthest_points,
|
|
sample_farthest_points_naive,
|
|
)
|
|
from pytorch3d.ops.utils import masked_gather
|
|
|
|
|
|
DATA_DIR = get_tests_dir() / "data"
|
|
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
|
DEBUG = False
|
|
|
|
|
|
class TestFPS(TestCaseMixin, unittest.TestCase):
|
|
def _test_simple(self, fps_func, device="cpu"):
|
|
# fmt: off
|
|
points = torch.tensor(
|
|
[
|
|
[
|
|
[-1.0, -1.0], # noqa: E241, E201
|
|
[-1.3, 1.1], # noqa: E241, E201
|
|
[ 0.2, -1.1], # noqa: E241, E201
|
|
[ 0.0, 0.0], # noqa: E241, E201
|
|
[ 1.3, 1.3], # noqa: E241, E201
|
|
[ 1.0, 0.5], # noqa: E241, E201
|
|
[-1.3, 0.2], # noqa: E241, E201
|
|
[ 1.5, -0.5], # noqa: E241, E201
|
|
],
|
|
[
|
|
[-2.2, -2.4], # noqa: E241, E201
|
|
[-2.1, 2.0], # noqa: E241, E201
|
|
[ 2.2, 2.1], # noqa: E241, E201
|
|
[ 2.1, -2.4], # noqa: E241, E201
|
|
[ 0.4, -1.0], # noqa: E241, E201
|
|
[ 0.3, 0.3], # noqa: E241, E201
|
|
[ 1.2, 0.5], # noqa: E241, E201
|
|
[ 4.5, 4.5], # noqa: E241, E201
|
|
],
|
|
],
|
|
dtype=torch.float32,
|
|
device=device,
|
|
)
|
|
# fmt: on
|
|
expected_inds = torch.tensor([[0, 4], [0, 7]], dtype=torch.int64, device=device)
|
|
out_points, out_inds = fps_func(points, K=2)
|
|
self.assertClose(out_inds, expected_inds)
|
|
|
|
# Gather the points
|
|
expected_inds = expected_inds[..., None].expand(-1, -1, points.shape[-1])
|
|
self.assertClose(out_points, points.gather(dim=1, index=expected_inds))
|
|
|
|
# Different number of points sampled for each pointcloud in the batch
|
|
expected_inds = torch.tensor(
|
|
[[0, 4, 1], [0, 7, -1]], dtype=torch.int64, device=device
|
|
)
|
|
out_points, out_inds = fps_func(points, K=[3, 2])
|
|
self.assertClose(out_inds, expected_inds)
|
|
|
|
# Gather the points
|
|
expected_points = masked_gather(points, expected_inds)
|
|
self.assertClose(out_points, expected_points)
|
|
|
|
def _test_compare_random_heterogeneous(self, device="cpu"):
|
|
N, P, D, K = 5, 20, 5, 8
|
|
points = torch.randn((N, P, D), device=device, dtype=torch.float32)
|
|
out_points_naive, out_idxs_naive = sample_farthest_points_naive(points, K=K)
|
|
out_points, out_idxs = sample_farthest_points(points, K=K)
|
|
self.assertTrue(out_idxs.min() >= 0)
|
|
self.assertClose(out_idxs, out_idxs_naive)
|
|
self.assertClose(out_points, out_points_naive)
|
|
for n in range(N):
|
|
self.assertEqual(out_idxs[n].ne(-1).sum(), K)
|
|
|
|
# Test case where K > P
|
|
K = 30
|
|
points1 = torch.randn((N, P, D), dtype=torch.float32, device=device)
|
|
points2 = points1.clone()
|
|
points1.requires_grad = True
|
|
points2.requires_grad = True
|
|
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
|
|
out_points_naive, out_idxs_naive = sample_farthest_points_naive(
|
|
points1, lengths, K=K
|
|
)
|
|
out_points, out_idxs = sample_farthest_points(points2, lengths, K=K)
|
|
self.assertClose(out_idxs, out_idxs_naive)
|
|
self.assertClose(out_points, out_points_naive)
|
|
|
|
for n in range(N):
|
|
# Check that for heterogeneous batches, the max number of
|
|
# selected points is less than the length
|
|
self.assertTrue(out_idxs[n].ne(-1).sum() <= lengths[n])
|
|
self.assertTrue(out_idxs[n].max() <= lengths[n])
|
|
|
|
# Check there are no duplicate indices
|
|
val_mask = out_idxs[n].ne(-1)
|
|
vals, counts = torch.unique(out_idxs[n][val_mask], return_counts=True)
|
|
self.assertTrue(counts.le(1).all())
|
|
|
|
# Check gradients
|
|
grad_sampled_points = torch.ones((N, K, D), dtype=torch.float32, device=device)
|
|
loss1 = (out_points_naive * grad_sampled_points).sum()
|
|
loss1.backward()
|
|
loss2 = (out_points * grad_sampled_points).sum()
|
|
loss2.backward()
|
|
self.assertClose(points1.grad, points2.grad, atol=5e-6)
|
|
|
|
def _test_errors(self, fps_func, device="cpu"):
|
|
N, P, D, K = 5, 40, 5, 8
|
|
points = torch.randn((N, P, D), device=device)
|
|
wrong_batch_dim = torch.randint(low=1, high=K, size=(K,), device=device)
|
|
|
|
# K has diferent batch dimension to points
|
|
with self.assertRaisesRegex(ValueError, "K and points must have"):
|
|
sample_farthest_points_naive(points, K=wrong_batch_dim)
|
|
|
|
# lengths has diferent batch dimension to points
|
|
with self.assertRaisesRegex(ValueError, "points and lengths must have"):
|
|
sample_farthest_points_naive(points, lengths=wrong_batch_dim, K=K)
|
|
|
|
def _test_random_start(self, fps_func, device="cpu"):
|
|
N, P, D, K = 5, 40, 5, 8
|
|
points = torch.randn((N, P, D), dtype=torch.float32, device=device)
|
|
out_points, out_idxs = fps_func(points, K=K, random_start_point=True)
|
|
# Check the first index is not 0 or the same number for all batch elements
|
|
# when random_start_point = True
|
|
self.assertTrue(out_idxs[:, 0].sum() > 0)
|
|
self.assertFalse(out_idxs[:, 0].eq(out_idxs[0, 0]).all())
|
|
|
|
def _test_gradcheck(self, fps_func, device="cpu"):
|
|
N, P, D, K = 2, 10, 3, 2
|
|
points = torch.randn(
|
|
(N, P, D), dtype=torch.float32, device=device, requires_grad=True
|
|
)
|
|
lengths = torch.randint(low=1, high=P, size=(N,), device=device)
|
|
torch.autograd.gradcheck(
|
|
fps_func,
|
|
(points, lengths, K),
|
|
check_undefined_grad=False,
|
|
eps=2e-3,
|
|
atol=0.001,
|
|
)
|
|
|
|
def test_sample_farthest_points_naive(self):
|
|
device = get_random_cuda_device()
|
|
self._test_simple(sample_farthest_points_naive, device)
|
|
self._test_errors(sample_farthest_points_naive, device)
|
|
self._test_random_start(sample_farthest_points_naive, device)
|
|
self._test_gradcheck(sample_farthest_points_naive, device)
|
|
|
|
def test_sample_farthest_points_cpu(self):
|
|
self._test_simple(sample_farthest_points, "cpu")
|
|
self._test_errors(sample_farthest_points, "cpu")
|
|
self._test_compare_random_heterogeneous("cpu")
|
|
self._test_random_start(sample_farthest_points, "cpu")
|
|
self._test_gradcheck(sample_farthest_points, "cpu")
|
|
|
|
def test_sample_farthest_points_cuda(self):
|
|
device = get_random_cuda_device()
|
|
self._test_simple(sample_farthest_points, device)
|
|
self._test_errors(sample_farthest_points, device)
|
|
self._test_compare_random_heterogeneous(device)
|
|
self._test_random_start(sample_farthest_points, device)
|
|
self._test_gradcheck(sample_farthest_points, device)
|
|
|
|
def test_cuda_vs_cpu(self):
|
|
"""
|
|
Compare cuda vs cpu on a complex object
|
|
"""
|
|
obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj"
|
|
K = 250
|
|
|
|
# Run on CPU
|
|
device = "cpu"
|
|
points, _, _ = load_obj(obj_filename, device=device, load_textures=False)
|
|
points = points[None, ...]
|
|
out_points_cpu, out_idxs_cpu = sample_farthest_points(points, K=K)
|
|
|
|
# Run on GPU
|
|
device = get_random_cuda_device()
|
|
points_cuda = points.to(device)
|
|
out_points_cuda, out_idxs_cuda = sample_farthest_points(points_cuda, K=K)
|
|
|
|
# Check that the indices from CUDA and CPU match
|
|
self.assertClose(out_idxs_cpu, out_idxs_cuda.cpu())
|
|
|
|
# Check there are no duplicate indices
|
|
val_mask = out_idxs_cuda[0].ne(-1)
|
|
vals, counts = torch.unique(out_idxs_cuda[0][val_mask], return_counts=True)
|
|
self.assertTrue(counts.le(1).all())
|
|
|
|
# Plot all results
|
|
if DEBUG:
|
|
# mplot3d is required for 3d projection plots
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits import mplot3d # noqa: F401
|
|
|
|
# Move to cpu and convert to numpy for plotting
|
|
points = points.squeeze()
|
|
out_points_cpu = out_points_cpu.squeeze().numpy()
|
|
out_points_cuda = out_points_cuda.squeeze().cpu().numpy()
|
|
|
|
# Farthest point sampling CPU
|
|
fig = plt.figure(figsize=plt.figaspect(1.0 / 3))
|
|
ax1 = fig.add_subplot(1, 3, 1, projection="3d")
|
|
ax1.scatter(*points.T, alpha=0.1)
|
|
ax1.scatter(*out_points_cpu.T, color="black")
|
|
ax1.set_title("FPS CPU")
|
|
|
|
# Farthest point sampling CUDA
|
|
ax2 = fig.add_subplot(1, 3, 2, projection="3d")
|
|
ax2.scatter(*points.T, alpha=0.1)
|
|
ax2.scatter(*out_points_cuda.T, color="red")
|
|
ax2.set_title("FPS CUDA")
|
|
|
|
# Random Sampling
|
|
random_points = np.random.permutation(points)[:K]
|
|
ax3 = fig.add_subplot(1, 3, 3, projection="3d")
|
|
ax3.scatter(*points.T, alpha=0.1)
|
|
ax3.scatter(*random_points.T, color="green")
|
|
ax3.set_title("Random")
|
|
|
|
# Save image
|
|
filename = "DEBUG_fps.jpg"
|
|
filepath = DATA_DIR / filename
|
|
plt.savefig(filepath)
|
|
|
|
@staticmethod
|
|
def sample_farthest_points_naive(N: int, P: int, D: int, K: int, device: str):
|
|
device = torch.device(device)
|
|
pts = torch.randn(
|
|
N, P, D, dtype=torch.float32, device=device, requires_grad=True
|
|
)
|
|
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
|
|
torch.cuda.synchronize()
|
|
|
|
def output():
|
|
out_points, _ = sample_farthest_points_naive(pts, K=K)
|
|
loss = (out_points * grad_pts).sum()
|
|
loss.backward()
|
|
torch.cuda.synchronize()
|
|
|
|
return output
|
|
|
|
@staticmethod
|
|
def sample_farthest_points(N: int, P: int, D: int, K: int, device: str):
|
|
device = torch.device(device)
|
|
pts = torch.randn(
|
|
N, P, D, dtype=torch.float32, device=device, requires_grad=True
|
|
)
|
|
grad_pts = torch.randn(N, K, D, dtype=torch.float32, device=device)
|
|
torch.cuda.synchronize()
|
|
|
|
def output():
|
|
out_points, _ = sample_farthest_points(pts, K=K)
|
|
loss = (out_points * grad_pts).sum()
|
|
loss.backward()
|
|
torch.cuda.synchronize()
|
|
|
|
return output
|