pytorch3d/tests/test_sample_points_from_meshes.py
Georgia Gkioxari 29cd181a83 CPU implem for face areas normals
Summary:
Added cpu implementation for face areas normals. Moved test and bm to separate functions.

```
Benchmark                                   Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
FACE_AREAS_NORMALS_2_100_300_False                196             268           2550
FACE_AREAS_NORMALS_2_100_300_True                 106             179           4733
FACE_AREAS_NORMALS_2_100_3000_False              1447            1630            346
FACE_AREAS_NORMALS_2_100_3000_True                107             178           4674
FACE_AREAS_NORMALS_2_1000_300_False               201             309           2486
FACE_AREAS_NORMALS_2_1000_300_True                107             186           4673
FACE_AREAS_NORMALS_2_1000_3000_False             1451            1636            345
FACE_AREAS_NORMALS_2_1000_3000_True               107             186           4655
FACE_AREAS_NORMALS_10_100_300_False               767             918            653
FACE_AREAS_NORMALS_10_100_300_True                106             167           4712
FACE_AREAS_NORMALS_10_100_3000_False             7036            7754             72
FACE_AREAS_NORMALS_10_100_3000_True               113             164           4445
FACE_AREAS_NORMALS_10_1000_300_False              748             947            669
FACE_AREAS_NORMALS_10_1000_300_True               108             169           4638
FACE_AREAS_NORMALS_10_1000_3000_False            7069            7783             71
FACE_AREAS_NORMALS_10_1000_3000_True              108             172           4646
FACE_AREAS_NORMALS_32_100_300_False              2286            2496            219
FACE_AREAS_NORMALS_32_100_300_True                108             180           4631
FACE_AREAS_NORMALS_32_100_3000_False            23184           24369             22
FACE_AREAS_NORMALS_32_100_3000_True               159             213           3147
FACE_AREAS_NORMALS_32_1000_300_False             2414            2645            208
FACE_AREAS_NORMALS_32_1000_300_True               112             197           4480
FACE_AREAS_NORMALS_32_1000_3000_False           21687           22964             24
FACE_AREAS_NORMALS_32_1000_3000_True              141             211           3540
--------------------------------------------------------------------------------

Benchmark                                         Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
FACE_AREAS_NORMALS_TORCH_2_100_300_False               5465            5782             92
FACE_AREAS_NORMALS_TORCH_2_100_300_True                1198            1351            418
FACE_AREAS_NORMALS_TORCH_2_100_3000_False             48228           48869             11
FACE_AREAS_NORMALS_TORCH_2_100_3000_True               1186            1304            422
FACE_AREAS_NORMALS_TORCH_2_1000_300_False              5556            6097             90
FACE_AREAS_NORMALS_TORCH_2_1000_300_True               1200            1328            417
FACE_AREAS_NORMALS_TORCH_2_1000_3000_False            48683           50016             11
FACE_AREAS_NORMALS_TORCH_2_1000_3000_True              1185            1306            422
FACE_AREAS_NORMALS_TORCH_10_100_300_False             24215           25097             21
FACE_AREAS_NORMALS_TORCH_10_100_300_True               1150            1314            435
FACE_AREAS_NORMALS_TORCH_10_100_3000_False           232605          234952              3
FACE_AREAS_NORMALS_TORCH_10_100_3000_True              1193            1314            420
FACE_AREAS_NORMALS_TORCH_10_1000_300_False            24912           25343             21
FACE_AREAS_NORMALS_TORCH_10_1000_300_True              1216            1330            412
FACE_AREAS_NORMALS_TORCH_10_1000_3000_False          239907          241253              3
FACE_AREAS_NORMALS_TORCH_10_1000_3000_True             1226            1333            408
FACE_AREAS_NORMALS_TORCH_32_100_300_False             73991           75776              7
FACE_AREAS_NORMALS_TORCH_32_100_300_True               1193            1339            420
FACE_AREAS_NORMALS_TORCH_32_100_3000_False           728932          728932              1
FACE_AREAS_NORMALS_TORCH_32_100_3000_True              1186            1359            422
FACE_AREAS_NORMALS_TORCH_32_1000_300_False            76385           79129              7
FACE_AREAS_NORMALS_TORCH_32_1000_300_True              1165            1310            430
FACE_AREAS_NORMALS_TORCH_32_1000_3000_False          753276          753276              1
FACE_AREAS_NORMALS_TORCH_32_1000_3000_True             1205            1340            415
--------------------------------------------------------------------------------
```

Reviewed By: bottler, jcjohnson

Differential Revision: D19864385

fbshipit-source-id: 3a87ae41a8e3ab5560febcb94961798f2e09dfb8
2020-02-13 11:42:48 -08:00

