use assertClose

Summary: use assertClose in some tests, which enforces shape equality. Fixes some small problems, including graph_conv on an empty graph.

Reviewed By: nikhilaravi

Differential Revision: D20556912

fbshipit-source-id: 60a61eafe3c03ce0f6c9c1a842685708fb10ac5b
This commit is contained in:
Jeremy Reizenstein
2020-03-23 11:33:10 -07:00
committed by Facebook GitHub Bot
parent 744ef0c2c8
commit 595aca27ea
13 changed files with 216 additions and 241 deletions

View File

@@ -64,14 +64,10 @@ class TestLights(TestCaseMixin, unittest.TestCase):
# Update element
color = (0.5, 0.5, 0.5)
light[1].ambient_color = color
self.assertTrue(
torch.allclose(light.ambient_color[1], torch.tensor(color))
)
self.assertClose(light.ambient_color[1], torch.tensor(color))
# Get item and get value
l0 = light[0]
self.assertTrue(
torch.allclose(l0.ambient_color, torch.tensor((0.0, 0.0, 0.0)))
)
self.assertClose(l0.ambient_color, torch.tensor((0.0, 0.0, 0.0)))
def test_initialize_lights_broadcast(self):
light = DirectionalLights(
@@ -127,7 +123,7 @@ class TestLights(TestCaseMixin, unittest.TestCase):
PointLights(location=torch.randn(10, 4))
class TestDiffuseLighting(unittest.TestCase):
class TestDiffuseLighting(TestCaseMixin, unittest.TestCase):
def test_diffuse_directional_lights(self):
"""
Test with a single point where:
@@ -145,17 +141,17 @@ class TestDiffuseLighting(unittest.TestCase):
[1 / np.sqrt(2), 1 / np.sqrt(2), 1 / np.sqrt(2)],
dtype=torch.float32,
)
expected_output = expected_output.view(-1, 1, 3)
expected_output = expected_output.view(1, 1, 3).repeat(3, 1, 1)
light = DirectionalLights(diffuse_color=color, direction=direction)
output_light = light.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
# Change light direction to be 90 degrees apart from normal direction.
direction = torch.tensor([0, 1, 0], dtype=torch.float32)
light.direction = direction
expected_output = torch.zeros_like(expected_output)
output_light = light.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
def test_diffuse_point_lights(self):
"""
@@ -183,7 +179,7 @@ class TestDiffuseLighting(unittest.TestCase):
output_light = light.diffuse(
points=points[None, None, :], normals=normals[None, None, :]
)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
# Change light direction to be 90 degrees apart from normal direction.
location = torch.tensor([0, 1, 0], dtype=torch.float32)
@@ -194,7 +190,7 @@ class TestDiffuseLighting(unittest.TestCase):
output_light = light.diffuse(
points=points[None, None, :], normals=normals[None, None, :]
)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
def test_diffuse_batched(self):
"""
@@ -220,7 +216,7 @@ class TestDiffuseLighting(unittest.TestCase):
lights = DirectionalLights(diffuse_color=color, direction=direction)
output_light = lights.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_out))
self.assertClose(output_light, expected_out)
def test_diffuse_batched_broadcast_inputs(self):
"""
@@ -250,7 +246,7 @@ class TestDiffuseLighting(unittest.TestCase):
lights = DirectionalLights(diffuse_color=color, direction=direction)
output_light = lights.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_out))
self.assertClose(output_light, expected_out)
def test_diffuse_batched_arbitrary_input_dims(self):
"""
@@ -280,7 +276,7 @@ class TestDiffuseLighting(unittest.TestCase):
lights = DirectionalLights(diffuse_color=color, direction=direction)
output_light = lights.diffuse(normals=normals)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
def test_diffuse_batched_packed(self):
"""
@@ -311,10 +307,10 @@ class TestDiffuseLighting(unittest.TestCase):
direction=direction[mesh_to_vert_idx, :],
)
output_light = lights.diffuse(normals=normals[mesh_to_vert_idx, :])
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
class TestSpecularLighting(unittest.TestCase):
class TestSpecularLighting(TestCaseMixin, unittest.TestCase):
def test_specular_directional_lights(self):
"""
Specular highlights depend on the camera position as well as the light
@@ -337,7 +333,7 @@ class TestSpecularLighting(unittest.TestCase):
points = torch.tensor([0, 0, 0], dtype=torch.float32)
normals = torch.tensor([0, 1, 0], dtype=torch.float32)
expected_output = torch.tensor([1.0, 0.0, 1.0], dtype=torch.float32)
expected_output = expected_output.view(-1, 1, 3)
expected_output = expected_output.view(1, 1, 3).repeat(3, 1, 1)
lights = DirectionalLights(specular_color=color, direction=direction)
output_light = lights.specular(
points=points[None, None, :],
@@ -345,7 +341,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :],
shininess=torch.tensor(10),
)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
# Change camera position to be behind the point.
camera_position = torch.tensor(
@@ -358,7 +354,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :],
shininess=torch.tensor(10),
)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
def test_specular_point_lights(self):
"""
@@ -386,7 +382,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :],
shininess=torch.tensor(10),
)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
# Change camera position to be behind the point
camera_position = torch.tensor(
@@ -399,7 +395,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :],
shininess=torch.tensor(10),
)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
# Change camera direction to be 30 degrees from the reflection direction
camera_position = torch.tensor(
@@ -418,7 +414,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[None, :],
shininess=torch.tensor(10),
)
self.assertTrue(torch.allclose(output_light, expected_output ** 10))
self.assertClose(output_light, expected_output ** 10)
def test_specular_batched(self):
batch_size = 10
@@ -448,7 +444,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position,
shininess=torch.tensor(10),
)
self.assertTrue(torch.allclose(output_light, expected_out))
self.assertClose(output_light, expected_out)
def test_specular_batched_broadcast_inputs(self):
batch_size = 10
@@ -481,7 +477,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position,
shininess=torch.tensor(10),
)
self.assertTrue(torch.allclose(output_light, expected_out))
self.assertClose(output_light, expected_out)
def test_specular_batched_arbitrary_input_dims(self):
"""
@@ -520,7 +516,7 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position,
shininess=10.0,
)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)
def test_specular_batched_packed(self):
"""
@@ -557,4 +553,4 @@ class TestSpecularLighting(unittest.TestCase):
camera_position=camera_position[mesh_to_vert_idx, :],
shininess=10.0,
)
self.assertTrue(torch.allclose(output_light, expected_output))
self.assertClose(output_light, expected_output)