mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Avoid to keep in memory lengths and bins for ImplicitronRayBundle
Summary: Convert ImplicitronRayBundle to a "classic" class instead of a dataclass. This change is introduced as a way to preserve the ImplicitronRayBundle interface while allowing two outcomes: - init lengths arguments is now a Optional[torch.Tensor] instead of torch.Tensor - lengths is now a property which returns a `torch.Tensor`. The lengths property will either recompute lengths from bins or return the stored _lengths. `_lenghts` is None if bins is set. It saves us a bit of memory. Reviewed By: shapovalov Differential Revision: D46686094 fbshipit-source-id: 3c75c0947216476ebff542b6f552d311024a679b
This commit is contained in:
parent
3d011a9198
commit
9446d91fae
@ -6,8 +6,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
@ -29,7 +27,6 @@ class RenderSamplingMode(Enum):
|
||||
FULL_GRID = "full_grid"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ImplicitronRayBundle:
|
||||
"""
|
||||
Parametrizes points along projection rays by storing ray `origins`,
|
||||
@ -69,53 +66,58 @@ class ImplicitronRayBundle:
|
||||
lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
|
||||
pixel_radii_2d: An optional tensor of shape `(..., 1)`
|
||||
base radii of the conical frustums.
|
||||
|
||||
Raises:
|
||||
ValueError: If either bins or lengths are not provided.
|
||||
ValueError: If bins is provided and the last dim is inferior or equal to 1.
|
||||
"""
|
||||
|
||||
origins: torch.Tensor
|
||||
directions: torch.Tensor
|
||||
lengths: torch.Tensor
|
||||
xys: torch.Tensor
|
||||
camera_ids: Optional[torch.LongTensor] = None
|
||||
camera_counts: Optional[torch.LongTensor] = None
|
||||
bins: Optional[torch.Tensor] = None
|
||||
pixel_radii_2d: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_bins(
|
||||
cls,
|
||||
def __init__(
|
||||
self,
|
||||
origins: torch.Tensor,
|
||||
directions: torch.Tensor,
|
||||
bins: torch.Tensor,
|
||||
lengths: Optional[torch.Tensor],
|
||||
xys: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> "ImplicitronRayBundle":
|
||||
"""
|
||||
Creates a new instance from bins instead of lengths.
|
||||
|
||||
Attributes:
|
||||
origins: A tensor of shape `(..., 3)` denoting the
|
||||
origins of the sampling rays in world coords.
|
||||
directions: A tensor of shape `(..., 3)` containing the direction
|
||||
vectors of sampling rays in world coords. They don't have to be normalized;
|
||||
they define unit vectors in the respective 1D coordinate systems; see
|
||||
documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
|
||||
bins: A tensor of shape `(..., num_points_per_ray + 1)`
|
||||
containing the bins at which the rays are sampled. In this case
|
||||
lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
|
||||
xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
|
||||
kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
|
||||
Returns:
|
||||
An instance of ImplicitronRayBundle.
|
||||
"""
|
||||
|
||||
if bins.shape[-1] <= 1:
|
||||
camera_ids: Optional[torch.LongTensor] = None,
|
||||
camera_counts: Optional[torch.LongTensor] = None,
|
||||
bins: Optional[torch.Tensor] = None,
|
||||
pixel_radii_2d: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if bins is not None and bins.shape[-1] <= 1:
|
||||
raise ValueError(
|
||||
"The last dim of bins must be at least superior or equal to 2."
|
||||
)
|
||||
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
|
||||
lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
|
||||
|
||||
return cls(origins, directions, lengths, xys, bins=bins, **kwargs)
|
||||
if bins is None and lengths is None:
|
||||
raise ValueError(
|
||||
"Please set either bins or lengths to initialize an ImplicitronRayBundle."
|
||||
)
|
||||
|
||||
self.origins = origins
|
||||
self.directions = directions
|
||||
self._lengths = lengths if bins is None else None
|
||||
self.xys = xys
|
||||
self.bins = bins
|
||||
self.pixel_radii_2d = pixel_radii_2d
|
||||
self.camera_ids = camera_ids
|
||||
self.camera_counts = camera_counts
|
||||
|
||||
@property
|
||||
def lengths(self) -> torch.Tensor:
|
||||
if self.bins is not None:
|
||||
# equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
|
||||
# pyre-ignore
|
||||
return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5)
|
||||
return self._lengths
|
||||
|
||||
@lengths.setter
|
||||
def lengths(self, value):
|
||||
if self.bins is not None:
|
||||
raise ValueError(
|
||||
"If the bins attribute is not None you cannot set the lengths attribute."
|
||||
)
|
||||
else:
|
||||
self._lengths = value
|
||||
|
||||
def is_packed(self) -> bool:
|
||||
"""
|
||||
|
@ -4,7 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import dataclasses
|
||||
import copy
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@ -102,12 +102,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
)
|
||||
|
||||
# jitter the initial depths
|
||||
ray_bundle_t = dataclasses.replace(
|
||||
ray_bundle,
|
||||
lengths=(
|
||||
ray_bundle.lengths
|
||||
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
|
||||
),
|
||||
|
||||
ray_bundle_t = copy.copy(ray_bundle)
|
||||
ray_bundle_t.lengths = (
|
||||
ray_bundle.lengths
|
||||
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
|
||||
)
|
||||
|
||||
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
|
||||
|
@ -4,6 +4,8 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
|
||||
@ -106,14 +108,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
|
||||
z_vals = z_samples
|
||||
# Resort by depth.
|
||||
z_vals, _ = torch.sort(z_vals, dim=-1)
|
||||
|
||||
kwargs_ray = dict(vars(input_ray_bundle))
|
||||
ray_bundle = copy.copy(input_ray_bundle)
|
||||
if input_ray_bundle.bins is None:
|
||||
kwargs_ray["lengths"] = z_vals
|
||||
return ImplicitronRayBundle(**kwargs_ray)
|
||||
kwargs_ray["bins"] = z_vals
|
||||
del kwargs_ray["lengths"]
|
||||
return ImplicitronRayBundle.from_bins(**kwargs_ray)
|
||||
ray_bundle.lengths = z_vals
|
||||
else:
|
||||
ray_bundle.bins = z_vals
|
||||
|
||||
return ray_bundle
|
||||
|
||||
|
||||
def apply_blurpool_on_weights(weights) -> torch.Tensor:
|
||||
|
@ -236,11 +236,12 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
elif self.cast_ray_bundle_as_cone:
|
||||
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
|
||||
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
|
||||
return ImplicitronRayBundle.from_bins(
|
||||
return ImplicitronRayBundle(
|
||||
directions=ray_bundle.directions,
|
||||
origins=ray_bundle.origins,
|
||||
bins=ray_bundle.lengths,
|
||||
lengths=None,
|
||||
xys=ray_bundle.xys,
|
||||
bins=ray_bundle.lengths,
|
||||
pixel_radii_2d=pixel_radii_2d,
|
||||
)
|
||||
|
||||
|
@ -25,23 +25,62 @@ from tests.common_testing import TestCaseMixin
|
||||
class TestRendererBase(TestCaseMixin, unittest.TestCase):
|
||||
def test_implicitron_from_bins(self) -> None:
|
||||
bins = torch.randn(2, 3, 4, 5)
|
||||
ray_bundle = ImplicitronRayBundle.from_bins(
|
||||
ray_bundle = ImplicitronRayBundle(
|
||||
origins=None,
|
||||
directions=None,
|
||||
lengths=None,
|
||||
xys=None,
|
||||
bins=bins,
|
||||
)
|
||||
self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
|
||||
self.assertClose(ray_bundle.bins, bins)
|
||||
|
||||
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
ImplicitronRayBundle.from_bins(
|
||||
def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths(
|
||||
self,
|
||||
) -> None:
|
||||
with self.assertRaises(ValueError) as context:
|
||||
ray_bundle = ImplicitronRayBundle(
|
||||
origins=torch.rand(2, 3, 4, 3),
|
||||
directions=torch.rand(2, 3, 4, 3),
|
||||
lengths=None,
|
||||
xys=torch.rand(2, 3, 4, 2),
|
||||
bins=torch.rand(2, 3, 4, 1),
|
||||
)
|
||||
ray_bundle.lengths = torch.empty(2)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"If the bins attribute is not None you cannot set the lengths attribute.",
|
||||
)
|
||||
|
||||
def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
|
||||
with self.assertRaises(ValueError) as context:
|
||||
ImplicitronRayBundle(
|
||||
origins=torch.rand(2, 3, 4, 3),
|
||||
directions=torch.rand(2, 3, 4, 3),
|
||||
lengths=None,
|
||||
xys=torch.rand(2, 3, 4, 2),
|
||||
bins=torch.rand(2, 3, 4, 1),
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"The last dim of bins must be at least superior or equal to 2.",
|
||||
)
|
||||
|
||||
def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided(
|
||||
self,
|
||||
) -> None:
|
||||
with self.assertRaises(ValueError) as context:
|
||||
ImplicitronRayBundle(
|
||||
origins=torch.rand(2, 3, 4, 3),
|
||||
directions=torch.rand(2, 3, 4, 3),
|
||||
lengths=None,
|
||||
xys=torch.rand(2, 3, 4, 2),
|
||||
bins=None,
|
||||
)
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"Please set either bins or lengths to initialize an ImplicitronRayBundle.",
|
||||
)
|
||||
|
||||
def test_conical_frustum_to_gaussian(self) -> None:
|
||||
origins = torch.zeros(3, 3, 3)
|
||||
|
Loading…
x
Reference in New Issue
Block a user