398 lines
13 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from pathlib import Path
import torch
from pytorch3d import _C
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere
class TestSamplePoints(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(1)
@staticmethod
def init_meshes(
num_meshes: int = 10,
num_verts: int = 1000,
num_faces: int = 3000,
device: str = "cpu",
):
device = torch.device(device)
verts_list = []
faces_list = []
for _ in range(num_meshes):
verts = torch.rand(
(num_verts, 3), dtype=torch.float32, device=device
)
faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
)
verts_list.append(verts)
faces_list.append(faces)
meshes = Meshes(verts_list, faces_list)
return meshes
def test_all_empty_meshes(self):
"""
Check sample_points_from_meshes raises an exception if all meshes are
invalid.
"""
device = torch.device("cuda:0")
verts1 = torch.tensor([], dtype=torch.float32, device=device)
faces1 = torch.tensor([], dtype=torch.int64, device=device)
meshes = Meshes(
verts=[verts1, verts1, verts1], faces=[faces1, faces1, faces1]
)
with self.assertRaises(ValueError) as err:
sample_points_from_meshes(
meshes, num_samples=100, return_normals=True
)
self.assertTrue("Meshes are empty." in str(err.exception))
def test_sampling_output(self):
"""
Check outputs of sampling are correct for different meshes.
For an ico_sphere, the sampled vertices should lie on a unit sphere.
For an empty mesh, the samples and normals should be 0.
"""
device = torch.device("cuda:0")
# Unit simplex.
verts_pyramid = torch.tensor(
[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
],
dtype=torch.float32,
device=device,
)
faces_pyramid = torch.tensor(
[[0, 1, 2], [0, 2, 3], [0, 1, 3], [1, 2, 3]],
dtype=torch.int64,
device=device,
)
sphere_mesh = ico_sphere(9, device)
verts_sphere, faces_sphere = sphere_mesh.get_mesh_verts_faces(0)
verts_empty = torch.tensor([], dtype=torch.float32, device=device)
faces_empty = torch.tensor([], dtype=torch.int64, device=device)
num_samples = 10
meshes = Meshes(
verts=[verts_empty, verts_sphere, verts_pyramid],
faces=[faces_empty, faces_sphere, faces_pyramid],
)
samples, normals = sample_points_from_meshes(
meshes, num_samples=num_samples, return_normals=True
)
samples = samples.cpu()
normals = normals.cpu()
self.assertEqual(samples.shape, (3, num_samples, 3))
self.assertEqual(normals.shape, (3, num_samples, 3))
# Empty meshes: should have all zeros for samples and normals.
self.assertTrue(
torch.allclose(samples[0, :], torch.zeros((1, num_samples, 3)))
)
self.assertTrue(
torch.allclose(normals[0, :], torch.zeros((1, num_samples, 3)))
)
# Sphere: points should have radius 1.
x, y, z = samples[1, :].unbind(1)
radius = torch.sqrt(x ** 2 + y ** 2 + z ** 2)
self.assertTrue(torch.allclose(radius, torch.ones((num_samples))))
# Pyramid: points shoudl lie on one of the faces.
pyramid_verts = samples[2, :]
pyramid_normals = normals[2, :]
self.assertTrue(
torch.allclose(
pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts)
)
)
self.assertTrue(
torch.allclose(
(pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts)
)
)
# Face 1: z = 0, x + y <= 1, normals = (0, 0, 1).
face_1_idxs = pyramid_verts[:, 2] == 0
face_1_verts, face_1_normals = (
pyramid_verts[face_1_idxs, :],
pyramid_normals[face_1_idxs, :],
)
self.assertTrue(
torch.all((face_1_verts[:, 0] + face_1_verts[:, 1]) <= 1)
)
self.assertTrue(
torch.allclose(
face_1_normals,
torch.tensor([0, 0, 1], dtype=torch.float32).expand(
face_1_normals.size()
),
)
)
# Face 2: x = 0, z + y <= 1, normals = (1, 0, 0).
face_2_idxs = pyramid_verts[:, 0] == 0
face_2_verts, face_2_normals = (
pyramid_verts[face_2_idxs, :],
pyramid_normals[face_2_idxs, :],
)
self.assertTrue(
torch.all((face_2_verts[:, 1] + face_2_verts[:, 2]) <= 1)
)
self.assertTrue(
torch.allclose(
face_2_normals,
torch.tensor([1, 0, 0], dtype=torch.float32).expand(
face_2_normals.size()
),
)
)
# Face 3: y = 0, x + z <= 1, normals = (0, -1, 0).
face_3_idxs = pyramid_verts[:, 1] == 0
face_3_verts, face_3_normals = (
pyramid_verts[face_3_idxs, :],
pyramid_normals[face_3_idxs, :],
)
self.assertTrue(
torch.all((face_3_verts[:, 0] + face_3_verts[:, 2]) <= 1)
)
self.assertTrue(
torch.allclose(
face_3_normals,
torch.tensor([0, -1, 0], dtype=torch.float32).expand(
face_3_normals.size()
),
)
)
# Face 4: x + y + z = 1, normals = (1, 1, 1)/sqrt(3).
face_4_idxs = pyramid_verts.gt(0).all(1)
face_4_verts, face_4_normals = (
pyramid_verts[face_4_idxs, :],
pyramid_normals[face_4_idxs, :],
)
self.assertTrue(
torch.allclose(
face_4_verts.sum(1), torch.ones(face_4_verts.size(0))
)
)
self.assertTrue(
torch.allclose(
face_4_normals,
(
torch.tensor([1, 1, 1], dtype=torch.float32)
/ torch.sqrt(torch.tensor(3, dtype=torch.float32))
).expand(face_4_normals.size()),
)
)
def test_mutinomial(self):
"""
Confirm that torch.multinomial does not sample elements which have
zero probability.
"""
freqs = torch.cuda.FloatTensor(
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.03178183361887932,
0.027680952101945877,
0.033176131546497345,
0.046052902936935425,
0.07742464542388916,
0.11543981730937958,
0.14148041605949402,
0.15784293413162231,
0.13180233538150787,
0.08271478116512299,
0.049702685326337814,
0.027557924389839172,
0.018125897273421288,
0.011851548217236996,
0.010252203792333603,
0.007422595750540495,
0.005372154992073774,
0.0045109698548913,
0.0036087757907807827,
0.0035267581697553396,
0.0018864056328311563,
0.0024605290964245796,
0.0022964938543736935,
0.0018453967059031129,
0.0010662291897460818,
0.0009842115687206388,
0.00045109697384759784,
0.0007791675161570311,
0.00020504408166743815,
0.00020504408166743815,
0.00020504408166743815,
0.00012302644609007984,
0.0,
0.00012302644609007984,
4.100881778867915e-05,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
]
)
sample = []
for _ in range(1000):
torch.cuda.get_rng_state()
sample = torch.multinomial(freqs, 1000, True)
if freqs[sample].min() == 0:
sample_idx = (freqs[sample] == 0).nonzero()[0][0]
sampled = sample[sample_idx]
print(
"%s th element of last sample was %s, which has probability %s"
% (sample_idx, sampled, freqs[sampled])
)
return False
return True
def test_multinomial_weights(self):
"""
Confirm that torch.multinomial does not sample elements which have
zero probability using a real example of input from a training run.
"""
weights = torch.load(Path(__file__).resolve().parent / "weights.pt")
S = 4096
num_trials = 100
for _ in range(0, num_trials):
weights[weights < 0] = 0.0
samples = weights.multinomial(S, replacement=True)
sampled_weights = weights[samples]
assert sampled_weights.min() > 0
if sampled_weights.min() <= 0:
return False
return True
@staticmethod
def packed_to_padded_tensor(inputs, first_idxs, max_size):
"""
PyTorch implementation of cuda packed_to_padded_tensor function.
"""
num_meshes = first_idxs.size(0)
inputs_padded = torch.zeros((num_meshes, max_size))
for m in range(num_meshes):
s = first_idxs[m]
if m == num_meshes - 1:
f = inputs.size(0)
else:
f = first_idxs[m + 1]
inputs_padded[m, :f] = inputs[s:f]
return inputs_padded
def test_packed_to_padded_tensor(self):
"""
Check the results from packed_to_padded cuda and PyTorch implementions
are the same.
"""
meshes = self.init_meshes(1, 3, 5, device="cuda:0")
verts = meshes.verts_packed()
faces = meshes.faces_packed()
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
max_faces = meshes.num_faces_per_mesh().max().item()
areas, _ = _C.face_areas_normals(verts, faces)
areas_padded = _C.packed_to_padded_tensor(
areas, mesh_to_faces_packed_first_idx, max_faces
).cpu()
areas_padded_cpu = TestSamplePoints.packed_to_padded_tensor(
areas, mesh_to_faces_packed_first_idx, max_faces
)
self.assertTrue(torch.allclose(areas_padded, areas_padded_cpu))
with self.assertRaises(Exception) as err:
_C.packed_to_padded_tensor(
areas.cpu(), mesh_to_faces_packed_first_idx, max_faces
)
self.assertTrue("Not implemented on the CPU" in str(err.exception))
@staticmethod
def sample_points_with_init(
num_meshes: int,
num_verts: int,
num_faces: int,
num_samples: int,
device: str = "cpu",
):
device = torch.device(device)
verts_list = []
faces_list = []
for _ in range(num_meshes):
verts = torch.rand(
(num_verts, 3), dtype=torch.float32, device=device
)
faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
)
verts_list.append(verts)
faces_list.append(faces)
meshes = Meshes(verts_list, faces_list)
torch.cuda.synchronize()
def sample_points():
sample_points_from_meshes(
meshes, num_samples=num_samples, return_normals=True
)
torch.cuda.synchronize()
return sample_points
@staticmethod
def packed_to_padded_with_init(
num_meshes: int, num_verts: int, num_faces: int, cuda: str = True
):
device = "cuda" if cuda else "cpu"
meshes = TestSamplePoints.init_meshes(
num_meshes, num_verts, num_faces, device
)
verts = meshes.verts_packed()
faces = meshes.faces_packed()
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
max_faces = meshes.num_faces_per_mesh().max().item()
areas, _ = _C.face_areas_normals(verts, faces)
torch.cuda.synchronize()
def packed_to_padded():
if cuda:
_C.packed_to_padded_tensor(
areas, mesh_to_faces_packed_first_idx, max_faces
)
else:
TestSamplePoints.packed_to_padded_tensor(
areas, mesh_to_faces_packed_first_idx, max_faces
)
torch.cuda.synchronize()
return packed_to_padded