Avoid to keep in memory lengths and bins for ImplicitronRayBundle

Summary:
Convert ImplicitronRayBundle to a "classic" class instead of a dataclass. This change is introduced as a way to preserve the ImplicitronRayBundle interface while allowing two outcomes:
- init lengths arguments is now a Optional[torch.Tensor] instead of torch.Tensor
- lengths is now a property which returns a `torch.Tensor`. The lengths property will either recompute lengths from bins or return the stored _lengths. `_lenghts` is None if bins is set. It saves us a bit of memory.

Reviewed By: shapovalov

Differential Revision: D46686094

fbshipit-source-id: 3c75c0947216476ebff542b6f552d311024a679b
This commit is contained in:
Emilien Garreau
2023-07-06 02:41:15 -07:00
committed by Facebook GitHub Bot
parent 3d011a9198
commit 9446d91fae
5 changed files with 103 additions and 61 deletions

View File

@@ -25,23 +25,62 @@ from tests.common_testing import TestCaseMixin
class TestRendererBase(TestCaseMixin, unittest.TestCase):
def test_implicitron_from_bins(self) -> None:
bins = torch.randn(2, 3, 4, 5)
ray_bundle = ImplicitronRayBundle.from_bins(
ray_bundle = ImplicitronRayBundle(
origins=None,
directions=None,
lengths=None,
xys=None,
bins=bins,
)
self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
self.assertClose(ray_bundle.bins, bins)
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
with self.assertRaises(ValueError):
ImplicitronRayBundle.from_bins(
def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths(
self,
) -> None:
with self.assertRaises(ValueError) as context:
ray_bundle = ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3),
lengths=None,
xys=torch.rand(2, 3, 4, 2),
bins=torch.rand(2, 3, 4, 1),
)
ray_bundle.lengths = torch.empty(2)
self.assertEqual(
str(context.exception),
"If the bins attribute is not None you cannot set the lengths attribute.",
)
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
with self.assertRaises(ValueError) as context:
ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3),
lengths=None,
xys=torch.rand(2, 3, 4, 2),
bins=torch.rand(2, 3, 4, 1),
)
self.assertEqual(
str(context.exception),
"The last dim of bins must be at least superior or equal to 2.",
)
def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided(
self,
) -> None:
with self.assertRaises(ValueError) as context:
ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3),
lengths=None,
xys=torch.rand(2, 3, 4, 2),
bins=None,
)
self.assertEqual(
str(context.exception),
"Please set either bins or lengths to initialize an ImplicitronRayBundle.",
)
def test_conical_frustum_to_gaussian(self) -> None:
origins = torch.zeros(3, 3, 3)