use assertClose

Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph.

Reviewed By: nikhilaravi

Differential Revision: D20556912

fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
This commit is contained in:
Jeremy Reizenstein 2020-03-23 11:33:10 -07:00 committed by Facebook GitHub Bot
parent 744ef0c2c8
commit 595aca27ea
13 changed files with 216 additions and 241 deletions

View File

@ -65,7 +65,7 @@ class GraphConv(nn.Module):
) )
if verts.shape[0] == 0: if verts.shape[0] == 0:
# empty graph. # empty graph.
return verts.sum() * 0.0 return verts.new_zeros((0, self.output_dim)) * verts.sum()
verts_w0 = self.w0(verts) # (V, output_dim) verts_w0 = self.w0(verts) # (V, output_dim)
verts_w1 = self.w1(verts) # (V, output_dim) verts_w1 = self.w1(verts) # (V, output_dim)

View File

@ -111,7 +111,7 @@ def orthographic_project_naive(points, scale_xyz=(1.0, 1.0, 1.0)):
return points return points
class TestCameraHelpers(unittest.TestCase): class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
torch.manual_seed(42) torch.manual_seed(42)
@ -154,7 +154,7 @@ class TestCameraHelpers(unittest.TestCase):
[0.0, 2.7, 0.0], dtype=torch.float32 [0.0, 2.7, 0.0], dtype=torch.float32
).view(1, 3) ).view(1, 3)
position = camera_position_from_spherical_angles(dist, elev, azim) position = camera_position_from_spherical_angles(dist, elev, azim)
self.assertTrue(torch.allclose(position, expected_position, atol=2e-7)) self.assertClose(position, expected_position, atol=2e-7)
def test_camera_position_from_angles_python_scalar_radians(self): def test_camera_position_from_angles_python_scalar_radians(self):
dist = 2.7 dist = 2.7
@ -165,7 +165,7 @@ class TestCameraHelpers(unittest.TestCase):
position = camera_position_from_spherical_angles( position = camera_position_from_spherical_angles(
dist, elev, azim, degrees=False dist, elev, azim, degrees=False
) )
self.assertTrue(torch.allclose(position, expected_position, atol=2e-7)) self.assertClose(position, expected_position, atol=2e-7)
def test_camera_position_from_angles_torch_scalars(self): def test_camera_position_from_angles_torch_scalars(self):
dist = torch.tensor(2.7) dist = torch.tensor(2.7)
@ -175,7 +175,7 @@ class TestCameraHelpers(unittest.TestCase):
[2.7, 0.0, 0.0], dtype=torch.float32 [2.7, 0.0, 0.0], dtype=torch.float32
).view(1, 3) ).view(1, 3)
position = camera_position_from_spherical_angles(dist, elev, azim) position = camera_position_from_spherical_angles(dist, elev, azim)
self.assertTrue(torch.allclose(position, expected_position, atol=2e-7)) self.assertClose(position, expected_position, atol=2e-7)
def test_camera_position_from_angles_mixed_scalars(self): def test_camera_position_from_angles_mixed_scalars(self):
dist = 2.7 dist = 2.7
@ -185,7 +185,7 @@ class TestCameraHelpers(unittest.TestCase):
[2.7, 0.0, 0.0], dtype=torch.float32 [2.7, 0.0, 0.0], dtype=torch.float32
).view(1, 3) ).view(1, 3)
position = camera_position_from_spherical_angles(dist, elev, azim) position = camera_position_from_spherical_angles(dist, elev, azim)
self.assertTrue(torch.allclose(position, expected_position, atol=2e-7)) self.assertClose(position, expected_position, atol=2e-7)
def test_camera_position_from_angles_torch_scalar_grads(self): def test_camera_position_from_angles_torch_scalar_grads(self):
dist = torch.tensor(2.7, requires_grad=True) dist = torch.tensor(2.7, requires_grad=True)
@ -210,8 +210,8 @@ class TestCameraHelpers(unittest.TestCase):
- torch.sin(elev) * torch.cos(azim) - torch.sin(elev) * torch.cos(azim)
) )
grad_elev = dist * (math.pi / 180.0) * grad_elev grad_elev = dist * (math.pi / 180.0) * grad_elev
self.assertTrue(torch.allclose(elev_grad, grad_elev)) self.assertClose(elev_grad, grad_elev)
self.assertTrue(torch.allclose(dist_grad, grad_dist)) self.assertClose(dist_grad, grad_dist)
def test_camera_position_from_angles_vectors(self): def test_camera_position_from_angles_vectors(self):
dist = torch.tensor([2.0, 2.0]) dist = torch.tensor([2.0, 2.0])
@ -221,7 +221,7 @@ class TestCameraHelpers(unittest.TestCase):
[[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32 [[2.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=torch.float32
) )
position = camera_position_from_spherical_angles(dist, elev, azim) position = camera_position_from_spherical_angles(dist, elev, azim)
self.assertTrue(torch.allclose(position, expected_position, atol=2e-7)) self.assertClose(position, expected_position, atol=2e-7)
def test_camera_position_from_angles_vectors_broadcast(self): def test_camera_position_from_angles_vectors_broadcast(self):
dist = torch.tensor([2.0, 3.0, 5.0]) dist = torch.tensor([2.0, 3.0, 5.0])
@ -232,7 +232,7 @@ class TestCameraHelpers(unittest.TestCase):
dtype=torch.float32, dtype=torch.float32,
) )
position = camera_position_from_spherical_angles(dist, elev, azim) position = camera_position_from_spherical_angles(dist, elev, azim)
self.assertTrue(torch.allclose(position, expected_position, atol=3e-7)) self.assertClose(position, expected_position, atol=3e-7)
def test_camera_position_from_angles_vectors_mixed_broadcast(self): def test_camera_position_from_angles_vectors_mixed_broadcast(self):
dist = torch.tensor([2.0, 3.0, 5.0]) dist = torch.tensor([2.0, 3.0, 5.0])
@ -243,7 +243,7 @@ class TestCameraHelpers(unittest.TestCase):
dtype=torch.float32, dtype=torch.float32,
) )
position = camera_position_from_spherical_angles(dist, elev, azim) position = camera_position_from_spherical_angles(dist, elev, azim)
self.assertTrue(torch.allclose(position, expected_position, atol=3e-7)) self.assertClose(position, expected_position, atol=3e-7)
def test_camera_position_from_angles_vectors_mixed_broadcast_grads(self): def test_camera_position_from_angles_vectors_mixed_broadcast_grads(self):
dist = torch.tensor([2.0, 3.0, 5.0], requires_grad=True) dist = torch.tensor([2.0, 3.0, 5.0], requires_grad=True)
@ -269,8 +269,8 @@ class TestCameraHelpers(unittest.TestCase):
- torch.sin(elev) * torch.cos(azim) - torch.sin(elev) * torch.cos(azim)
) )
grad_elev = (dist * (math.pi / 180.0) * grad_elev).sum() grad_elev = (dist * (math.pi / 180.0) * grad_elev).sum()
self.assertTrue(torch.allclose(elev_grad, grad_elev)) self.assertClose(elev_grad, grad_elev)
self.assertTrue(torch.allclose(dist_grad, grad_dist)) self.assertClose(dist_grad, torch.full([3], grad_dist))
def test_camera_position_from_angles_vectors_bad_broadcast(self): def test_camera_position_from_angles_vectors_bad_broadcast(self):
# Batch dim for broadcast must be N or 1 # Batch dim for broadcast must be N or 1
@ -283,7 +283,7 @@ class TestCameraHelpers(unittest.TestCase):
def test_look_at_rotation_python_list(self): def test_look_at_rotation_python_list(self):
camera_position = [[0.0, 0.0, -1.0]] # camera pointing along negative z camera_position = [[0.0, 0.0, -1.0]] # camera pointing along negative z
rot_mat = look_at_rotation(camera_position) rot_mat = look_at_rotation(camera_position)
self.assertTrue(torch.allclose(rot_mat, torch.eye(3)[None], atol=2e-7)) self.assertClose(rot_mat, torch.eye(3)[None], atol=2e-7)
def test_look_at_rotation_input_fail(self): def test_look_at_rotation_input_fail(self):
camera_position = [-1.0] # expected to have xyz positions camera_position = [-1.0] # expected to have xyz positions
@ -310,7 +310,7 @@ class TestCameraHelpers(unittest.TestCase):
) )
# fmt: on # fmt: on
rot_mats = look_at_rotation(camera_positions) rot_mats = look_at_rotation(camera_positions)
self.assertTrue(torch.allclose(rot_mats, rot_mats_expected, atol=2e-7)) self.assertClose(rot_mats, rot_mats_expected, atol=2e-7)
def test_look_at_rotation_tensor_broadcast(self): def test_look_at_rotation_tensor_broadcast(self):
# fmt: off # fmt: off
@ -335,19 +335,15 @@ class TestCameraHelpers(unittest.TestCase):
) )
# fmt: on # fmt: on
rot_mats = look_at_rotation(camera_positions) rot_mats = look_at_rotation(camera_positions)
self.assertTrue(torch.allclose(rot_mats, rot_mats_expected, atol=2e-7)) self.assertClose(rot_mats, rot_mats_expected, atol=2e-7)
def test_look_at_rotation_tensor_grad(self): def test_look_at_rotation_tensor_grad(self):
camera_position = torch.tensor([[0.0, 0.0, -1.0]], requires_grad=True) camera_position = torch.tensor([[0.0, 0.0, -1.0]], requires_grad=True)
rot_mat = look_at_rotation(camera_position) rot_mat = look_at_rotation(camera_position)
rot_mat.sum().backward() rot_mat.sum().backward()
self.assertTrue(hasattr(camera_position, "grad")) self.assertTrue(hasattr(camera_position, "grad"))
self.assertTrue( self.assertClose(
torch.allclose( camera_position.grad, torch.zeros_like(camera_position), atol=2e-7
camera_position.grad,
torch.zeros_like(camera_position),
atol=2e-7,
)
) )
def test_view_transform(self): def test_view_transform(self):
@ -403,9 +399,9 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
vertices = vertices[None, None, :] vertices = vertices[None, None, :]
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
v2 = perspective_project_naive(vertices, fov=60.0) v2 = perspective_project_naive(vertices, fov=60.0)
self.assertTrue(torch.allclose(v1[..., :2], v2[..., :2])) self.assertClose(v1[..., :2], v2[..., :2])
self.assertTrue(torch.allclose(far * v1[..., 2], v2[..., 2])) self.assertClose(far * v1[..., 2], v2[..., 2])
self.assertTrue(torch.allclose(v1.squeeze(), projected_verts)) self.assertClose(v1.squeeze(), projected_verts)
# vertices are at the near clipping plane so z gets mapped to 0.0. # vertices are at the near clipping plane so z gets mapped to 0.0.
vertices[..., 2] = near vertices[..., 2] = near
@ -414,8 +410,8 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
) )
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
v2 = perspective_project_naive(vertices, fov=60.0) v2 = perspective_project_naive(vertices, fov=60.0)
self.assertTrue(torch.allclose(v1[..., :2], v2[..., :2])) self.assertClose(v1[..., :2], v2[..., :2])
self.assertTrue(torch.allclose(v1.squeeze(), projected_verts)) self.assertClose(v1.squeeze(), projected_verts)
def test_perspective_kwargs(self): def test_perspective_kwargs(self):
cameras = OpenGLPerspectiveCameras(znear=5.0, zfar=100.0, fov=0.0) cameras = OpenGLPerspectiveCameras(znear=5.0, zfar=100.0, fov=0.0)
@ -428,7 +424,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
) )
vertices = vertices[None, None, :] vertices = vertices[None, None, :]
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
self.assertTrue(torch.allclose(v1.squeeze(), projected_verts)) self.assertClose(v1.squeeze(), projected_verts)
def test_perspective_mixed_inputs_broadcast(self): def test_perspective_mixed_inputs_broadcast(self):
far = torch.tensor([10.0, 20.0], dtype=torch.float32) far = torch.tensor([10.0, 20.0], dtype=torch.float32)
@ -449,8 +445,8 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
vertices = vertices[None, None, :] vertices = vertices[None, None, :]
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
v2 = perspective_project_naive(vertices, fov=60.0) v2 = perspective_project_naive(vertices, fov=60.0)
self.assertTrue(torch.allclose(v1[..., :2], v2[..., :2])) self.assertClose(v1[..., :2], torch.cat([v2, v2])[..., :2])
self.assertTrue(torch.allclose(v1.squeeze(), projected_verts)) self.assertClose(v1.squeeze(), projected_verts)
def test_perspective_mixed_inputs_grad(self): def test_perspective_mixed_inputs_grad(self):
far = torch.tensor([10.0]) far = torch.tensor([10.0])
@ -468,7 +464,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
grad_cotan = -(1.0 / (torch.sin(half_fov_rad) ** 2.0) * 1 / 2.0) grad_cotan = -(1.0 / (torch.sin(half_fov_rad) ** 2.0) * 1 / 2.0)
grad_fov = (math.pi / 180.0) * grad_cotan grad_fov = (math.pi / 180.0) * grad_cotan
grad_fov = (vertices[0] + vertices[1]) * grad_fov / 10.0 grad_fov = (vertices[0] + vertices[1]) * grad_fov / 10.0
self.assertTrue(torch.allclose(fov_grad, grad_fov)) self.assertClose(fov_grad, grad_fov)
def test_camera_class_init(self): def test_camera_class_init(self):
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -496,8 +492,8 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
R = look_at_rotation(T) R = look_at_rotation(T)
P = cam.get_full_projection_transform(R=R, T=T) P = cam.get_full_projection_transform(R=R, T=T)
self.assertTrue(isinstance(P, Transform3d)) self.assertTrue(isinstance(P, Transform3d))
self.assertTrue(torch.allclose(cam.R, R)) self.assertClose(cam.R, R)
self.assertTrue(torch.allclose(cam.T, T)) self.assertClose(cam.T, T)
def test_transform_points(self): def test_transform_points(self):
# Check transform_points methods works with default settings for # Check transform_points methods works with default settings for
@ -511,7 +507,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
) )
projected_points = projected_points.view(1, 1, 3).expand(5, 10, -1) projected_points = projected_points.view(1, 1, 3).expand(5, 10, -1)
new_points = cam.transform_points(points) new_points = cam.transform_points(points)
self.assertTrue(torch.allclose(new_points, projected_points)) self.assertClose(new_points, projected_points)
class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase): class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
@ -526,15 +522,15 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
vertices = vertices[None, None, :] vertices = vertices[None, None, :]
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
v2 = orthographic_project_naive(vertices) v2 = orthographic_project_naive(vertices)
self.assertTrue(torch.allclose(v1[..., :2], v2[..., :2])) self.assertClose(v1[..., :2], v2[..., :2])
self.assertTrue(torch.allclose(v1.squeeze(), projected_verts)) self.assertClose(v1.squeeze(), projected_verts)
vertices[..., 2] = near vertices[..., 2] = near
projected_verts[2] = 0.0 projected_verts[2] = 0.0
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
v2 = orthographic_project_naive(vertices) v2 = orthographic_project_naive(vertices)
self.assertTrue(torch.allclose(v1[..., :2], v2[..., :2])) self.assertClose(v1[..., :2], v2[..., :2])
self.assertTrue(torch.allclose(v1.squeeze(), projected_verts)) self.assertClose(v1.squeeze(), projected_verts)
def test_orthographic_scaled(self): def test_orthographic_scaled(self):
vertices = torch.tensor([1, 2, 0.5], dtype=torch.float32) vertices = torch.tensor([1, 2, 0.5], dtype=torch.float32)
@ -549,8 +545,8 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
P = cameras.get_projection_transform() P = cameras.get_projection_transform()
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
v2 = orthographic_project_naive(vertices, scale) v2 = orthographic_project_naive(vertices, scale)
self.assertTrue(torch.allclose(v1[..., :2], v2[..., :2])) self.assertClose(v1[..., :2], v2[..., :2])
self.assertTrue(torch.allclose(v1, projected_verts)) self.assertClose(v1, projected_verts[None, None])
def test_orthographic_kwargs(self): def test_orthographic_kwargs(self):
cameras = OpenGLOrthographicCameras(znear=5.0, zfar=100.0) cameras = OpenGLOrthographicCameras(znear=5.0, zfar=100.0)
@ -560,14 +556,13 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
projected_verts = torch.tensor([1, 2, 1], dtype=torch.float32) projected_verts = torch.tensor([1, 2, 1], dtype=torch.float32)
vertices = vertices[None, None, :] vertices = vertices[None, None, :]
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
self.assertTrue(torch.allclose(v1.squeeze(), projected_verts)) self.assertClose(v1.squeeze(), projected_verts)
def test_orthographic_mixed_inputs_broadcast(self): def test_orthographic_mixed_inputs_broadcast(self):
far = torch.tensor([10.0, 20.0]) far = torch.tensor([10.0, 20.0])
near = 1.0 near = 1.0
cameras = OpenGLOrthographicCameras(znear=near, zfar=far) cameras = OpenGLOrthographicCameras(znear=near, zfar=far)
P = cameras.get_projection_transform() P = cameras.get_projection_transform()
vertices = torch.tensor([1.0, 2.0, 10.0], dtype=torch.float32) vertices = torch.tensor([1.0, 2.0, 10.0], dtype=torch.float32)
z2 = 1.0 / (20.0 - 1.0) * 10.0 + -(1.0) / (20.0 - 1.0) z2 = 1.0 / (20.0 - 1.0) * 10.0 + -(1.0) / (20.0 - 1.0)
projected_verts = torch.tensor( projected_verts = torch.tensor(
@ -576,8 +571,8 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
vertices = vertices[None, None, :] vertices = vertices[None, None, :]
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
v2 = orthographic_project_naive(vertices) v2 = orthographic_project_naive(vertices)
self.assertTrue(torch.allclose(v1[..., :2], v2[..., :2])) self.assertClose(v1[..., :2], torch.cat([v2, v2])[..., :2])
self.assertTrue(torch.allclose(v1.squeeze(), projected_verts)) self.assertClose(v1.squeeze(), projected_verts)
def test_orthographic_mixed_inputs_grad(self): def test_orthographic_mixed_inputs_grad(self):
far = torch.tensor([10.0]) far = torch.tensor([10.0])
@ -602,7 +597,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
] ]
] ]
) )
self.assertTrue(torch.allclose(scale_grad, grad_scale)) self.assertClose(scale_grad, grad_scale)
class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase): class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase):
@ -615,8 +610,8 @@ class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase):
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
v2 = orthographic_project_naive(vertices) v2 = orthographic_project_naive(vertices)
self.assertTrue(torch.allclose(v1[..., :2], v2[..., :2])) self.assertClose(v1[..., :2], v2[..., :2])
self.assertTrue(torch.allclose(v1, projected_verts)) self.assertClose(v1, projected_verts)
def test_orthographic_scaled(self): def test_orthographic_scaled(self):
focal_length_x = 10.0 focal_length_x = 10.0
@ -636,9 +631,9 @@ class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase):
vertices, scale_xyz=(focal_length_x, focal_length_y, 1.0) vertices, scale_xyz=(focal_length_x, focal_length_y, 1.0)
) )
v3 = cameras.transform_points(vertices) v3 = cameras.transform_points(vertices)
self.assertTrue(torch.allclose(v1[..., :2], v2[..., :2])) self.assertClose(v1[..., :2], v2[..., :2])
self.assertTrue(torch.allclose(v3[..., :2], v2[..., :2])) self.assertClose(v3[..., :2], v2[..., :2])
self.assertTrue(torch.allclose(v1, projected_verts)) self.assertClose(v1, projected_verts)
def test_orthographic_kwargs(self): def test_orthographic_kwargs(self):
cameras = SfMOrthographicCameras( cameras = SfMOrthographicCameras(
@ -653,7 +648,7 @@ class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase):
projected_verts[:, :, 0] += 2.5 projected_verts[:, :, 0] += 2.5
projected_verts[:, :, 1] += 3.5 projected_verts[:, :, 1] += 3.5
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
self.assertTrue(torch.allclose(v1, projected_verts)) self.assertClose(v1, projected_verts)
class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase): class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase):
@ -664,7 +659,7 @@ class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase):
vertices = torch.randn([3, 4, 3], dtype=torch.float32) vertices = torch.randn([3, 4, 3], dtype=torch.float32)
v1 = P.transform_points(vertices) v1 = P.transform_points(vertices)
v2 = sfm_perspective_project_naive(vertices) v2 = sfm_perspective_project_naive(vertices)
self.assertTrue(torch.allclose(v1, v2)) self.assertClose(v1, v2)
def test_perspective_scaled(self): def test_perspective_scaled(self):
focal_length_x = 10.0 focal_length_x = 10.0
@ -684,8 +679,8 @@ class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase):
vertices, fx=focal_length_x, fy=focal_length_y, p0x=p0x, p0y=p0y vertices, fx=focal_length_x, fy=focal_length_y, p0x=p0x, p0y=p0y
) )
v3 = cameras.transform_points(vertices) v3 = cameras.transform_points(vertices)
self.assertTrue(torch.allclose(v1, v2)) self.assertClose(v1, v2)
self.assertTrue(torch.allclose(v3[..., :2], v2[..., :2])) self.assertClose(v3[..., :2], v2[..., :2])
def test_perspective_kwargs(self): def test_perspective_kwargs(self):
cameras = SfMPerspectiveCameras( cameras = SfMPerspectiveCameras(
@ -699,4 +694,4 @@ class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase):
v2 = sfm_perspective_project_naive( v2 = sfm_perspective_project_naive(
vertices, fx=2.0, fy=2.0, p0x=2.5, p0y=3.5 vertices, fx=2.0, fy=2.0, p0x=2.5, p0y=3.5
) )
self.assertTrue(torch.allclose(v1, v2)) self.assertClose(v1, v2)

View File

@ -6,8 +6,10 @@ import torch.nn.functional as F
from pytorch3d.loss import chamfer_distance from pytorch3d.loss import chamfer_distance
from common_testing import TestCaseMixin
class TestChamfer(unittest.TestCase):
class TestChamfer(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64): def init_pointclouds(batch_size: int = 10, P1: int = 32, P2: int = 64):
""" """
@ -85,7 +87,7 @@ class TestChamfer(unittest.TestCase):
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2 pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum() pred_loss = pred_loss.sum() / weights.sum()
self.assertTrue(torch.allclose(loss, pred_loss)) self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None) self.assertTrue(loss_norm is None)
def test_chamfer_point_reduction(self): def test_chamfer_point_reduction(self):
@ -115,13 +117,13 @@ class TestChamfer(unittest.TestCase):
) )
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2 pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss_mean *= weights pred_loss_mean *= weights
self.assertTrue(torch.allclose(loss, pred_loss_mean)) self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean = ( pred_loss_norm_mean = (
pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2 pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
) )
pred_loss_norm_mean *= weights pred_loss_norm_mean *= weights
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_mean)) self.assertClose(loss_norm, pred_loss_norm_mean)
# point_reduction = "sum". # point_reduction = "sum".
loss, loss_norm = chamfer_distance( loss, loss_norm = chamfer_distance(
@ -135,11 +137,11 @@ class TestChamfer(unittest.TestCase):
) )
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1) pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1)
pred_loss_sum *= weights pred_loss_sum *= weights
self.assertTrue(torch.allclose(loss, pred_loss_sum)) self.assertClose(loss, pred_loss_sum)
pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(1) pred_loss_norm_sum = pred_loss_norm[0].sum(1) + pred_loss_norm[1].sum(1)
pred_loss_norm_sum *= weights pred_loss_norm_sum *= weights
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_sum)) self.assertClose(loss_norm, pred_loss_norm_sum)
# Error when point_reduction = "none" and batch_reduction = "none". # Error when point_reduction = "none" and batch_reduction = "none".
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -182,12 +184,12 @@ class TestChamfer(unittest.TestCase):
pred_loss[0] *= weights.view(N, 1) pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1) pred_loss[1] *= weights.view(N, 1)
pred_loss = pred_loss[0].sum() + pred_loss[1].sum() pred_loss = pred_loss[0].sum() + pred_loss[1].sum()
self.assertTrue(torch.allclose(loss, pred_loss)) self.assertClose(loss, pred_loss)
pred_loss_norm[0] *= weights.view(N, 1) pred_loss_norm[0] *= weights.view(N, 1)
pred_loss_norm[1] *= weights.view(N, 1) pred_loss_norm[1] *= weights.view(N, 1)
pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum() pred_loss_norm = pred_loss_norm[0].sum() + pred_loss_norm[1].sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm)) self.assertClose(loss_norm, pred_loss_norm)
# batch_reduction = "mean". # batch_reduction = "mean".
loss, loss_norm = chamfer_distance( loss, loss_norm = chamfer_distance(
@ -201,10 +203,10 @@ class TestChamfer(unittest.TestCase):
) )
pred_loss /= weights.sum() pred_loss /= weights.sum()
self.assertTrue(torch.allclose(loss, pred_loss)) self.assertClose(loss, pred_loss)
pred_loss_norm /= weights.sum() pred_loss_norm /= weights.sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm)) self.assertClose(loss_norm, pred_loss_norm)
# Error when point_reduction is not in ["none", "mean", "sum"]. # Error when point_reduction is not in ["none", "mean", "sum"].
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -239,7 +241,7 @@ class TestChamfer(unittest.TestCase):
pred_loss[1] *= weights.view(N, 1) pred_loss[1] *= weights.view(N, 1)
pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1) # point sum pred_loss_sum = pred_loss[0].sum(1) + pred_loss[1].sum(1) # point sum
pred_loss_sum = pred_loss_sum.sum() # batch sum pred_loss_sum = pred_loss_sum.sum() # batch sum
self.assertTrue(torch.allclose(loss, pred_loss_sum)) self.assertClose(loss, pred_loss_sum)
pred_loss_norm[0] *= weights.view(N, 1) pred_loss_norm[0] *= weights.view(N, 1)
pred_loss_norm[1] *= weights.view(N, 1) pred_loss_norm[1] *= weights.view(N, 1)
@ -247,7 +249,7 @@ class TestChamfer(unittest.TestCase):
1 1
) # point sum. ) # point sum.
pred_loss_norm_sum = pred_loss_norm_sum.sum() # batch sum pred_loss_norm_sum = pred_loss_norm_sum.sum() # batch sum
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_sum)) self.assertClose(loss_norm, pred_loss_norm_sum)
# batch_reduction = "mean", point_reduction = "sum". # batch_reduction = "mean", point_reduction = "sum".
loss, loss_norm = chamfer_distance( loss, loss_norm = chamfer_distance(
@ -260,10 +262,10 @@ class TestChamfer(unittest.TestCase):
point_reduction="sum", point_reduction="sum",
) )
pred_loss_sum /= weights.sum() pred_loss_sum /= weights.sum()
self.assertTrue(torch.allclose(loss, pred_loss_sum)) self.assertClose(loss, pred_loss_sum)
pred_loss_norm_sum /= weights.sum() pred_loss_norm_sum /= weights.sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_sum)) self.assertClose(loss_norm, pred_loss_norm_sum)
# batch_reduction = "sum", point_reduction = "mean". # batch_reduction = "sum", point_reduction = "mean".
loss, loss_norm = chamfer_distance( loss, loss_norm = chamfer_distance(
@ -277,13 +279,13 @@ class TestChamfer(unittest.TestCase):
) )
pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2 pred_loss_mean = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss_mean = pred_loss_mean.sum() pred_loss_mean = pred_loss_mean.sum()
self.assertTrue(torch.allclose(loss, pred_loss_mean)) self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean = ( pred_loss_norm_mean = (
pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2 pred_loss_norm[0].sum(1) / P1 + pred_loss_norm[1].sum(1) / P2
) )
pred_loss_norm_mean = pred_loss_norm_mean.sum() pred_loss_norm_mean = pred_loss_norm_mean.sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_mean)) self.assertClose(loss_norm, pred_loss_norm_mean)
# batch_reduction = "mean", point_reduction = "mean". This is the default. # batch_reduction = "mean", point_reduction = "mean". This is the default.
loss, loss_norm = chamfer_distance( loss, loss_norm = chamfer_distance(
@ -296,10 +298,10 @@ class TestChamfer(unittest.TestCase):
point_reduction="mean", point_reduction="mean",
) )
pred_loss_mean /= weights.sum() pred_loss_mean /= weights.sum()
self.assertTrue(torch.allclose(loss, pred_loss_mean)) self.assertClose(loss, pred_loss_mean)
pred_loss_norm_mean /= weights.sum() pred_loss_norm_mean /= weights.sum()
self.assertTrue(torch.allclose(loss_norm, pred_loss_norm_mean)) self.assertClose(loss_norm, pred_loss_norm_mean)
def test_incorrect_weights(self): def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128 N, P1, P2 = 16, 64, 128
@ -315,17 +317,17 @@ class TestChamfer(unittest.TestCase):
loss, loss_norm = chamfer_distance( loss, loss_norm = chamfer_distance(
p1, p2, weights=weights, batch_reduction="mean" p1, p2, weights=weights, batch_reduction="mean"
) )
self.assertTrue(torch.allclose(loss.cpu(), torch.zeros((1,)))) self.assertClose(loss.cpu(), torch.zeros(()))
self.assertTrue(loss.requires_grad) self.assertTrue(loss.requires_grad)
self.assertTrue(torch.allclose(loss_norm.cpu(), torch.zeros((1,)))) self.assertClose(loss_norm.cpu(), torch.zeros(()))
self.assertTrue(loss_norm.requires_grad) self.assertTrue(loss_norm.requires_grad)
loss, loss_norm = chamfer_distance( loss, loss_norm = chamfer_distance(
p1, p2, weights=weights, batch_reduction="none" p1, p2, weights=weights, batch_reduction="none"
) )
self.assertTrue(torch.allclose(loss.cpu(), torch.zeros((N,)))) self.assertClose(loss.cpu(), torch.zeros((N, N)))
self.assertTrue(loss.requires_grad) self.assertTrue(loss.requires_grad)
self.assertTrue(torch.allclose(loss_norm.cpu(), torch.zeros((N,)))) self.assertClose(loss_norm.cpu(), torch.zeros((N, N)))
self.assertTrue(loss_norm.requires_grad) self.assertTrue(loss_norm.requires_grad)
weights = torch.ones((N,), dtype=torch.float32, device=device) * -1 weights = torch.ones((N,), dtype=torch.float32, device=device) * -1

