test fixes and lints

Summary:
- followup recent pyre change D63415925
- make tests remove temporary files
- weights_only=True in torch.load
- lint fixes

3 test fixes from VRehnberg in https://github.com/facebookresearch/pytorch3d/issues/1914
- imageio channels fix
- frozen decorator in test_config
- load_blobs positional

Reviewed By: MichaelRamamonjisoa

Differential Revision: D66162167

fbshipit-source-id: 7737e174691b62f1708443a4fae07343cec5bfeb
This commit is contained in:
Jeremy Reizenstein
2024-11-20 09:15:51 -08:00
committed by Facebook GitHub Bot
parent c17e6f947a
commit e20cbe9b0e
21 changed files with 48 additions and 45 deletions

View File

@@ -76,7 +76,7 @@ class TestForward(unittest.TestCase):
"test_out",
"test_forward_TestForward_test_bg_weight_hits.png",
),
(hits * 255.0).cpu().to(torch.uint8).numpy(),
(hits * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
)
self.assertEqual(hits[500, 500, 0].item(), 1.0)
self.assertTrue(
@@ -139,7 +139,7 @@ class TestForward(unittest.TestCase):
"test_out",
"test_forward_TestForward_test_basic_3chan_hits.png",
),
(hits * 255.0).cpu().to(torch.uint8).numpy(),
(hits * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
)
self.assertEqual(hits[500, 500, 0].item(), 1.0)
self.assertTrue(
@@ -194,7 +194,7 @@ class TestForward(unittest.TestCase):
"test_out",
"test_forward_TestForward_test_basic_1chan.png",
),
(result * 255.0).cpu().to(torch.uint8).numpy(),
(result * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
)
imageio.imsave(
path.join(
@@ -202,7 +202,7 @@ class TestForward(unittest.TestCase):
"test_out",
"test_forward_TestForward_test_basic_1chan_hits.png",
),
(hits * 255.0).cpu().to(torch.uint8).numpy(),
(hits * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
)
self.assertEqual(hits[500, 500, 0].item(), 1.0)
self.assertTrue(
@@ -264,7 +264,7 @@ class TestForward(unittest.TestCase):
"test_out",
"test_forward_TestForward_test_basic_8chan_hits.png",
),
(hits * 255.0).cpu().to(torch.uint8).numpy(),
(hits * 255.0).cpu().to(torch.uint8).squeeze(2).numpy(),
)
self.assertEqual(hits[500, 500, 0].item(), 1.0)
self.assertTrue(