pytorch3d/tests/test_rasterize_points.py
Tim Hatch 34bbb3ad32 apply import merging for fbcode/vision/fair (2 of 2)
Summary:
Applies new import merging and sorting from µsort v1.0.

When merging imports, µsort will make a best-effort to move associated
comments to match merged elements, but there are known limitations due to
the diynamic nature of Python and developer tooling. These changes should
not produce any dangerous runtime changes, but may require touch-ups to
satisfy linters and other tooling.

Note that µsort uses case-insensitive, lexicographical sorting, which
results in a different ordering compared to isort. This provides a more
consistent sorting order, matching the case-insensitive order used when
sorting import statements by module name, and ensures that "frog", "FROG",
and "Frog" always sort next to each other.

For details on µsort's sorting and merging semantics, see the user guide:
https://usort.readthedocs.io/en/stable/guide.html#sorting

Reviewed By: bottler

Differential Revision: D35553814

fbshipit-source-id: be49bdb6a4c25264ff8d4db3a601f18736d17be1
2022-04-13 06:51:33 -07:00

651 lines
26 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
import torch
from common_testing import get_random_cuda_device, TestCaseMixin
from pytorch3d import _C
from pytorch3d.renderer.points.rasterize_points import (
_format_radius,
rasterize_points,
rasterize_points_python,
)
from pytorch3d.structures.pointclouds import Pointclouds
class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
def test_python_simple_cpu(self):
self._simple_test_case(
rasterize_points_python, torch.device("cpu"), bin_size=-1
)
def test_naive_simple_cpu(self):
device = torch.device("cpu")
self._simple_test_case(rasterize_points, device)
def test_naive_simple_cuda(self):
device = get_random_cuda_device()
self._simple_test_case(rasterize_points, device, bin_size=0)
def test_python_behind_camera(self):
self._test_behind_camera(
rasterize_points_python, torch.device("cpu"), bin_size=-1
)
def test_cpu_behind_camera(self):
self._test_behind_camera(rasterize_points, torch.device("cpu"))
def test_cuda_behind_camera(self):
device = get_random_cuda_device()
self._test_behind_camera(rasterize_points, device, bin_size=0)
def test_python_variable_radius(self):
self._test_variable_size_radius(
rasterize_points_python, torch.device("cpu"), bin_size=-1
)
def test_cpu_variable_radius(self):
self._test_variable_size_radius(rasterize_points, torch.device("cpu"))
def test_cuda_variable_radius(self):
device = get_random_cuda_device()
# Naive
self._test_variable_size_radius(rasterize_points, device, bin_size=0)
# Coarse to fine
self._test_variable_size_radius(rasterize_points, device, bin_size=None)
def test_cpp_vs_naive_vs_binned(self):
# Make sure that the backward pass runs for all pathways
N = 2
P = 1000
image_size = 32
radius = 0.1
points_per_pixel = 3
points1 = torch.randn(P, 3, requires_grad=True)
points2 = torch.randn(int(P / 2), 3, requires_grad=True)
pointclouds = Pointclouds(points=[points1, points2])
grad_zbuf = torch.randn(N, image_size, image_size, points_per_pixel)
grad_dists = torch.randn(N, image_size, image_size, points_per_pixel)
# Option I: CPU, naive
idx1, zbuf1, dists1 = rasterize_points(
pointclouds, image_size, radius, points_per_pixel, bin_size=0
)
loss = (zbuf1 * grad_zbuf).sum() + (dists1 * grad_dists).sum()
loss.backward()
grad1 = points1.grad.data.clone()
# Option II: CUDA, naive
points1_cuda = points1.cuda().detach().clone().requires_grad_(True)
points2_cuda = points2.cuda().detach().clone().requires_grad_(True)
pointclouds = Pointclouds(points=[points1_cuda, points2_cuda])
grad_zbuf = grad_zbuf.cuda()
grad_dists = grad_dists.cuda()
idx2, zbuf2, dists2 = rasterize_points(
pointclouds, image_size, radius, points_per_pixel, bin_size=0
)
loss = (zbuf2 * grad_zbuf).sum() + (dists2 * grad_dists).sum()
loss.backward()
idx2 = idx2.data.cpu().clone()
zbuf2 = zbuf2.data.cpu().clone()
dists2 = dists2.data.cpu().clone()
grad2 = points1_cuda.grad.data.cpu().clone()
# Option III: CUDA, binned
points1_cuda = points1.cuda().detach().clone().requires_grad_(True)
points2_cuda = points2.cuda().detach().clone().requires_grad_(True)
pointclouds = Pointclouds(points=[points1_cuda, points2_cuda])
idx3, zbuf3, dists3 = rasterize_points(
pointclouds, image_size, radius, points_per_pixel, bin_size=32
)
loss = (zbuf3 * grad_zbuf).sum() + (dists3 * grad_dists).sum()
points1.grad.data.zero_()
loss.backward()
idx3 = idx3.data.cpu().clone()
zbuf3 = zbuf3.data.cpu().clone()
dists3 = dists3.data.cpu().clone()
grad3 = points1_cuda.grad.data.cpu().clone()
# Make sure everything was the same
idx12_same = (idx1 == idx2).all().item()
idx13_same = (idx1 == idx3).all().item()
zbuf12_same = (zbuf1 == zbuf2).all().item()
zbuf13_same = (zbuf1 == zbuf3).all().item()
dists12_diff = (dists1 - dists2).abs().max().item()
dists13_diff = (dists1 - dists3).abs().max().item()
self.assertTrue(idx12_same)
self.assertTrue(idx13_same)
self.assertTrue(zbuf12_same)
self.assertTrue(zbuf13_same)
self.assertTrue(dists12_diff < 1e-6)
self.assertTrue(dists13_diff < 1e-6)
diff12 = (grad1 - grad2).abs().max().item()
diff13 = (grad1 - grad3).abs().max().item()
diff23 = (grad2 - grad3).abs().max().item()
self.assertTrue(diff12 < 5e-6)
self.assertTrue(diff13 < 5e-6)
self.assertTrue(diff23 < 5e-6)
def test_python_vs_cpu_naive(self):
torch.manual_seed(231)
image_size = 32
radius = 0.1
points_per_pixel = 3
# Test a batch of homogeneous point clouds.
N = 2
P = 17
points = torch.randn(N, P, 3, requires_grad=True)
pointclouds = Pointclouds(points=points)
args = (pointclouds, image_size, radius, points_per_pixel)
self._compare_impls(
rasterize_points_python,
rasterize_points,
args,
args,
points,
points,
compare_grads=True,
)
# Test a batch of heterogeneous point clouds.
P2 = 10
points1 = torch.randn(P, 3, requires_grad=True)
points2 = torch.randn(P2, 3)
pointclouds = Pointclouds(points=[points1, points2])
args = (pointclouds, image_size, radius, points_per_pixel)
self._compare_impls(
rasterize_points_python,
rasterize_points,
args,
args,
points1, # check gradients for first element in batch
points1,
compare_grads=True,
)
def test_cpu_vs_cuda_naive(self):
torch.manual_seed(231)
image_size = 64
radius = 0.1
points_per_pixel = 5
# Test homogeneous point cloud batch.
N = 2
P = 1000
bin_size = 0
points_cpu = torch.rand(N, P, 3, requires_grad=True)
points_cuda = points_cpu.cuda().detach().requires_grad_(True)
pointclouds_cpu = Pointclouds(points=points_cpu)
pointclouds_cuda = Pointclouds(points=points_cuda)
args_cpu = (pointclouds_cpu, image_size, radius, points_per_pixel, bin_size)
args_cuda = (pointclouds_cuda, image_size, radius, points_per_pixel, bin_size)
self._compare_impls(
rasterize_points,
rasterize_points,
args_cpu,
args_cuda,
points_cpu,
points_cuda,
compare_grads=True,
)
def _compare_impls(
self,
fn1,
fn2,
args1,
args2,
grad_var1=None,
grad_var2=None,
compare_grads=False,
):
idx1, zbuf1, dist1 = fn1(*args1)
torch.manual_seed(231)
grad_zbuf = torch.randn_like(zbuf1)
grad_dist = torch.randn_like(dist1)
loss = (zbuf1 * grad_zbuf).sum() + (dist1 * grad_dist).sum()
if compare_grads:
loss.backward()
grad_points1 = grad_var1.grad.data.clone().cpu()
idx2, zbuf2, dist2 = fn2(*args2)
grad_zbuf = grad_zbuf.to(zbuf2)
grad_dist = grad_dist.to(dist2)
loss = (zbuf2 * grad_zbuf).sum() + (dist2 * grad_dist).sum()
if compare_grads:
# clear points1.grad in case args1 and args2 reused the same tensor
grad_var1.grad.data.zero_()
loss.backward()
grad_points2 = grad_var2.grad.data.clone().cpu()
self.assertEqual((idx1.cpu() == idx2.cpu()).all().item(), 1)
self.assertEqual((zbuf1.cpu() == zbuf2.cpu()).all().item(), 1)
self.assertClose(dist1.cpu(), dist2.cpu())
if compare_grads:
self.assertClose(grad_points1, grad_points2, atol=2e-6)
def test_bin_size_error(self):
points = Pointclouds(points=torch.rand(5, 100, 3))
image_size = 1024
bin_size = 16
with self.assertRaisesRegex(ValueError, "bin_size too small"):
rasterize_points(points, image_size, 0.0, 2, bin_size=bin_size)
def _test_behind_camera(self, rasterize_points_fn, device, bin_size=None):
# Test case where all points are behind the camera -- nothing should
# get rasterized
N = 2
P = 32
xy = torch.randn(N, P, 2)
z = torch.randn(N, P, 1).abs().mul(-1) # Make them all negative
points = torch.cat([xy, z], dim=2).to(device)
image_size = 16
points_per_pixel = 3
radius = 0.2
idx_expected = torch.full(
(N, 16, 16, 3), fill_value=-1, dtype=torch.int32, device=device
)
zbuf_expected = torch.full(
(N, 16, 16, 3), fill_value=-1, dtype=torch.float32, device=device
)
dists_expected = zbuf_expected.clone()
pointclouds = Pointclouds(points=points)
if bin_size == -1:
# simple python case with no binning
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel
)
else:
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel, bin_size
)
idx_same = (idx == idx_expected).all().item() == 1
zbuf_same = (zbuf == zbuf_expected).all().item() == 1
self.assertTrue(idx_same)
self.assertTrue(zbuf_same)
self.assertClose(dists, dists_expected)
def _simple_test_case(self, rasterize_points_fn, device, bin_size=0):
# Create two pointclouds with different numbers of points.
# fmt: off
points1 = torch.tensor(
[
[0.0, 0.0, 0.0], # noqa: E241
[0.4, 0.0, 0.1], # noqa: E241
[0.0, 0.4, 0.2], # noqa: E241
[0.0, 0.0, -0.1], # noqa: E241 Points with negative z should be skippped
],
device=device,
)
points2 = torch.tensor(
[
[0.0, 0.0, 0.0], # noqa: E241
[0.4, 0.0, 0.1], # noqa: E241
[0.0, 0.4, 0.2], # noqa: E241
[0.0, 0.0, -0.1], # noqa: E241 Points with negative z should be skippped
[0.0, 0.0, -0.7], # noqa: E241 Points with negative z should be skippped
],
device=device,
)
# fmt: on
pointclouds = Pointclouds(points=[points1, points2])
image_size = 5
points_per_pixel = 2
radius = 0.5
# The expected output values. Note that in the outputs, the world space
# +Y is up, and the world space +X is left.
idx1_expected = torch.full(
(1, 5, 5, 2), fill_value=-1, dtype=torch.int32, device=device
)
# fmt: off
idx1_expected[0, :, :, 0] = torch.tensor([
[-1, -1, 2, -1, -1], # noqa: E241
[-1, 1, 0, 2, -1], # noqa: E241
[ 1, 0, 0, 0, -1], # noqa: E241 E201
[-1, 1, 0, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1], # noqa: E241
], device=device)
idx1_expected[0, :, :, 1] = torch.tensor([
[-1, -1, -1, -1, -1], # noqa: E241
[-1, 2, 2, -1, -1], # noqa: E241
[-1, 1, 1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1], # noqa: E241
], device=device)
# fmt: on
zbuf1_expected = torch.full(
(1, 5, 5, 2), fill_value=100, dtype=torch.float32, device=device
)
# fmt: off
zbuf1_expected[0, :, :, 0] = torch.tensor([
[-1.0, -1.0, 0.2, -1.0, -1.0], # noqa: E241
[-1.0, 0.1, 0.0, 0.2, -1.0], # noqa: E241
[ 0.1, 0.0, 0.0, 0.0, -1.0], # noqa: E241 E201
[-1.0, 0.1, 0.0, -1.0, -1.0], # noqa: E241
[-1.0, -1.0, -1.0, -1.0, -1.0] # noqa: E241
], device=device)
zbuf1_expected[0, :, :, 1] = torch.tensor([
[-1.0, -1.0, -1.0, -1.0, -1.0], # noqa: E241
[-1.0, 0.2, 0.2, -1.0, -1.0], # noqa: E241
[-1.0, 0.1, 0.1, -1.0, -1.0], # noqa: E241
[-1.0, -1.0, -1.0, -1.0, -1.0], # noqa: E241
[-1.0, -1.0, -1.0, -1.0, -1.0], # noqa: E241
], device=device)
# fmt: on
dists1_expected = torch.zeros((5, 5, 2), dtype=torch.float32, device=device)
# fmt: off
dists1_expected[:, :, 0] = torch.tensor([
[-1.00, -1.00, 0.16, -1.00, -1.00], # noqa: E241
[-1.00, 0.16, 0.16, 0.16, -1.00], # noqa: E241
[ 0.16, 0.16, 0.00, 0.16, -1.00], # noqa: E241 E201
[-1.00, 0.16, 0.16, -1.00, -1.00], # noqa: E241
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
], device=device)
dists1_expected[:, :, 1] = torch.tensor([
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
[-1.00, 0.16, 0.00, -1.00, -1.00], # noqa: E241
[-1.00, 0.00, 0.16, -1.00, -1.00], # noqa: E241
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
], device=device)
# fmt: on
if bin_size == -1:
# simple python case with no binning
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel
)
else:
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel, bin_size
)
# check first point cloud
idx_same = (idx[0, ...] == idx1_expected).all().item() == 1
if idx_same == 0:
print(idx[0, :, :, 0])
print(idx[0, :, :, 1])
zbuf_same = (zbuf[0, ...] == zbuf1_expected).all().item() == 1
self.assertClose(dists[0, ...], dists1_expected)
self.assertTrue(idx_same)
self.assertTrue(zbuf_same)
# Check second point cloud - the indices in idx refer to points in the
# pointclouds.points_packed() tensor. In the second point cloud,
# two points are behind the screen - the expected indices are the same
# the first pointcloud but offset by the number of points in the
# first pointcloud.
num_points_per_cloud = pointclouds.num_points_per_cloud()
idx1_expected[idx1_expected >= 0] += num_points_per_cloud[0]
idx_same = (idx[1, ...] == idx1_expected).all().item() == 1
zbuf_same = (zbuf[1, ...] == zbuf1_expected).all().item() == 1
self.assertTrue(idx_same)
self.assertTrue(zbuf_same)
self.assertClose(dists[1, ...], dists1_expected)
def test_coarse_cpu(self):
return self._test_coarse_rasterize(torch.device("cpu"))
def test_coarse_cuda(self):
device = get_random_cuda_device()
return self._test_coarse_rasterize(device)
def test_compare_coarse_cpu_vs_cuda(self):
torch.manual_seed(231)
N = 3
max_P = 1000
image_size = (64, 64)
radius = 0.1
bin_size = 16
max_points_per_bin = 500
# create heterogeneous point clouds
points = []
for _ in range(N):
p = np.random.choice(max_P)
points.append(torch.randn(p, 3))
pointclouds = Pointclouds(points=points)
points_packed = pointclouds.points_packed()
cloud_to_packed_first_idx = pointclouds.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds.num_points_per_cloud()
radius = torch.full((points_packed.shape[0],), fill_value=radius)
args = (
points_packed,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
radius,
bin_size,
max_points_per_bin,
)
bp_cpu = _C._rasterize_points_coarse(*args)
device = get_random_cuda_device()
pointclouds_cuda = pointclouds.to(device)
points_packed = pointclouds_cuda.points_packed()
cloud_to_packed_first_idx = pointclouds_cuda.cloud_to_packed_first_idx()
num_points_per_cloud = pointclouds_cuda.num_points_per_cloud()
radius = radius.to(device)
args = (
points_packed,
cloud_to_packed_first_idx,
num_points_per_cloud,
image_size,
radius,
bin_size,
max_points_per_bin,
)
bp_cuda = _C._rasterize_points_coarse(*args)
# Bin points might not be the same: CUDA version might write them in
# any order. But if we sort the non-(-1) elements of the CUDA output
# then they should be the same.
for n in range(N):
for by in range(bp_cpu.shape[1]):
for bx in range(bp_cpu.shape[2]):
K = (bp_cpu[n, by, bx] != -1).sum().item()
idxs_cpu = bp_cpu[n, by, bx].tolist()
idxs_cuda = bp_cuda[n, by, bx].tolist()
idxs_cuda[:K] = sorted(idxs_cuda[:K])
self.assertEqual(idxs_cpu, idxs_cuda)
def _test_coarse_rasterize(self, device):
#
#
# |2 (4)
# |
# |
# |
# |1
# |
# | (1)
# (2)|
# _________(5)___(0)_______________
# -1 | 1 2
# |
# | (3)
# |
# |-1
#
# Locations of the points are shown by o. The screen bounding box
# is between [-1, 1] in both the x and y directions.
#
# These points are interesting because:
# (0) Falls into two bins;
# (1) and (2) fall into one bin;
# (3) is out-of-bounds, but its disk is in-bounds;
# (4) is out-of-bounds, and its entire disk is also out-of-bounds
# (5) has a negative z-value, so it should be skipped
# fmt: off
points = torch.tensor(
[
[ 0.5, 0.0, 0.0], # noqa: E241, E201
[ 0.5, 0.5, 0.1], # noqa: E241, E201
[-0.3, 0.4, 0.0], # noqa: E241
[ 1.1, -0.5, 0.2], # noqa: E241, E201
[ 2.0, 2.0, 0.3], # noqa: E241, E201
[ 0.0, 0.0, -0.1], # noqa: E241, E201
],
device=device
)
# fmt: on
image_size = (16, 16)
radius = 0.2
bin_size = 8
max_points_per_bin = 5
bin_points_expected = -1 * torch.ones(
1, 2, 2, 5, dtype=torch.int32, device=device
)
# Note that the order is only deterministic here for CUDA if all points
# fit in one chunk. This will the the case for this small example, but
# to properly exercise coordinated writes among multiple chunks we need
# to use a bigger test case.
bin_points_expected[0, 0, 1, :2] = torch.tensor([0, 3])
bin_points_expected[0, 1, 0, 0] = torch.tensor([2])
bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
pointclouds = Pointclouds(points=[points])
radius = torch.full((points.shape[0],), fill_value=radius, device=device)
args = (
pointclouds.points_packed(),
pointclouds.cloud_to_packed_first_idx(),
pointclouds.num_points_per_cloud(),
image_size,
radius,
bin_size,
max_points_per_bin,
)
bin_points = _C._rasterize_points_coarse(*args)
bin_points_same = (bin_points == bin_points_expected).all()
self.assertTrue(bin_points_same.item() == 1)
def _test_variable_size_radius(self, rasterize_points_fn, device, bin_size=0):
# Two points
points = torch.tensor(
[[0.5, 0.5, 0.3], [0.5, -0.5, -0.1], [0.0, 0.0, 0.3]],
dtype=torch.float32,
device=device,
)
image_size = 16
points_per_pixel = 1
radius = torch.tensor([0.1, 0.0, 0.2], dtype=torch.float32, device=device)
pointclouds = Pointclouds(points=[points])
if bin_size == -1:
# simple python case with no binning
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel
)
else:
idx, zbuf, dists = rasterize_points_fn(
pointclouds, image_size, radius, points_per_pixel, bin_size
)
idx_expected = torch.zeros(
(1, image_size, image_size, 1), dtype=torch.int64, device=device
)
# fmt: off
idx_expected[0, ..., 0] = torch.tensor(
[
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, 2, 2, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241
],
dtype=torch.int64,
device=device
)
# fmt: on
zbuf_expected = torch.full(
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
)
zbuf_expected[idx_expected == 0] = 0.3
zbuf_expected[idx_expected == 2] = 0.3
dists_expected = torch.full(
idx_expected.shape, fill_value=-1, dtype=torch.float32, device=device
)
# fmt: off
dists_expected[0, ..., 0] = torch.Tensor(
[
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., 0.0078, 0.0078, -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., 0.0391, 0.0078, 0.0078, 0.0391, -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., 0.0391, 0.0391, -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, B950
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241, B950
]
)
# fmt: on
# Check the distances for a point are less than the squared radius
# for that point.
self.assertTrue((dists[idx == 0] < radius[0] ** 2).all())
self.assertTrue((dists[idx == 2] < radius[2] ** 2).all())
# Check all values are correct.
idx_same = (idx == idx_expected).all().item() == 1
zbuf_same = (zbuf == zbuf_expected).all().item() == 1
self.assertTrue(idx_same)
self.assertTrue(zbuf_same)
self.assertClose(dists, dists_expected, atol=4e-5)
def test_radius_format_failure(self):
N = 20
P_max = 15
points_list = []
for _ in range(N):
p = torch.randint(low=1, high=P_max, size=(1,))[0]
points_list.append(torch.randn((p, 3)))
points = Pointclouds(points=points_list)
# Incorrect shape
with self.assertRaisesRegex(ValueError, "radius must be of shape"):
_format_radius([0, 1, 2], points)
# Incorrect type
with self.assertRaisesRegex(ValueError, "float, list, tuple or tensor"):
_format_radius({0: [0, 1, 2]}, points)