mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
use assertClose
Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph. Reviewed By: nikhilaravi Differential Revision: D20556912 fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
744ef0c2c8
commit
595aca27ea
@@ -9,8 +9,10 @@ from pytorch3d.ops import sample_points_from_meshes
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
class TestSamplePoints(unittest.TestCase):
|
||||
|
||||
class TestSamplePoints(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
@@ -98,32 +100,24 @@ class TestSamplePoints(unittest.TestCase):
|
||||
self.assertEqual(normals.shape, (3, num_samples, 3))
|
||||
|
||||
# Empty meshes: should have all zeros for samples and normals.
|
||||
self.assertTrue(
|
||||
torch.allclose(samples[0, :], torch.zeros((1, num_samples, 3)))
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(normals[0, :], torch.zeros((1, num_samples, 3)))
|
||||
)
|
||||
self.assertClose(samples[0, :], torch.zeros((num_samples, 3)))
|
||||
self.assertClose(normals[0, :], torch.zeros((num_samples, 3)))
|
||||
|
||||
# Sphere: points should have radius 1.
|
||||
x, y, z = samples[1, :].unbind(1)
|
||||
radius = torch.sqrt(x ** 2 + y ** 2 + z ** 2)
|
||||
|
||||
self.assertTrue(torch.allclose(radius, torch.ones((num_samples))))
|
||||
self.assertClose(radius, torch.ones((num_samples)))
|
||||
|
||||
# Pyramid: points shoudl lie on one of the faces.
|
||||
pyramid_verts = samples[2, :]
|
||||
pyramid_normals = normals[2, :]
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts)
|
||||
)
|
||||
self.assertClose(
|
||||
pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
(pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts)
|
||||
)
|
||||
self.assertClose(
|
||||
(pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts)
|
||||
)
|
||||
|
||||
# Face 1: z = 0, x + y <= 1, normals = (0, 0, 1).
|
||||
@@ -135,13 +129,11 @@ class TestSamplePoints(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
torch.all((face_1_verts[:, 0] + face_1_verts[:, 1]) <= 1)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
face_1_normals,
|
||||
torch.tensor([0, 0, 1], dtype=torch.float32).expand(
|
||||
face_1_normals.size()
|
||||
),
|
||||
)
|
||||
self.assertClose(
|
||||
face_1_normals,
|
||||
torch.tensor([0, 0, 1], dtype=torch.float32).expand(
|
||||
face_1_normals.size()
|
||||
),
|
||||
)
|
||||
|
||||
# Face 2: x = 0, z + y <= 1, normals = (1, 0, 0).
|
||||
@@ -153,13 +145,11 @@ class TestSamplePoints(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
torch.all((face_2_verts[:, 1] + face_2_verts[:, 2]) <= 1)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
face_2_normals,
|
||||
torch.tensor([1, 0, 0], dtype=torch.float32).expand(
|
||||
face_2_normals.size()
|
||||
),
|
||||
)
|
||||
self.assertClose(
|
||||
face_2_normals,
|
||||
torch.tensor([1, 0, 0], dtype=torch.float32).expand(
|
||||
face_2_normals.size()
|
||||
),
|
||||
)
|
||||
|
||||
# Face 3: y = 0, x + z <= 1, normals = (0, -1, 0).
|
||||
@@ -171,13 +161,11 @@ class TestSamplePoints(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
torch.all((face_3_verts[:, 0] + face_3_verts[:, 2]) <= 1)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
face_3_normals,
|
||||
torch.tensor([0, -1, 0], dtype=torch.float32).expand(
|
||||
face_3_normals.size()
|
||||
),
|
||||
)
|
||||
self.assertClose(
|
||||
face_3_normals,
|
||||
torch.tensor([0, -1, 0], dtype=torch.float32).expand(
|
||||
face_3_normals.size()
|
||||
),
|
||||
)
|
||||
|
||||
# Face 4: x + y + z = 1, normals = (1, 1, 1)/sqrt(3).
|
||||
@@ -186,22 +174,16 @@ class TestSamplePoints(unittest.TestCase):
|
||||
pyramid_verts[face_4_idxs, :],
|
||||
pyramid_normals[face_4_idxs, :],
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
face_4_verts.sum(1), torch.ones(face_4_verts.size(0))
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
face_4_normals,
|
||||
(
|
||||
torch.tensor([1, 1, 1], dtype=torch.float32)
|
||||
/ torch.sqrt(torch.tensor(3, dtype=torch.float32))
|
||||
).expand(face_4_normals.size()),
|
||||
)
|
||||
self.assertClose(face_4_verts.sum(1), torch.ones(face_4_verts.size(0)))
|
||||
self.assertClose(
|
||||
face_4_normals,
|
||||
(
|
||||
torch.tensor([1, 1, 1], dtype=torch.float32)
|
||||
/ torch.sqrt(torch.tensor(3, dtype=torch.float32))
|
||||
).expand(face_4_normals.size()),
|
||||
)
|
||||
|
||||
def test_mutinomial(self):
|
||||
def test_multinomial(self):
|
||||
"""
|
||||
Confirm that torch.multinomial does not sample elements which have
|
||||
zero probability.
|
||||
|
||||
Reference in New Issue
Block a user