mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Avoid to keep in memory lengths and bins for ImplicitronRayBundle
Summary: Convert ImplicitronRayBundle to a "classic" class instead of a dataclass. This change is introduced as a way to preserve the ImplicitronRayBundle interface while allowing two outcomes: - init lengths arguments is now a Optional[torch.Tensor] instead of torch.Tensor - lengths is now a property which returns a `torch.Tensor`. The lengths property will either recompute lengths from bins or return the stored _lengths. `_lenghts` is None if bins is set. It saves us a bit of memory. Reviewed By: shapovalov Differential Revision: D46686094 fbshipit-source-id: 3c75c0947216476ebff542b6f552d311024a679b
This commit is contained in:
parent
3d011a9198
commit
9446d91fae
@ -6,8 +6,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -29,7 +27,6 @@ class RenderSamplingMode(Enum):
|
|||||||
FULL_GRID = "full_grid"
|
FULL_GRID = "full_grid"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class ImplicitronRayBundle:
|
class ImplicitronRayBundle:
|
||||||
"""
|
"""
|
||||||
Parametrizes points along projection rays by storing ray `origins`,
|
Parametrizes points along projection rays by storing ray `origins`,
|
||||||
@ -69,53 +66,58 @@ class ImplicitronRayBundle:
|
|||||||
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
|
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
|
||||||
pixel_radii_2d: An optional tensor of shape `(..., 1)`
|
pixel_radii_2d: An optional tensor of shape `(..., 1)`
|
||||||
base radii of the conical frustums.
|
base radii of the conical frustums.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If either bins or lengths are not provided.
|
||||||
|
ValueError: If bins is provided and the last dim is inferior or equal to 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
origins: torch.Tensor
|
def __init__(
|
||||||
directions: torch.Tensor
|
self,
|
||||||
lengths: torch.Tensor
|
|
||||||
xys: torch.Tensor
|
|
||||||
camera_ids: Optional[torch.LongTensor] = None
|
|
||||||
camera_counts: Optional[torch.LongTensor] = None
|
|
||||||
bins: Optional[torch.Tensor] = None
|
|
||||||
pixel_radii_2d: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_bins(
|
|
||||||
cls,
|
|
||||||
origins: torch.Tensor,
|
origins: torch.Tensor,
|
||||||
directions: torch.Tensor,
|
directions: torch.Tensor,
|
||||||
bins: torch.Tensor,
|
lengths: Optional[torch.Tensor],
|
||||||
xys: torch.Tensor,
|
xys: torch.Tensor,
|
||||||
**kwargs,
|
camera_ids: Optional[torch.LongTensor] = None,
|
||||||
) -> "ImplicitronRayBundle":
|
camera_counts: Optional[torch.LongTensor] = None,
|
||||||
"""
|
bins: Optional[torch.Tensor] = None,
|
||||||
Creates a new instance from bins instead of lengths.
|
pixel_radii_2d: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
Attributes:
|
if bins is not None and bins.shape[-1] <= 1:
|
||||||
origins: A tensor of shape `(..., 3)` denoting the
|
|
||||||
origins of the sampling rays in world coords.
|
|
||||||
directions: A tensor of shape `(..., 3)` containing the direction
|
|
||||||
vectors of sampling rays in world coords. They don't have to be normalized;
|
|
||||||
they define unit vectors in the respective 1D coordinate systems; see
|
|
||||||
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
|
|
||||||
bins: A tensor of shape `(..., num_points_per_ray + 1)`
|
|
||||||
containing the bins at which the rays are sampled. In this case
|
|
||||||
lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
|
|
||||||
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
|
|
||||||
kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
|
|
||||||
Returns:
|
|
||||||
An instance of ImplicitronRayBundle.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if bins.shape[-1] <= 1:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The last dim of bins must be at least superior or equal to 2."
|
"The last dim of bins must be at least superior or equal to 2."
|
||||||
)
|
)
|
||||||
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
|
|
||||||
lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
|
|
||||||
|
|
||||||
return cls(origins, directions, lengths, xys, bins=bins, **kwargs)
|
if bins is None and lengths is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Please set either bins or lengths to initialize an ImplicitronRayBundle."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.origins = origins
|
||||||
|
self.directions = directions
|
||||||
|
self._lengths = lengths if bins is None else None
|
||||||
|
self.xys = xys
|
||||||
|
self.bins = bins
|
||||||
|
self.pixel_radii_2d = pixel_radii_2d
|
||||||
|
self.camera_ids = camera_ids
|
||||||
|
self.camera_counts = camera_counts
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lengths(self) -> torch.Tensor:
|
||||||
|
if self.bins is not None:
|
||||||
|
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
|
||||||
|
# pyre-ignore
|
||||||
|
return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5)
|
||||||
|
return self._lengths
|
||||||
|
|
||||||
|
@lengths.setter
|
||||||
|
def lengths(self, value):
|
||||||
|
if self.bins is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"If the bins attribute is not None you cannot set the lengths attribute."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._lengths = value
|
||||||
|
|
||||||
def is_packed(self) -> bool:
|
def is_packed(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import dataclasses
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@ -102,12 +102,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# jitter the initial depths
|
# jitter the initial depths
|
||||||
ray_bundle_t = dataclasses.replace(
|
|
||||||
ray_bundle,
|
ray_bundle_t = copy.copy(ray_bundle)
|
||||||
lengths=(
|
ray_bundle_t.lengths = (
|
||||||
ray_bundle.lengths
|
ray_bundle.lengths
|
||||||
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
|
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
|
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
|
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
|
||||||
@ -106,14 +108,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
|||||||
z_vals = z_samples
|
z_vals = z_samples
|
||||||
# Resort by depth.
|
# Resort by depth.
|
||||||
z_vals, _ = torch.sort(z_vals, dim=-1)
|
z_vals, _ = torch.sort(z_vals, dim=-1)
|
||||||
|
ray_bundle = copy.copy(input_ray_bundle)
|
||||||
kwargs_ray = dict(vars(input_ray_bundle))
|
|
||||||
if input_ray_bundle.bins is None:
|
if input_ray_bundle.bins is None:
|
||||||
kwargs_ray["lengths"] = z_vals
|
ray_bundle.lengths = z_vals
|
||||||
return ImplicitronRayBundle(**kwargs_ray)
|
else:
|
||||||
kwargs_ray["bins"] = z_vals
|
ray_bundle.bins = z_vals
|
||||||
del kwargs_ray["lengths"]
|
|
||||||
return ImplicitronRayBundle.from_bins(**kwargs_ray)
|
return ray_bundle
|
||||||
|
|
||||||
|
|
||||||
def apply_blurpool_on_weights(weights) -> torch.Tensor:
|
def apply_blurpool_on_weights(weights) -> torch.Tensor:
|
||||||
|
@ -236,11 +236,12 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
elif self.cast_ray_bundle_as_cone:
|
elif self.cast_ray_bundle_as_cone:
|
||||||
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
|
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
|
||||||
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
|
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
|
||||||
return ImplicitronRayBundle.from_bins(
|
return ImplicitronRayBundle(
|
||||||
directions=ray_bundle.directions,
|
directions=ray_bundle.directions,
|
||||||
origins=ray_bundle.origins,
|
origins=ray_bundle.origins,
|
||||||
bins=ray_bundle.lengths,
|
lengths=None,
|
||||||
xys=ray_bundle.xys,
|
xys=ray_bundle.xys,
|
||||||
|
bins=ray_bundle.lengths,
|
||||||
pixel_radii_2d=pixel_radii_2d,
|
pixel_radii_2d=pixel_radii_2d,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,23 +25,62 @@ from tests.common_testing import TestCaseMixin
|
|||||||
class TestRendererBase(TestCaseMixin, unittest.TestCase):
|
class TestRendererBase(TestCaseMixin, unittest.TestCase):
|
||||||
def test_implicitron_from_bins(self) -> None:
|
def test_implicitron_from_bins(self) -> None:
|
||||||
bins = torch.randn(2, 3, 4, 5)
|
bins = torch.randn(2, 3, 4, 5)
|
||||||
ray_bundle = ImplicitronRayBundle.from_bins(
|
ray_bundle = ImplicitronRayBundle(
|
||||||
origins=None,
|
origins=None,
|
||||||
directions=None,
|
directions=None,
|
||||||
|
lengths=None,
|
||||||
xys=None,
|
xys=None,
|
||||||
bins=bins,
|
bins=bins,
|
||||||
)
|
)
|
||||||
self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
|
self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
|
||||||
self.assertClose(ray_bundle.bins, bins)
|
self.assertClose(ray_bundle.bins, bins)
|
||||||
|
|
||||||
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
|
def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths(
|
||||||
with self.assertRaises(ValueError):
|
self,
|
||||||
ImplicitronRayBundle.from_bins(
|
) -> None:
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
ray_bundle = ImplicitronRayBundle(
|
||||||
origins=torch.rand(2, 3, 4, 3),
|
origins=torch.rand(2, 3, 4, 3),
|
||||||
directions=torch.rand(2, 3, 4, 3),
|
directions=torch.rand(2, 3, 4, 3),
|
||||||
|
lengths=None,
|
||||||
xys=torch.rand(2, 3, 4, 2),
|
xys=torch.rand(2, 3, 4, 2),
|
||||||
bins=torch.rand(2, 3, 4, 1),
|
bins=torch.rand(2, 3, 4, 1),
|
||||||
)
|
)
|
||||||
|
ray_bundle.lengths = torch.empty(2)
|
||||||
|
self.assertEqual(
|
||||||
|
str(context.exception),
|
||||||
|
"If the bins attribute is not None you cannot set the lengths attribute.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
ImplicitronRayBundle(
|
||||||
|
origins=torch.rand(2, 3, 4, 3),
|
||||||
|
directions=torch.rand(2, 3, 4, 3),
|
||||||
|
lengths=None,
|
||||||
|
xys=torch.rand(2, 3, 4, 2),
|
||||||
|
bins=torch.rand(2, 3, 4, 1),
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
str(context.exception),
|
||||||
|
"The last dim of bins must be at least superior or equal to 2.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
ImplicitronRayBundle(
|
||||||
|
origins=torch.rand(2, 3, 4, 3),
|
||||||
|
directions=torch.rand(2, 3, 4, 3),
|
||||||
|
lengths=None,
|
||||||
|
xys=torch.rand(2, 3, 4, 2),
|
||||||
|
bins=None,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
str(context.exception),
|
||||||
|
"Please set either bins or lengths to initialize an ImplicitronRayBundle.",
|
||||||
|
)
|
||||||
|
|
||||||
def test_conical_frustum_to_gaussian(self) -> None:
|
def test_conical_frustum_to_gaussian(self) -> None:
|
||||||
origins = torch.zeros(3, 3, 3)
|
origins = torch.zeros(3, 3, 3)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user