View File

@ -13,8 +13,10 @@ from pytorch3d.ops.graph_conv import (
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils import ico_sphere from pytorch3d.utils import ico_sphere
from common_testing import TestCaseMixin
class TestGraphConv(unittest.TestCase):
class TestGraphConv(TestCaseMixin, unittest.TestCase):
def test_undirected(self): def test_undirected(self):
dtype = torch.float32 dtype = torch.float32
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -42,7 +44,7 @@ class TestGraphConv(unittest.TestCase):
conv.w1.bias.data.zero_() conv.w1.bias.data.zero_()
y = conv(verts, edges) y = conv(verts, edges)
self.assertTrue(torch.allclose(y, expected_y)) self.assertClose(y, expected_y)
def test_no_edges(self): def test_no_edges(self):
dtype = torch.float32 dtype = torch.float32
@ -57,19 +59,26 @@ class TestGraphConv(unittest.TestCase):
conv.w0.bias.data.zero_() conv.w0.bias.data.zero_()
y = conv(verts, edges) y = conv(verts, edges)
self.assertTrue(torch.allclose(y, expected_y)) self.assertClose(y, expected_y)
def test_no_verts_and_edges(self): def test_no_verts_and_edges(self):
dtype = torch.float32 dtype = torch.float32
verts = torch.tensor([], dtype=dtype, requires_grad=True) verts = torch.tensor([], dtype=dtype, requires_grad=True)
edges = torch.tensor([], dtype=dtype) edges = torch.tensor([], dtype=dtype)
w0 = torch.tensor([[1, -1, -2]], dtype=dtype) w0 = torch.tensor([[1, -1, -2]], dtype=dtype)
conv = GraphConv(3, 1).to(dtype) conv = GraphConv(3, 1).to(dtype)
conv.w0.weight.data.copy_(w0) conv.w0.weight.data.copy_(w0)
conv.w0.bias.data.zero_() conv.w0.bias.data.zero_()
y = conv(verts, edges) y = conv(verts, edges)
self.assertTrue(torch.allclose(y, torch.tensor([]))) self.assertClose(y, torch.zeros((0, 1)))
self.assertTrue(y.requires_grad)
conv2 = GraphConv(3, 2).to(dtype)
conv2.w0.weight.data.copy_(w0.repeat(2, 1))
conv2.w0.bias.data.zero_()
y = conv2(verts, edges)
self.assertClose(y, torch.zeros((0, 2)))
self.assertTrue(y.requires_grad) self.assertTrue(y.requires_grad)
def test_directed(self): def test_directed(self):
@ -91,7 +100,7 @@ class TestGraphConv(unittest.TestCase):
conv.w1.bias.data.zero_() conv.w1.bias.data.zero_()
y = conv(verts, edges) y = conv(verts, edges)
self.assertTrue(torch.allclose(y, expected_y)) self.assertClose(y, expected_y)
def test_backward(self): def test_backward(self):
device = torch.device("cuda:0") device = torch.device("cuda:0")
@ -108,7 +117,7 @@ class TestGraphConv(unittest.TestCase):
neighbor_sums_cuda.sum().backward() neighbor_sums_cuda.sum().backward()
neighbor_sums.sum().backward() neighbor_sums.sum().backward()
self.assertTrue(torch.allclose(verts.grad.cpu(), verts_cuda.grad.cpu())) self.assertClose(verts.grad.cpu(), verts_cuda.grad.cpu())
def test_repr(self): def test_repr(self):
conv = GraphConv(32, 64, directed=True) conv = GraphConv(32, 64, directed=True)
@ -147,7 +156,7 @@ class TestGraphConv(unittest.TestCase):
output_cuda = _C.gather_scatter( output_cuda = _C.gather_scatter(
input.to(device=device), edges.to(device=device), False, False input.to(device=device), edges.to(device=device), False, False
) )
self.assertTrue(torch.allclose(output_cuda.cpu(), output_cpu)) self.assertClose(output_cuda.cpu(), output_cpu)
with self.assertRaises(Exception) as err: with self.assertRaises(Exception) as err:
_C.gather_scatter(input.cpu(), edges.cpu(), False, False) _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
self.assertTrue("Not implemented on the CPU" in str(err.exception)) self.assertTrue("Not implemented on the CPU" in str(err.exception))
@ -157,7 +166,7 @@ class TestGraphConv(unittest.TestCase):
output_cuda = _C.gather_scatter( output_cuda = _C.gather_scatter(
input.to(device=device), edges.to(device=device), True, False input.to(device=device), edges.to(device=device), True, False
) )
self.assertTrue(torch.allclose(output_cuda.cpu(), output_cpu)) self.assertClose(output_cuda.cpu(), output_cpu)
@staticmethod @staticmethod
def graph_conv_forward_backward( def graph_conv_forward_backward(

View File

@ -64,14 +64,10 @@ class TestLights(TestCaseMixin, unittest.TestCase):
# Update element # Update element
color = (0.5, 0.5, 0.5) color = (0.5, 0.5, 0.5)
light[1].ambient_color = color light[1].ambient_color = color
self.assertTrue( self.assertClose(light.ambient_color[1], torch.tensor(color))
torch.allclose(light.ambient_color[1], torch.tensor(color))
)
# Get item and get value # Get item and get value
l0 = light[0] l0 = light[0]
self.assertTrue( self.assertClose(l0.ambient_color, torch.tensor((0.0, 0.0, 0.0)))
torch.allclose(l0.ambient_color, torch.tensor((0.0, 0.0, 0.0)))
)
def test_initialize_lights_broadcast(self): def test_initialize_lights_broadcast(self):
light = DirectionalLights( light = DirectionalLights(
@ -127,7 +123,7 @@ class TestLights(TestCaseMixin, unittest.TestCase):
PointLights(location=torch.randn(10, 4)) PointLights(location=torch.randn(10, 4))
class TestDiffuseLighting(unittest.TestCase): class TestDiffuseLighting(TestCaseMixin, unittest.TestCase):
def test_diffuse_directional_lights(self): def test_diffuse_directional_lights(self):
""" """
Test with a single point where: Test with a single point where:
@ -145,17 +141,17 @@ class TestDiffuseLighting(unittest.TestCase):
[1 / np.sqrt(2), 1 / np.sqrt(2), 1 / np.sqrt(2)], [1 / np.sqrt(2), 1 / np.sqrt(2), 1 / np.sqrt(2)],
dtype=torch.float32, dtype=torch.float32,
) )
expected_output = expected_output.view(-1, 1, 3) expected_output = expected_output.view(1, 1, 3).repeat(3, 1, 1)
light = DirectionalLights(diffuse_color=color, direction=direction) light = DirectionalLights(diffuse_color=color, direction=direction)
output_light = light.diffuse(normals=normals) output_light = light.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
# Change light direction to be 90 degrees apart from normal direction. # Change light direction to be 90 degrees apart from normal direction.
direction = torch.tensor([0, 1, 0], dtype=torch.float32) direction = torch.tensor([0, 1, 0], dtype=torch.float32)
light.direction = direction light.direction = direction
expected_output = torch.zeros_like(expected_output) expected_output = torch.zeros_like(expected_output)
output_light = light.diffuse(normals=normals) output_light = light.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
def test_diffuse_point_lights(self): def test_diffuse_point_lights(self):
""" """
@ -183,7 +179,7 @@ class TestDiffuseLighting(unittest.TestCase):
output_light = light.diffuse( output_light = light.diffuse(
points=points[None, None, :], normals=normals[None, None, :] points=points[None, None, :], normals=normals[None, None, :]
) )
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
# Change light direction to be 90 degrees apart from normal direction. # Change light direction to be 90 degrees apart from normal direction.
location = torch.tensor([0, 1, 0], dtype=torch.float32) location = torch.tensor([0, 1, 0], dtype=torch.float32)
@ -194,7 +190,7 @@ class TestDiffuseLighting(unittest.TestCase):
output_light = light.diffuse( output_light = light.diffuse(
points=points[None, None, :], normals=normals[None, None, :] points=points[None, None, :], normals=normals[None, None, :]
) )
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
def test_diffuse_batched(self): def test_diffuse_batched(self):
""" """
@ -220,7 +216,7 @@ class TestDiffuseLighting(unittest.TestCase):
lights = DirectionalLights(diffuse_color=color, direction=direction) lights = DirectionalLights(diffuse_color=color, direction=direction)
output_light = lights.diffuse(normals=normals) output_light = lights.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_out)) self.assertClose(output_light, expected_out)
def test_diffuse_batched_broadcast_inputs(self): def test_diffuse_batched_broadcast_inputs(self):
""" """
@ -250,7 +246,7 @@ class TestDiffuseLighting(unittest.TestCase):
lights = DirectionalLights(diffuse_color=color, direction=direction) lights = DirectionalLights(diffuse_color=color, direction=direction)
output_light = lights.diffuse(normals=normals) output_light = lights.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_out)) self.assertClose(output_light, expected_out)
def test_diffuse_batched_arbitrary_input_dims(self): def test_diffuse_batched_arbitrary_input_dims(self):
""" """
@ -280,7 +276,7 @@ class TestDiffuseLighting(unittest.TestCase):
lights = DirectionalLights(diffuse_color=color, direction=direction) lights = DirectionalLights(diffuse_color=color, direction=direction)
output_light = lights.diffuse(normals=normals) output_light = lights.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
def test_diffuse_batched_packed(self): def test_diffuse_batched_packed(self):
""" """
@ -311,10 +307,10 @@ class TestDiffuseLighting(unittest.TestCase):
direction=direction[mesh_to_vert_idx, :], direction=direction[mesh_to_vert_idx, :],
) )
output_light = lights.diffuse(normals=normals[mesh_to_vert_idx, :]) output_light = lights.diffuse(normals=normals[mesh_to_vert_idx, :])
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
class TestSpecularLighting(unittest.TestCase): class TestSpecularLighting(TestCaseMixin, unittest.TestCase):
def test_specular_directional_lights(self): def test_specular_directional_lights(self):
""" """
Specular highlights depend on the camera position as well as the light Specular highlights depend on the camera position as well as the light
@ -337,7 +333,7 @@ class TestSpecularLighting(unittest.TestCase):
points = torch.tensor([0, 0, 0], dtype=torch.float32) points = torch.tensor([0, 0, 0], dtype=torch.float32)
normals = torch.tensor([0, 1, 0], dtype=torch.float32) normals = torch.tensor([0, 1, 0], dtype=torch.float32)
expected_output = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float32) expected_output = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float32)
expected_output = expected_output.view(-1, 1, 3) expected_output = expected_output.view(1, 1, 3).repeat(3, 1, 1)
lights = DirectionalLights(specular_color=color, direction=direction) lights = DirectionalLights(specular_color=color, direction=direction)
output_light = lights.specular( output_light = lights.specular(
points=points[None, None, :], points=points[None, None, :],
@ -345,7 +341,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :], camera_position=camera_position[None, :],
shininess=torch.tensor(10), shininess=torch.tensor(10),
) )
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
# Change camera position to be behind the point. # Change camera position to be behind the point.
camera_position = torch.tensor( camera_position = torch.tensor(
@ -358,7 +354,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :], camera_position=camera_position[None, :],
shininess=torch.tensor(10), shininess=torch.tensor(10),
) )
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
def test_specular_point_lights(self): def test_specular_point_lights(self):
""" """
@ -386,7 +382,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :], camera_position=camera_position[None, :],
shininess=torch.tensor(10), shininess=torch.tensor(10),
) )
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
# Change camera position to be behind the point # Change camera position to be behind the point
camera_position = torch.tensor( camera_position = torch.tensor(
@ -399,7 +395,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :], camera_position=camera_position[None, :],
shininess=torch.tensor(10), shininess=torch.tensor(10),
) )
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
# Change camera direction to be 30 degrees from the reflection direction # Change camera direction to be 30 degrees from the reflection direction
camera_position = torch.tensor( camera_position = torch.tensor(
@ -418,7 +414,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :], camera_position=camera_position[None, :],
shininess=torch.tensor(10), shininess=torch.tensor(10),
) )
self.assertTrue(torch.allclose(output_light, expected_output ** 10)) self.assertClose(output_light, expected_output ** 10)
def test_specular_batched(self): def test_specular_batched(self):
batch_size = 10 batch_size = 10
@ -448,7 +444,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position, camera_position=camera_position,
shininess=torch.tensor(10), shininess=torch.tensor(10),
) )
self.assertTrue(torch.allclose(output_light, expected_out)) self.assertClose(output_light, expected_out)
def test_specular_batched_broadcast_inputs(self): def test_specular_batched_broadcast_inputs(self):
batch_size = 10 batch_size = 10
@ -481,7 +477,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position, camera_position=camera_position,
shininess=torch.tensor(10), shininess=torch.tensor(10),
) )
self.assertTrue(torch.allclose(output_light, expected_out)) self.assertClose(output_light, expected_out)
def test_specular_batched_arbitrary_input_dims(self): def test_specular_batched_arbitrary_input_dims(self):
""" """
@ -520,7 +516,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position, camera_position=camera_position,
shininess=10.0, shininess=10.0,
) )
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)
def test_specular_batched_packed(self): def test_specular_batched_packed(self):
""" """
@ -557,4 +553,4 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[mesh_to_vert_idx, :], camera_position=camera_position[mesh_to_vert_idx, :],
shininess=10.0, shininess=10.0,
) )
self.assertTrue(torch.allclose(output_light, expected_output)) self.assertClose(output_light, expected_output)

