Avoid to keep in memory lengths and bins for ImplicitronRayBundle

Summary:
Convert ImplicitronRayBundle to a "classic" class instead of a dataclass. This change is introduced as a way to preserve the ImplicitronRayBundle interface while allowing two outcomes:
- init lengths arguments is now a Optional[torch.Tensor] instead of torch.Tensor
- lengths is now a property which returns a `torch.Tensor`. The lengths property will either recompute lengths from bins or return the stored _lengths. `_lenghts` is None if bins is set. It saves us a bit of memory.

Reviewed By: shapovalov

Differential Revision: D46686094

fbshipit-source-id: 3c75c0947216476ebff542b6f552d311024a679b
This commit is contained in:
Emilien Garreau 2023-07-06 02:41:15 -07:00 committed by Facebook GitHub Bot
parent 3d011a9198
commit 9446d91fae
5 changed files with 103 additions and 61 deletions

View File

@ -6,8 +6,6 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@ -29,7 +27,6 @@ class RenderSamplingMode(Enum):
FULL_GRID = "full_grid" FULL_GRID = "full_grid"
@dataclasses.dataclass
class ImplicitronRayBundle: class ImplicitronRayBundle:
""" """
Parametrizes points along projection rays by storing ray `origins`, Parametrizes points along projection rays by storing ray `origins`,
@ -69,53 +66,58 @@ class ImplicitronRayBundle:
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`. lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
pixel_radii_2d: An optional tensor of shape `(..., 1)` pixel_radii_2d: An optional tensor of shape `(..., 1)`
base radii of the conical frustums. base radii of the conical frustums.
Raises:
ValueError: If either bins or lengths are not provided.
ValueError: If bins is provided and the last dim is inferior or equal to 1.
""" """
origins: torch.Tensor def __init__(
directions: torch.Tensor self,
lengths: torch.Tensor
xys: torch.Tensor
camera_ids: Optional[torch.LongTensor] = None
camera_counts: Optional[torch.LongTensor] = None
bins: Optional[torch.Tensor] = None
pixel_radii_2d: Optional[torch.Tensor] = None
@classmethod
def from_bins(
cls,
origins: torch.Tensor, origins: torch.Tensor,
directions: torch.Tensor, directions: torch.Tensor,
bins: torch.Tensor, lengths: Optional[torch.Tensor],
xys: torch.Tensor, xys: torch.Tensor,
**kwargs, camera_ids: Optional[torch.LongTensor] = None,
) -> "ImplicitronRayBundle": camera_counts: Optional[torch.LongTensor] = None,
""" bins: Optional[torch.Tensor] = None,
Creates a new instance from bins instead of lengths. pixel_radii_2d: Optional[torch.Tensor] = None,
):
Attributes: if bins is not None and bins.shape[-1] <= 1:
origins: A tensor of shape `(..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(..., 3)` containing the direction
vectors of sampling rays in world coords. They don't have to be normalized;
they define unit vectors in the respective 1D coordinate systems; see
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
bins: A tensor of shape `(..., num_points_per_ray + 1)`
containing the bins at which the rays are sampled. In this case
lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
Returns:
An instance of ImplicitronRayBundle.
"""
if bins.shape[-1] <= 1:
raise ValueError( raise ValueError(
"The last dim of bins must be at least superior or equal to 2." "The last dim of bins must be at least superior or equal to 2."
) )
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
return cls(origins, directions, lengths, xys, bins=bins, **kwargs) if bins is None and lengths is None:
raise ValueError(
"Please set either bins or lengths to initialize an ImplicitronRayBundle."
)
self.origins = origins
self.directions = directions
self._lengths = lengths if bins is None else None
self.xys = xys
self.bins = bins
self.pixel_radii_2d = pixel_radii_2d
self.camera_ids = camera_ids
self.camera_counts = camera_counts
@property
def lengths(self) -> torch.Tensor:
if self.bins is not None:
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
# pyre-ignore
return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5)
return self._lengths
@lengths.setter
def lengths(self, value):
if self.bins is not None:
raise ValueError(
"If the bins attribute is not None you cannot set the lengths attribute."
)
else:
self._lengths = value
def is_packed(self) -> bool: def is_packed(self) -> bool:
""" """

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import dataclasses import copy
import logging import logging
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -102,12 +102,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
) )
# jitter the initial depths # jitter the initial depths
ray_bundle_t = dataclasses.replace(
ray_bundle, ray_bundle_t = copy.copy(ray_bundle)
lengths=( ray_bundle_t.lengths = (
ray_bundle.lengths ray_bundle.lengths
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
),
) )
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None] states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]

View File

@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
import torch import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
@ -106,14 +108,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
z_vals = z_samples z_vals = z_samples
# Resort by depth. # Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1) z_vals, _ = torch.sort(z_vals, dim=-1)
ray_bundle = copy.copy(input_ray_bundle)
kwargs_ray = dict(vars(input_ray_bundle))
if input_ray_bundle.bins is None: if input_ray_bundle.bins is None:
kwargs_ray["lengths"] = z_vals ray_bundle.lengths = z_vals
return ImplicitronRayBundle(**kwargs_ray) else:
kwargs_ray["bins"] = z_vals ray_bundle.bins = z_vals
del kwargs_ray["lengths"]
return ImplicitronRayBundle.from_bins(**kwargs_ray) return ray_bundle
def apply_blurpool_on_weights(weights) -> torch.Tensor: def apply_blurpool_on_weights(weights) -> torch.Tensor:

View File

@ -236,11 +236,12 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
elif self.cast_ray_bundle_as_cone: elif self.cast_ray_bundle_as_cone:
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width) pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw) pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
return ImplicitronRayBundle.from_bins( return ImplicitronRayBundle(
directions=ray_bundle.directions, directions=ray_bundle.directions,
origins=ray_bundle.origins, origins=ray_bundle.origins,
bins=ray_bundle.lengths, lengths=None,
xys=ray_bundle.xys, xys=ray_bundle.xys,
bins=ray_bundle.lengths,
pixel_radii_2d=pixel_radii_2d, pixel_radii_2d=pixel_radii_2d,
) )

View File

@ -25,23 +25,62 @@ from tests.common_testing import TestCaseMixin
class TestRendererBase(TestCaseMixin, unittest.TestCase): class TestRendererBase(TestCaseMixin, unittest.TestCase):
def test_implicitron_from_bins(self) -> None: def test_implicitron_from_bins(self) -> None:
bins = torch.randn(2, 3, 4, 5) bins = torch.randn(2, 3, 4, 5)
ray_bundle = ImplicitronRayBundle.from_bins( ray_bundle = ImplicitronRayBundle(
origins=None, origins=None,
directions=None, directions=None,
lengths=None,
xys=None, xys=None,
bins=bins, bins=bins,
) )
self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1])) self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
self.assertClose(ray_bundle.bins, bins) self.assertClose(ray_bundle.bins, bins)
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None: def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths(
with self.assertRaises(ValueError): self,
ImplicitronRayBundle.from_bins( ) -> None:
with self.assertRaises(ValueError) as context:
ray_bundle = ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3), origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3), directions=torch.rand(2, 3, 4, 3),
lengths=None,
xys=torch.rand(2, 3, 4, 2), xys=torch.rand(2, 3, 4, 2),
bins=torch.rand(2, 3, 4, 1), bins=torch.rand(2, 3, 4, 1),
) )
ray_bundle.lengths = torch.empty(2)
self.assertEqual(
str(context.exception),
"If the bins attribute is not None you cannot set the lengths attribute.",
)
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
with self.assertRaises(ValueError) as context:
ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3),
lengths=None,
xys=torch.rand(2, 3, 4, 2),
bins=torch.rand(2, 3, 4, 1),
)
self.assertEqual(
str(context.exception),
"The last dim of bins must be at least superior or equal to 2.",
)
def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided(
self,
) -> None:
with self.assertRaises(ValueError) as context:
ImplicitronRayBundle(
origins=torch.rand(2, 3, 4, 3),
directions=torch.rand(2, 3, 4, 3),
lengths=None,
xys=torch.rand(2, 3, 4, 2),
bins=None,
)
self.assertEqual(
str(context.exception),
"Please set either bins or lengths to initialize an ImplicitronRayBundle.",
)
def test_conical_frustum_to_gaussian(self) -> None: def test_conical_frustum_to_gaussian(self) -> None:
origins = torch.zeros(3, 3, 3) origins = torch.zeros(3, 3, 3)