mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
test fixes and lints
Summary: - followup recent pyre change D63415925 - make tests remove temporary files - weights_only=True in torch.load - lint fixes 3 test fixes from VRehnberg in https://github.com/facebookresearch/pytorch3d/issues/1914 - imageio channels fix - frozen decorator in test_config - load_blobs positional Reviewed By: MichaelRamamonjisoa Differential Revision: D66162167 fbshipit-source-id: 7737e174691b62f1708443a4fae07343cec5bfeb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c17e6f947a
commit
e20cbe9b0e
@@ -76,7 +76,7 @@ class TestForward(unittest.TestCase):
|
||||
"test_out",
|
||||
"test_forward_TestForward_test_bg_weight_hits.png",
|
||||
),
|
||||
(hits * 255.0).cpu().to(torch.uint8).numpy(),
|
||||
(hits * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
|
||||
)
|
||||
self.assertEqual(hits[500, 500, 0].item(), 1.0)
|
||||
self.assertTrue(
|
||||
@@ -139,7 +139,7 @@ class TestForward(unittest.TestCase):
|
||||
"test_out",
|
||||
"test_forward_TestForward_test_basic_3chan_hits.png",
|
||||
),
|
||||
(hits * 255.0).cpu().to(torch.uint8).numpy(),
|
||||
(hits * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
|
||||
)
|
||||
self.assertEqual(hits[500, 500, 0].item(), 1.0)
|
||||
self.assertTrue(
|
||||
@@ -194,7 +194,7 @@ class TestForward(unittest.TestCase):
|
||||
"test_out",
|
||||
"test_forward_TestForward_test_basic_1chan.png",
|
||||
),
|
||||
(result * 255.0).cpu().to(torch.uint8).numpy(),
|
||||
(result * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
|
||||
)
|
||||
imageio.imsave(
|
||||
path.join(
|
||||
@@ -202,7 +202,7 @@ class TestForward(unittest.TestCase):
|
||||
"test_out",
|
||||
"test_forward_TestForward_test_basic_1chan_hits.png",
|
||||
),
|
||||
(hits * 255.0).cpu().to(torch.uint8).numpy(),
|
||||
(hits * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
|
||||
)
|
||||
self.assertEqual(hits[500, 500, 0].item(), 1.0)
|
||||
self.assertTrue(
|
||||
@@ -264,7 +264,7 @@ class TestForward(unittest.TestCase):
|
||||
"test_out",
|
||||
"test_forward_TestForward_test_basic_8chan_hits.png",
|
||||
),
|
||||
(hits * 255.0).cpu().to(torch.uint8).numpy(),
|
||||
(hits * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
|
||||
)
|
||||
self.assertEqual(hits[500, 500, 0].item(), 1.0)
|
||||
self.assertTrue(
|
||||
|
||||
Reference in New Issue
Block a user