View File

@ -6,10 +6,11 @@ import torch
from pytorch3d.loss import mesh_edge_loss from pytorch3d.loss import mesh_edge_loss
from pytorch3d.structures import Meshes from pytorch3d.structures import Meshes
from common_testing import TestCaseMixin
from test_sample_points_from_meshes import TestSamplePoints from test_sample_points_from_meshes import TestSamplePoints
class TestMeshEdgeLoss(unittest.TestCase): class TestMeshEdgeLoss(TestCaseMixin, unittest.TestCase):
def test_empty_meshes(self): def test_empty_meshes(self):
device = torch.device("cuda:0") device = torch.device("cuda:0")
target_length = 0 target_length = 0
@ -26,11 +27,9 @@ class TestMeshEdgeLoss(unittest.TestCase):
mesh = Meshes(verts=verts_list, faces=faces_list) mesh = Meshes(verts=verts_list, faces=faces_list)
loss = mesh_edge_loss(mesh, target_length=target_length) loss = mesh_edge_loss(mesh, target_length=target_length)
self.assertTrue( self.assertClose(
torch.allclose(
loss, torch.tensor([0.0], dtype=torch.float32, device=device) loss, torch.tensor([0.0], dtype=torch.float32, device=device)
) )
)
self.assertTrue(loss.requires_grad) self.assertTrue(loss.requires_grad)
@staticmethod @staticmethod
@ -94,7 +93,7 @@ class TestMeshEdgeLoss(unittest.TestCase):
loss = mesh_edge_loss(meshes, target_length=target_length) loss = mesh_edge_loss(meshes, target_length=target_length)
predloss = TestMeshEdgeLoss.mesh_edge_loss_naive(meshes, target_length) predloss = TestMeshEdgeLoss.mesh_edge_loss_naive(meshes, target_length)
self.assertTrue(torch.allclose(loss, predloss)) self.assertClose(loss, predloss)
@staticmethod @staticmethod
def mesh_edge_loss( def mesh_edge_loss(

View File

@ -898,7 +898,7 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
face_areas = mesh.faces_areas_packed() face_areas = mesh.faces_areas_packed()
expected_areas = torch.tensor([0.125, 0.2]) expected_areas = torch.tensor([0.125, 0.2])
self.assertTrue(torch.allclose(face_areas, expected_areas)) self.assertClose(face_areas, expected_areas)
def test_compute_normals(self): def test_compute_normals(self):
@ -959,9 +959,9 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
# Multiple meshes in the batch with equal sized meshes # Multiple meshes in the batch with equal sized meshes
meshes_extended = mesh.extend(3) meshes_extended = mesh.extend(3)
for m in meshes_extended.verts_normals_list(): for m in meshes_extended.verts_normals_list():
self.assertTrue(torch.allclose(m, verts_normals_expected)) self.assertClose(m, verts_normals_expected)
for f in meshes_extended.faces_normals_list(): for f in meshes_extended.faces_normals_list():
self.assertTrue(torch.allclose(f, faces_normals_expected)) self.assertClose(f, faces_normals_expected)
# Multiple meshes in the batch with different sized meshes # Multiple meshes in the batch with different sized meshes
# Check padded and packed normals are the correct sizes. # Check padded and packed normals are the correct sizes.

View File

@ -133,14 +133,10 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertTrue(torch.all(verts == expected_verts)) self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces)) self.assertTrue(torch.all(faces.verts_idx == expected_faces))
self.assertTrue(torch.allclose(normals, expected_normals)) self.assertClose(normals, expected_normals)
self.assertTrue(torch.allclose(textures, expected_textures)) self.assertClose(textures, expected_textures)
self.assertTrue( self.assertClose(faces.normals_idx, expected_faces_normals_idx)
torch.allclose(faces.normals_idx, expected_faces_normals_idx) self.assertClose(faces.textures_idx, expected_faces_textures_idx)
)
self.assertTrue(
torch.allclose(faces.textures_idx, expected_faces_textures_idx)
)
self.assertTrue(materials is None) self.assertTrue(materials is None)
self.assertTrue(tex_maps is None) self.assertTrue(tex_maps is None)
@ -181,11 +177,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
textures = aux.verts_uvs textures = aux.verts_uvs
materials = aux.material_colors materials = aux.material_colors
tex_maps = aux.texture_images tex_maps = aux.texture_images
self.assertTrue( self.assertClose(faces.normals_idx, expected_faces_normals_idx)
torch.allclose(faces.normals_idx, expected_faces_normals_idx) self.assertClose(normals, expected_normals)
) self.assertClose(verts, expected_verts)
self.assertTrue(torch.allclose(normals, expected_normals))
self.assertTrue(torch.allclose(verts, expected_verts))
self.assertTrue(faces.textures_idx == []) self.assertTrue(faces.textures_idx == [])
self.assertTrue(textures is None) self.assertTrue(textures is None)
self.assertTrue(materials is None) self.assertTrue(materials is None)
@ -225,11 +219,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
materials = aux.material_colors materials = aux.material_colors
tex_maps = aux.texture_images tex_maps = aux.texture_images
self.assertTrue( self.assertClose(faces.textures_idx, expected_faces_textures_idx)
torch.allclose(faces.textures_idx, expected_faces_textures_idx) self.assertClose(expected_textures, textures)
) self.assertClose(expected_verts, verts)
self.assertTrue(torch.allclose(expected_textures, textures))
self.assertTrue(torch.allclose(expected_verts, verts))
self.assertTrue(faces.normals_idx == []) self.assertTrue(faces.normals_idx == [])
self.assertTrue(normals is None) self.assertTrue(normals is None)
self.assertTrue(materials is None) self.assertTrue(materials is None)

