mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Avoid to keep in memory lengths and bins for ImplicitronRayBundle
Summary: Convert ImplicitronRayBundle to a "classic" class instead of a dataclass. This change is introduced as a way to preserve the ImplicitronRayBundle interface while allowing two outcomes: - init lengths arguments is now a Optional[torch.Tensor] instead of torch.Tensor - lengths is now a property which returns a `torch.Tensor`. The lengths property will either recompute lengths from bins or return the stored _lengths. `_lenghts` is None if bins is set. It saves us a bit of memory. Reviewed By: shapovalov Differential Revision: D46686094 fbshipit-source-id: 3c75c0947216476ebff542b6f552d311024a679b
This commit is contained in:
		
							parent
							
								
									3d011a9198
								
							
						
					
					
						commit
						9446d91fae
					
				@ -6,8 +6,6 @@
 | 
			
		||||
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from enum import Enum
 | 
			
		||||
@ -29,7 +27,6 @@ class RenderSamplingMode(Enum):
 | 
			
		||||
    FULL_GRID = "full_grid"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class ImplicitronRayBundle:
 | 
			
		||||
    """
 | 
			
		||||
    Parametrizes points along projection rays by storing ray `origins`,
 | 
			
		||||
@ -69,53 +66,58 @@ class ImplicitronRayBundle:
 | 
			
		||||
            lengths should be equal to the midpoints of bins `(..., num_points_per_ray)`.
 | 
			
		||||
        pixel_radii_2d: An optional tensor of shape `(..., 1)`
 | 
			
		||||
            base radii of the conical frustums.
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
        ValueError: If either bins or lengths are not provided.
 | 
			
		||||
        ValueError: If bins is provided and the last dim is inferior or equal to 1.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    origins: torch.Tensor
 | 
			
		||||
    directions: torch.Tensor
 | 
			
		||||
    lengths: torch.Tensor
 | 
			
		||||
    xys: torch.Tensor
 | 
			
		||||
    camera_ids: Optional[torch.LongTensor] = None
 | 
			
		||||
    camera_counts: Optional[torch.LongTensor] = None
 | 
			
		||||
    bins: Optional[torch.Tensor] = None
 | 
			
		||||
    pixel_radii_2d: Optional[torch.Tensor] = None
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_bins(
 | 
			
		||||
        cls,
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        origins: torch.Tensor,
 | 
			
		||||
        directions: torch.Tensor,
 | 
			
		||||
        bins: torch.Tensor,
 | 
			
		||||
        lengths: Optional[torch.Tensor],
 | 
			
		||||
        xys: torch.Tensor,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> "ImplicitronRayBundle":
 | 
			
		||||
        """
 | 
			
		||||
        Creates a new instance from bins instead of lengths.
 | 
			
		||||
 | 
			
		||||
        Attributes:
 | 
			
		||||
            origins: A tensor of shape `(..., 3)` denoting the
 | 
			
		||||
                origins of the sampling rays in world coords.
 | 
			
		||||
            directions: A tensor of shape `(..., 3)` containing the direction
 | 
			
		||||
                vectors of sampling rays in world coords. They don't have to be normalized;
 | 
			
		||||
                they define unit vectors in the respective 1D coordinate systems; see
 | 
			
		||||
                documentation for :func:`ray_bundle_to_ray_points` for the conversion formula.
 | 
			
		||||
            bins: A tensor of shape `(..., num_points_per_ray + 1)`
 | 
			
		||||
                containing the bins at which the rays are sampled. In this case
 | 
			
		||||
                lengths is equal to the midpoints of bins `(..., num_points_per_ray)`.
 | 
			
		||||
            xys: A tensor of shape `(..., 2)`, the xy-locations (`xys`) of the ray pixels
 | 
			
		||||
            kwargs: Additional arguments passed to the constructor of ImplicitronRayBundle
 | 
			
		||||
        Returns:
 | 
			
		||||
            An instance of ImplicitronRayBundle.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        if bins.shape[-1] <= 1:
 | 
			
		||||
        camera_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        camera_counts: Optional[torch.LongTensor] = None,
 | 
			
		||||
        bins: Optional[torch.Tensor] = None,
 | 
			
		||||
        pixel_radii_2d: Optional[torch.Tensor] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        if bins is not None and bins.shape[-1] <= 1:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "The last dim of bins must be at least superior or equal to 2."
 | 
			
		||||
            )
 | 
			
		||||
        # equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
 | 
			
		||||
        lengths = torch.lerp(bins[..., 1:], bins[..., :-1], 0.5)
 | 
			
		||||
 | 
			
		||||
        return cls(origins, directions, lengths, xys, bins=bins, **kwargs)
 | 
			
		||||
        if bins is None and lengths is None:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Please set either bins or lengths to initialize an ImplicitronRayBundle."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.origins = origins
 | 
			
		||||
        self.directions = directions
 | 
			
		||||
        self._lengths = lengths if bins is None else None
 | 
			
		||||
        self.xys = xys
 | 
			
		||||
        self.bins = bins
 | 
			
		||||
        self.pixel_radii_2d = pixel_radii_2d
 | 
			
		||||
        self.camera_ids = camera_ids
 | 
			
		||||
        self.camera_counts = camera_counts
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def lengths(self) -> torch.Tensor:
 | 
			
		||||
        if self.bins is not None:
 | 
			
		||||
            # equivalent to: 0.5 * (bins[..., 1:] + bins[..., :-1]) but more efficient
 | 
			
		||||
            # pyre-ignore
 | 
			
		||||
            return torch.lerp(self.bins[..., :-1], self.bins[..., 1:], 0.5)
 | 
			
		||||
        return self._lengths
 | 
			
		||||
 | 
			
		||||
    @lengths.setter
 | 
			
		||||
    def lengths(self, value):
 | 
			
		||||
        if self.bins is not None:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "If the bins attribute is not None you cannot set the lengths attribute."
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            self._lengths = value
 | 
			
		||||
 | 
			
		||||
    def is_packed(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
import copy
 | 
			
		||||
import logging
 | 
			
		||||
from typing import List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
@ -102,12 +102,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # jitter the initial depths
 | 
			
		||||
        ray_bundle_t = dataclasses.replace(
 | 
			
		||||
            ray_bundle,
 | 
			
		||||
            lengths=(
 | 
			
		||||
                ray_bundle.lengths
 | 
			
		||||
                + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
 | 
			
		||||
            ),
 | 
			
		||||
 | 
			
		||||
        ray_bundle_t = copy.copy(ray_bundle)
 | 
			
		||||
        ray_bundle_t.lengths = (
 | 
			
		||||
            ray_bundle.lengths
 | 
			
		||||
            + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,8 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
 | 
			
		||||
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
 | 
			
		||||
@ -106,14 +108,13 @@ class RayPointRefiner(Configurable, torch.nn.Module):
 | 
			
		||||
            z_vals = z_samples
 | 
			
		||||
        # Resort by depth.
 | 
			
		||||
        z_vals, _ = torch.sort(z_vals, dim=-1)
 | 
			
		||||
 | 
			
		||||
        kwargs_ray = dict(vars(input_ray_bundle))
 | 
			
		||||
        ray_bundle = copy.copy(input_ray_bundle)
 | 
			
		||||
        if input_ray_bundle.bins is None:
 | 
			
		||||
            kwargs_ray["lengths"] = z_vals
 | 
			
		||||
            return ImplicitronRayBundle(**kwargs_ray)
 | 
			
		||||
        kwargs_ray["bins"] = z_vals
 | 
			
		||||
        del kwargs_ray["lengths"]
 | 
			
		||||
        return ImplicitronRayBundle.from_bins(**kwargs_ray)
 | 
			
		||||
            ray_bundle.lengths = z_vals
 | 
			
		||||
        else:
 | 
			
		||||
            ray_bundle.bins = z_vals
 | 
			
		||||
 | 
			
		||||
        return ray_bundle
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_blurpool_on_weights(weights) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
@ -236,11 +236,12 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
			
		||||
        elif self.cast_ray_bundle_as_cone:
 | 
			
		||||
            pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
 | 
			
		||||
            pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
 | 
			
		||||
            return ImplicitronRayBundle.from_bins(
 | 
			
		||||
            return ImplicitronRayBundle(
 | 
			
		||||
                directions=ray_bundle.directions,
 | 
			
		||||
                origins=ray_bundle.origins,
 | 
			
		||||
                bins=ray_bundle.lengths,
 | 
			
		||||
                lengths=None,
 | 
			
		||||
                xys=ray_bundle.xys,
 | 
			
		||||
                bins=ray_bundle.lengths,
 | 
			
		||||
                pixel_radii_2d=pixel_radii_2d,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,23 +25,62 @@ from tests.common_testing import TestCaseMixin
 | 
			
		||||
class TestRendererBase(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_implicitron_from_bins(self) -> None:
 | 
			
		||||
        bins = torch.randn(2, 3, 4, 5)
 | 
			
		||||
        ray_bundle = ImplicitronRayBundle.from_bins(
 | 
			
		||||
        ray_bundle = ImplicitronRayBundle(
 | 
			
		||||
            origins=None,
 | 
			
		||||
            directions=None,
 | 
			
		||||
            lengths=None,
 | 
			
		||||
            xys=None,
 | 
			
		||||
            bins=bins,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(ray_bundle.lengths, 0.5 * (bins[..., 1:] + bins[..., :-1]))
 | 
			
		||||
        self.assertClose(ray_bundle.bins, bins)
 | 
			
		||||
 | 
			
		||||
    def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            ImplicitronRayBundle.from_bins(
 | 
			
		||||
    def test_implicitron_raise_value_error_bins_is_set_and_try_to_set_lengths(
 | 
			
		||||
        self,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        with self.assertRaises(ValueError) as context:
 | 
			
		||||
            ray_bundle = ImplicitronRayBundle(
 | 
			
		||||
                origins=torch.rand(2, 3, 4, 3),
 | 
			
		||||
                directions=torch.rand(2, 3, 4, 3),
 | 
			
		||||
                lengths=None,
 | 
			
		||||
                xys=torch.rand(2, 3, 4, 2),
 | 
			
		||||
                bins=torch.rand(2, 3, 4, 1),
 | 
			
		||||
            )
 | 
			
		||||
            ray_bundle.lengths = torch.empty(2)
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                str(context.exception),
 | 
			
		||||
                "If the bins attribute is not None you cannot set the lengths attribute.",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_implicitron_raise_value_error_if_bins_dim_equal_1(self) -> None:
 | 
			
		||||
        with self.assertRaises(ValueError) as context:
 | 
			
		||||
            ImplicitronRayBundle(
 | 
			
		||||
                origins=torch.rand(2, 3, 4, 3),
 | 
			
		||||
                directions=torch.rand(2, 3, 4, 3),
 | 
			
		||||
                lengths=None,
 | 
			
		||||
                xys=torch.rand(2, 3, 4, 2),
 | 
			
		||||
                bins=torch.rand(2, 3, 4, 1),
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                str(context.exception),
 | 
			
		||||
                "The last dim of bins must be at least superior or equal to 2.",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_implicitron_raise_value_error_if_neither_bins_or_lengths_provided(
 | 
			
		||||
        self,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        with self.assertRaises(ValueError) as context:
 | 
			
		||||
            ImplicitronRayBundle(
 | 
			
		||||
                origins=torch.rand(2, 3, 4, 3),
 | 
			
		||||
                directions=torch.rand(2, 3, 4, 3),
 | 
			
		||||
                lengths=None,
 | 
			
		||||
                xys=torch.rand(2, 3, 4, 2),
 | 
			
		||||
                bins=None,
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                str(context.exception),
 | 
			
		||||
                "Please set either bins or lengths to initialize an ImplicitronRayBundle.",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_conical_frustum_to_gaussian(self) -> None:
 | 
			
		||||
        origins = torch.zeros(3, 3, 3)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user