View File

@ -12,8 +12,10 @@ from pytorch3d.renderer.mesh.rasterize_meshes import (
from pytorch3d.structures import Meshes from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere from pytorch3d.utils import ico_sphere
from common_testing import TestCaseMixin
class TestRasterizeMeshes(unittest.TestCase):
class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_python(self): def test_simple_python(self):
device = torch.device("cpu") device = torch.device("cpu")
self._simple_triangle_raster( self._simple_triangle_raster(
@ -266,14 +268,14 @@ class TestRasterizeMeshes(unittest.TestCase):
# Make sure everything was the same # Make sure everything was the same
self.assertTrue((idx1 == idx2).all().item()) self.assertTrue((idx1 == idx2).all().item())
self.assertTrue((idx1 == idx3).all().item()) self.assertTrue((idx1 == idx3).all().item())
self.assertTrue(torch.allclose(zbuf1, zbuf2, atol=1e-6)) self.assertClose(zbuf1, zbuf2, atol=1e-6)
self.assertTrue(torch.allclose(zbuf1, zbuf3, atol=1e-6)) self.assertClose(zbuf1, zbuf3, atol=1e-6)
self.assertTrue(torch.allclose(dist1, dist2, atol=1e-6)) self.assertClose(dist1, dist2, atol=1e-6)
self.assertTrue(torch.allclose(dist1, dist3, atol=1e-6)) self.assertClose(dist1, dist3, atol=1e-6)
self.assertTrue(torch.allclose(grad1, grad2, rtol=5e-3)) # flaky test self.assertClose(grad1, grad2, rtol=5e-3) # flaky test
self.assertTrue(torch.allclose(grad1, grad3, rtol=5e-3)) self.assertClose(grad1, grad3, rtol=5e-3)
self.assertTrue(torch.allclose(grad2, grad3, rtol=5e-3)) self.assertClose(grad2, grad3, rtol=5e-3)
def test_compare_coarse_cpu_vs_cuda(self): def test_compare_coarse_cpu_vs_cuda(self):
torch.manual_seed(231) torch.manual_seed(231)
@ -399,9 +401,9 @@ class TestRasterizeMeshes(unittest.TestCase):
idx1, zbuf1, bary1, dist1 = fn1(*args1) idx1, zbuf1, bary1, dist1 = fn1(*args1)
idx2, zbuf2, bary2, dist2 = fn2(*args2) idx2, zbuf2, bary2, dist2 = fn2(*args2)
self.assertTrue((idx1.cpu() == idx2.cpu()).all().item()) self.assertTrue((idx1.cpu() == idx2.cpu()).all().item())
self.assertTrue(torch.allclose(zbuf1.cpu(), zbuf2.cpu(), rtol=1e-4)) self.assertClose(zbuf1.cpu(), zbuf2.cpu(), rtol=1e-4)
self.assertTrue(torch.allclose(dist1.cpu(), dist2.cpu(), rtol=6e-3)) self.assertClose(dist1.cpu(), dist2.cpu(), rtol=6e-3)
self.assertTrue(torch.allclose(bary1.cpu(), bary2.cpu(), rtol=1e-3)) self.assertClose(bary1.cpu(), bary2.cpu(), rtol=1e-3)
if not compare_grads: if not compare_grads:
return return
@ -429,7 +431,7 @@ class TestRasterizeMeshes(unittest.TestCase):
grad_var1.grad.data.zero_() grad_var1.grad.data.zero_()
loss2.backward() loss2.backward()
grad_verts2 = grad_var2.grad.data.clone().cpu() grad_verts2 = grad_var2.grad.data.clone().cpu()
self.assertTrue(torch.allclose(grad_verts1, grad_verts2, rtol=1e-3)) self.assertClose(grad_verts1, grad_verts2, rtol=1e-3)
def _test_perspective_correct( def _test_perspective_correct(
self, rasterize_meshes_fn, device, bin_size=None self, rasterize_meshes_fn, device, bin_size=None
@ -615,8 +617,8 @@ class TestRasterizeMeshes(unittest.TestCase):
zbuf_same = (zbuf == zbuf_expected).all().item() zbuf_same = (zbuf == zbuf_expected).all().item()
self.assertTrue(idx_same) self.assertTrue(idx_same)
self.assertTrue(zbuf_same) self.assertTrue(zbuf_same)
self.assertTrue(torch.allclose(bary, bary_expected)) self.assertClose(bary, bary_expected)
self.assertTrue(torch.allclose(dists, dists_expected)) self.assertClose(dists, dists_expected)
def _simple_triangle_raster(self, raster_fn, device, bin_size=None): def _simple_triangle_raster(self, raster_fn, device, bin_size=None):
image_size = 10 image_size = 10
@ -769,10 +771,10 @@ class TestRasterizeMeshes(unittest.TestCase):
meshes, image_size, 0.0, 2, bin_size meshes, image_size, 0.0, 2, bin_size
) )
self.assertTrue(torch.allclose(p2face[..., 0], expected_p2face_k0)) self.assertClose(p2face[..., 0], expected_p2face_k0)
self.assertTrue(torch.allclose(zbuf[..., 0], expected_zbuf_k0)) self.assertClose(zbuf[..., 0], expected_zbuf_k0)
self.assertTrue(torch.allclose(p2face[..., 1], expected_p2face_k1)) self.assertClose(p2face[..., 1], expected_p2face_k1)
self.assertTrue(torch.allclose(zbuf[..., 1], expected_zbuf_k1)) self.assertClose(zbuf[..., 1], expected_zbuf_k1)
def _simple_blurry_raster(self, raster_fn, device, bin_size=None): def _simple_blurry_raster(self, raster_fn, device, bin_size=None):
""" """
@ -861,12 +863,9 @@ class TestRasterizeMeshes(unittest.TestCase):
p2f[expected_p2f == 0] = order.index(0) p2f[expected_p2f == 0] = order.index(0)
p2f[expected_p2f == 1] = order.index(1) p2f[expected_p2f == 1] = order.index(1)
p2f[expected_p2f == 2] = order.index(2) p2f[expected_p2f == 2] = order.index(2)
self.assertClose(pix_to_face.squeeze(), p2f)
self.assertTrue(torch.allclose(pix_to_face.squeeze(), p2f)) self.assertClose(zbuf.squeeze(), expected_zbuf, rtol=1e-5)
self.assertTrue( self.assertClose(dists, expected_dists)
torch.allclose(zbuf.squeeze(), expected_zbuf, rtol=1e-5)
)
self.assertTrue(torch.allclose(dists, expected_dists))
def _test_coarse_rasterize(self, device): def _test_coarse_rasterize(self, device):
image_size = 16 image_size = 16

View File

@ -224,9 +224,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
self.assertEqual((zbuf1.cpu() == zbuf2.cpu()).all().item(), 1) self.assertEqual((zbuf1.cpu() == zbuf2.cpu()).all().item(), 1)
self.assertClose(dist1.cpu(), dist2.cpu()) self.assertClose(dist1.cpu(), dist2.cpu())
if compare_grads: if compare_grads:
self.assertTrue( self.assertClose(grad_points1, grad_points2, atol=2e-6)
torch.allclose(grad_points1, grad_points2, atol=2e-6)
)
def _test_behind_camera(self, rasterize_points_fn, device, bin_size=None): def _test_behind_camera(self, rasterize_points_fn, device, bin_size=None):
# Test case where all points are behind the camera -- nothing should # Test case where all points are behind the camera -- nothing should
@ -261,7 +259,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
self.assertTrue(idx_same) self.assertTrue(idx_same)
self.assertTrue(zbuf_same) self.assertTrue(zbuf_same)
self.assertTrue(torch.allclose(dists, dists_expected)) self.assertClose(dists, dists_expected)
def _simple_test_case(self, rasterize_points_fn, device, bin_size=0): def _simple_test_case(self, rasterize_points_fn, device, bin_size=0):
# Create two pointclouds with different numbers of points. # Create two pointclouds with different numbers of points.
@ -334,18 +332,18 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
], device=device) ], device=device)
# fmt: on # fmt: on
dists1_expected = torch.full( dists1_expected = torch.zeros(
(1, 5, 5, 2), fill_value=0.0, dtype=torch.float32, device=device (5, 5, 2), dtype=torch.float32, device=device
) )
# fmt: off # fmt: off
dists1_expected[0, :, :, 0] = torch.tensor([ dists1_expected[:, :, 0] = torch.tensor([
[-1.00, -1.00, 0.16, -1.00, -1.00], # noqa: E241 [-1.00, -1.00, 0.16, -1.00, -1.00], # noqa: E241
[-1.00, 0.16, 0.16, 0.16, -1.00], # noqa: E241 [-1.00, 0.16, 0.16, 0.16, -1.00], # noqa: E241
[ 0.16, 0.16, 0.00, 0.16, -1.00], # noqa: E241 E201 [ 0.16, 0.16, 0.00, 0.16, -1.00], # noqa: E241 E201
[-1.00, 0.16, 0.16, -1.00, -1.00], # noqa: E241 [-1.00, 0.16, 0.16, -1.00, -1.00], # noqa: E241
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241 [-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
], device=device) ], device=device)
dists1_expected[0, :, :, 1] = torch.tensor([ dists1_expected[:, :, 1] = torch.tensor([
[-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241 [-1.00, -1.00, -1.00, -1.00, -1.00], # noqa: E241
[-1.00, 0.16, 0.00, -1.00, -1.00], # noqa: E241 [-1.00, 0.16, 0.00, -1.00, -1.00], # noqa: E241
[-1.00, 0.00, 0.16, -1.00, -1.00], # noqa: E241 [-1.00, 0.00, 0.16, -1.00, -1.00], # noqa: E241
@ -370,10 +368,9 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
print(idx[0, :, :, 0]) print(idx[0, :, :, 0])
print(idx[0, :, :, 1]) print(idx[0, :, :, 1])
zbuf_same = (zbuf[0, ...] == zbuf1_expected).all().item() == 1 zbuf_same = (zbuf[0, ...] == zbuf1_expected).all().item() == 1
dist_same = torch.allclose(dists[0, ...], dists1_expected) self.assertClose(dists[0, ...], dists1_expected)
self.assertTrue(idx_same) self.assertTrue(idx_same)
self.assertTrue(zbuf_same) self.assertTrue(zbuf_same)
self.assertTrue(dist_same)
# Check second point cloud - the indices in idx refer to points in the # Check second point cloud - the indices in idx refer to points in the
# pointclouds.points_packed() tensor. In the second point cloud, # pointclouds.points_packed() tensor. In the second point cloud,
@ -387,7 +384,7 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
zbuf_same = (zbuf[1, ...] == zbuf1_expected).all().item() == 1 zbuf_same = (zbuf[1, ...] == zbuf1_expected).all().item() == 1
self.assertTrue(idx_same) self.assertTrue(idx_same)
self.assertTrue(zbuf_same) self.assertTrue(zbuf_same)
self.assertTrue(torch.allclose(dists[1, ...], dists1_expected)) self.assertClose(dists[1, ...], dists1_expected)
def test_coarse_cpu(self): def test_coarse_cpu(self):
return self._test_coarse_rasterize(torch.device("cpu")) return self._test_coarse_rasterize(torch.device("cpu"))

View File

@ -9,8 +9,10 @@ from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
from common_testing import TestCaseMixin
class TestSamplePoints(unittest.TestCase):
class TestSamplePoints(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
super().setUp() super().setUp()
torch.manual_seed(1) torch.manual_seed(1)
@ -98,33 +100,25 @@ class TestSamplePoints(unittest.TestCase):
self.assertEqual(normals.shape, (3, num_samples, 3)) self.assertEqual(normals.shape, (3, num_samples, 3))
# Empty meshes: should have all zeros for samples and normals. # Empty meshes: should have all zeros for samples and normals.
self.assertTrue( self.assertClose(samples[0, :], torch.zeros((num_samples, 3)))
torch.allclose(samples[0, :], torch.zeros((1, num_samples, 3))) self.assertClose(normals[0, :], torch.zeros((num_samples, 3)))
)
self.assertTrue(
torch.allclose(normals[0, :], torch.zeros((1, num_samples, 3)))
)
# Sphere: points should have radius 1. # Sphere: points should have radius 1.
x, y, z = samples[1, :].unbind(1) x, y, z = samples[1, :].unbind(1)
radius = torch.sqrt(x ** 2 + y ** 2 + z ** 2) radius = torch.sqrt(x ** 2 + y ** 2 + z ** 2)
self.assertTrue(torch.allclose(radius, torch.ones((num_samples)))) self.assertClose(radius, torch.ones((num_samples)))
# Pyramid: points shoudl lie on one of the faces. # Pyramid: points shoudl lie on one of the faces.
pyramid_verts = samples[2, :] pyramid_verts = samples[2, :]
pyramid_normals = normals[2, :] pyramid_normals = normals[2, :]
self.assertTrue( self.assertClose(
torch.allclose(
pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts) pyramid_verts.lt(1).float(), torch.ones_like(pyramid_verts)
) )
) self.assertClose(
self.assertTrue(
torch.allclose(
(pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts) (pyramid_verts >= 0).float(), torch.ones_like(pyramid_verts)
) )
)
# Face 1: z = 0, x + y <= 1, normals = (0, 0, 1). # Face 1: z = 0, x + y <= 1, normals = (0, 0, 1).
face_1_idxs = pyramid_verts[:, 2] == 0 face_1_idxs = pyramid_verts[:, 2] == 0
@ -135,14 +129,12 @@ class TestSamplePoints(unittest.TestCase):
self.assertTrue( self.assertTrue(
torch.all((face_1_verts[:, 0] + face_1_verts[:, 1]) <= 1) torch.all((face_1_verts[:, 0] + face_1_verts[:, 1]) <= 1)
) )
self.assertTrue( self.assertClose(
torch.allclose(
face_1_normals, face_1_normals,
torch.tensor([0, 0, 1], dtype=torch.float32).expand( torch.tensor([0, 0, 1], dtype=torch.float32).expand(
face_1_normals.size() face_1_normals.size()
), ),
) )
)
# Face 2: x = 0, z + y <= 1, normals = (1, 0, 0). # Face 2: x = 0, z + y <= 1, normals = (1, 0, 0).
face_2_idxs = pyramid_verts[:, 0] == 0 face_2_idxs = pyramid_verts[:, 0] == 0
@ -153,14 +145,12 @@ class TestSamplePoints(unittest.TestCase):
self.assertTrue( self.assertTrue(
torch.all((face_2_verts[:, 1] + face_2_verts[:, 2]) <= 1) torch.all((face_2_verts[:, 1] + face_2_verts[:, 2]) <= 1)
) )
self.assertTrue( self.assertClose(
torch.allclose(
face_2_normals, face_2_normals,
torch.tensor([1, 0, 0], dtype=torch.float32).expand( torch.tensor([1, 0, 0], dtype=torch.float32).expand(
face_2_normals.size() face_2_normals.size()
), ),
) )
)
# Face 3: y = 0, x + z <= 1, normals = (0, -1, 0). # Face 3: y = 0, x + z <= 1, normals = (0, -1, 0).
face_3_idxs = pyramid_verts[:, 1] == 0 face_3_idxs = pyramid_verts[:, 1] == 0
@ -171,14 +161,12 @@ class TestSamplePoints(unittest.TestCase):
self.assertTrue( self.assertTrue(
torch.all((face_3_verts[:, 0] + face_3_verts[:, 2]) <= 1) torch.all((face_3_verts[:, 0] + face_3_verts[:, 2]) <= 1)
) )
self.assertTrue( self.assertClose(
torch.allclose(
face_3_normals, face_3_normals,
torch.tensor([0, -1, 0], dtype=torch.float32).expand( torch.tensor([0, -1, 0], dtype=torch.float32).expand(
face_3_normals.size() face_3_normals.size()
), ),
) )
)
# Face 4: x + y + z = 1, normals = (1, 1, 1)/sqrt(3). # Face 4: x + y + z = 1, normals = (1, 1, 1)/sqrt(3).
face_4_idxs = pyramid_verts.gt(0).all(1) face_4_idxs = pyramid_verts.gt(0).all(1)
@ -186,22 +174,16 @@ class TestSamplePoints(unittest.TestCase):
pyramid_verts[face_4_idxs, :], pyramid_verts[face_4_idxs, :],
pyramid_normals[face_4_idxs, :], pyramid_normals[face_4_idxs, :],
) )
self.assertTrue( self.assertClose(face_4_verts.sum(1), torch.ones(face_4_verts.size(0)))
torch.allclose( self.assertClose(
face_4_verts.sum(1), torch.ones(face_4_verts.size(0))
)
)
self.assertTrue(
torch.allclose(
face_4_normals, face_4_normals,
( (
torch.tensor([1, 1, 1], dtype=torch.float32) torch.tensor([1, 1, 1], dtype=torch.float32)
/ torch.sqrt(torch.tensor(3, dtype=torch.float32)) / torch.sqrt(torch.tensor(3, dtype=torch.float32))
).expand(face_4_normals.size()), ).expand(face_4_normals.size()),
) )
)
def test_mutinomial(self): def test_multinomial(self):
""" """
Confirm that torch.multinomial does not sample elements which have Confirm that torch.multinomial does not sample elements which have
zero probability. zero probability.

View File

@ -8,8 +8,10 @@ from pytorch3d.ops.subdivide_meshes import SubdivideMeshes
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere from pytorch3d.utils.ico_sphere import ico_sphere
from common_testing import TestCaseMixin
class TestSubdivideMeshes(unittest.TestCase):
class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_subdivide(self): def test_simple_subdivide(self):
# Create a mesh with one face and check the subdivided mesh has # Create a mesh with one face and check the subdivided mesh has
# 4 faces with the correct vertex coordinates. # 4 faces with the correct vertex coordinates.
@ -56,8 +58,8 @@ class TestSubdivideMeshes(unittest.TestCase):
device=device, device=device,
) )
new_verts, new_faces = new_mesh.get_mesh_verts_faces(0) new_verts, new_faces = new_mesh.get_mesh_verts_faces(0)
self.assertTrue(torch.allclose(new_verts, gt_subdivide_verts)) self.assertClose(new_verts, gt_subdivide_verts)
self.assertTrue(torch.allclose(new_faces, gt_subdivide_faces)) self.assertClose(new_faces, gt_subdivide_faces)
self.assertTrue(new_verts.requires_grad == verts.requires_grad) self.assertTrue(new_verts.requires_grad == verts.requires_grad)
def test_heterogeneous_meshes(self): def test_heterogeneous_meshes(self):
@ -185,12 +187,12 @@ class TestSubdivideMeshes(unittest.TestCase):
new_mesh_verts1, new_mesh_faces1 = new_mesh.get_mesh_verts_faces(0) new_mesh_verts1, new_mesh_faces1 = new_mesh.get_mesh_verts_faces(0)
new_mesh_verts2, new_mesh_faces2 = new_mesh.get_mesh_verts_faces(1) new_mesh_verts2, new_mesh_faces2 = new_mesh.get_mesh_verts_faces(1)
new_mesh_verts3, new_mesh_faces3 = new_mesh.get_mesh_verts_faces(2) new_mesh_verts3, new_mesh_faces3 = new_mesh.get_mesh_verts_faces(2)
self.assertTrue(torch.allclose(new_mesh_verts1, gt_subdivided_verts1)) self.assertClose(new_mesh_verts1, gt_subdivided_verts1)
self.assertTrue(torch.allclose(new_mesh_faces1, gt_subdivided_faces1)) self.assertClose(new_mesh_faces1, gt_subdivided_faces1)
self.assertTrue(torch.allclose(new_mesh_verts2, gt_subdivided_verts2)) self.assertClose(new_mesh_verts2, gt_subdivided_verts2)
self.assertTrue(torch.allclose(new_mesh_faces2, gt_subdivided_faces2)) self.assertClose(new_mesh_faces2, gt_subdivided_faces2)
self.assertTrue(torch.allclose(new_mesh_verts3, gt_subdivided_verts3)) self.assertClose(new_mesh_verts3, gt_subdivided_verts3)
self.assertTrue(torch.allclose(new_mesh_faces3, gt_subdivided_faces3)) self.assertClose(new_mesh_faces3, gt_subdivided_faces3)
self.assertTrue(new_mesh_verts1.requires_grad == verts1.requires_grad) self.assertTrue(new_mesh_verts1.requires_grad == verts1.requires_grad)
self.assertTrue(new_mesh_verts2.requires_grad == verts2.requires_grad) self.assertTrue(new_mesh_verts2.requires_grad == verts2.requires_grad)
self.assertTrue(new_mesh_verts3.requires_grad == verts2.requires_grad) self.assertTrue(new_mesh_verts3.requires_grad == verts2.requires_grad)
@ -212,7 +214,7 @@ class TestSubdivideMeshes(unittest.TestCase):
gt_feats = torch.cat( gt_feats = torch.cat(
(feats.view(N, V, D), app_feats.view(N, -1, D)), dim=1 (feats.view(N, V, D), app_feats.view(N, -1, D)), dim=1
).view(-1, D) ).view(-1, D)
self.assertTrue(torch.allclose(new_feats, gt_feats)) self.assertClose(new_feats, gt_feats)
self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad) self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad)
@staticmethod @staticmethod

View File

@ -8,8 +8,10 @@ import torch.nn.functional as F
from pytorch3d.ops.vert_align import vert_align from pytorch3d.ops.vert_align import vert_align
from pytorch3d.structures.meshes import Meshes from pytorch3d.structures.meshes import Meshes
from common_testing import TestCaseMixin
class TestVertAlign(unittest.TestCase):
class TestVertAlign(TestCaseMixin, unittest.TestCase):
@staticmethod @staticmethod
def vert_align_naive( def vert_align_naive(
feats, feats,
@ -103,14 +105,14 @@ class TestVertAlign(unittest.TestCase):
naive_out = TestVertAlign.vert_align_naive( naive_out = TestVertAlign.vert_align_naive(
feats, meshes, return_packed=True feats, meshes, return_packed=True
) )
self.assertTrue(torch.allclose(out, naive_out)) self.assertClose(out, naive_out)
# feats as tensor # feats as tensor
out = vert_align(feats[0], meshes, return_packed=True) out = vert_align(feats[0], meshes, return_packed=True)
naive_out = TestVertAlign.vert_align_naive( naive_out = TestVertAlign.vert_align_naive(
feats[0], meshes, return_packed=True feats[0], meshes, return_packed=True
) )
self.assertTrue(torch.allclose(out, naive_out)) self.assertClose(out, naive_out)
def test_vert_align_with_verts(self): def test_vert_align_with_verts(self):
""" """
@ -130,14 +132,14 @@ class TestVertAlign(unittest.TestCase):
naive_out = TestVertAlign.vert_align_naive( naive_out = TestVertAlign.vert_align_naive(
feats, verts, return_packed=True feats, verts, return_packed=True
) )
self.assertTrue(torch.allclose(out, naive_out)) self.assertClose(out, naive_out)
# feats as tensor # feats as tensor
out = vert_align(feats[0], verts, return_packed=True) out = vert_align(feats[0], verts, return_packed=True)
naive_out = TestVertAlign.vert_align_naive( naive_out = TestVertAlign.vert_align_naive(
feats[0], verts, return_packed=True feats[0], verts, return_packed=True
) )
self.assertTrue(torch.allclose(out, naive_out)) self.assertClose(out, naive_out)
out2 = vert_align( out2 = vert_align(
feats[0], verts, return_packed=True, align_corners=False feats[0], verts, return_packed=True, align_corners=False