diff --git a/docs/tutorials/implicitron_volumes.ipynb b/docs/tutorials/implicitron_volumes.ipynb index 605edae6..1af8af1a 100644 --- a/docs/tutorials/implicitron_volumes.ipynb +++ b/docs/tutorials/implicitron_volumes.ipynb @@ -145,10 +145,9 @@ "from pytorch3d.implicitron.dataset.dataset_base import FrameData\n", "from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n", "from pytorch3d.implicitron.models.generic_model import GenericModel\n", - "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n", + "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle\n", "from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n", "from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n", - "from pytorch3d.renderer import RayBundle\n", "from pytorch3d.renderer.implicit.renderer import VolumeSampler\n", "from pytorch3d.structures import Volumes\n", "from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene" @@ -393,7 +392,7 @@ "\n", " def forward(\n", " self,\n", - " ray_bundle: RayBundle,\n", + " ray_bundle: ImplicitronRayBundle,\n", " fun_viewpool=None,\n", " global_code=None,\n", " ):\n", diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index b3dabee2..853e84ef 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -22,6 +22,7 @@ from pytorch3d.implicitron.models.metrics import ( RegularizationMetricsBase, ViewMetricsBase, ) +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools import image_utils, vis_utils from pytorch3d.implicitron.tools.config import ( expand_args_fields, @@ -30,7 +31,8 @@ from pytorch3d.implicitron.tools.config import ( ) from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples from pytorch3d.implicitron.tools.utils import cat_dataclass -from pytorch3d.renderer import RayBundle, utils as rend_utils +from pytorch3d.renderer import utils as rend_utils + from pytorch3d.renderer.cameras import CamerasBase from visdom import Visdom @@ -387,7 +389,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ) # (1) Sample rendering rays with the ray sampler. - ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29] + ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29] target_cameras, evaluation_mode, mask=mask_crop[:n_targets] @@ -568,14 +570,14 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 def _render( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, chunked_inputs: Dict[str, torch.Tensor], sampling_mode: RenderSamplingMode, **kwargs, ) -> RendererOutput: """ Args: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g. SignedDistanceFunctionRenderer requires "object_mask", shape @@ -899,7 +901,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor: def _chunk_generator( chunk_size: int, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, chunked_inputs: Dict[str, torch.Tensor], tqdm_trigger_threshold: int, *args, @@ -932,7 +934,7 @@ def _chunk_generator( for start_idx in iter: end_idx = min(start_idx + chunk_size_in_rays, n_rays) - ray_bundle_chunk = RayBundle( + ray_bundle_chunk = ImplicitronRayBundle( origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx], directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ :, start_idx:end_idx diff --git a/pytorch3d/implicitron/models/implicit_function/base.py b/pytorch3d/implicitron/models/implicit_function/base.py index 2e0c7798..75bd3653 100644 --- a/pytorch3d/implicitron/models/implicit_function/base.py +++ b/pytorch3d/implicitron/models/implicit_function/base.py @@ -7,9 +7,10 @@ from abc import ABC, abstractmethod from typing import Optional +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle + from pytorch3d.implicitron.tools.config import ReplaceableBase from pytorch3d.renderer.cameras import CamerasBase -from pytorch3d.renderer.implicit import RayBundle class ImplicitFunctionBase(ABC, ReplaceableBase): @@ -20,7 +21,7 @@ class ImplicitFunctionBase(ABC, ReplaceableBase): def forward( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, diff --git a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py index 557ba138..f43a2932 100644 --- a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py +++ b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py @@ -6,8 +6,10 @@ import math from typing import Optional, Tuple import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import registry -from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle +from pytorch3d.renderer.implicit import HarmonicEmbedding + from torch import nn from .base import ImplicitFunctionBase @@ -127,7 +129,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): def forward( self, *, - ray_bundle: Optional[RayBundle] = None, + ray_bundle: Optional[ImplicitronRayBundle] = None, rays_points_world: Optional[torch.Tensor] = None, fun_viewpool=None, global_code=None, diff --git a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py index d325c798..aecd9105 100644 --- a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py +++ b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py @@ -9,8 +9,9 @@ from typing import Optional, Tuple import torch from pytorch3d.common.linear_with_repeat import LinearWithRepeat +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import expand_args_fields, registry -from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle +from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.implicit import HarmonicEmbedding @@ -130,7 +131,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): def forward( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, @@ -144,7 +145,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): RGB color and opacity respectively. Args: - ray_bundle: A RayBundle object containing the following variables: + ray_bundle: An ImplicitronRayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the sampling rays in world coords. directions: A tensor of shape `(minibatch, ..., 3)` @@ -165,11 +166,12 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): """ # We first convert the ray parametrizations to world # coordinates with `ray_bundle_to_ray_points`. + # pyre-ignore[6] rays_points_world = ray_bundle_to_ray_points(ray_bundle) # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3] embeds = create_embeddings_for_implicit_function( - xyz_world=ray_bundle_to_ray_points(ray_bundle), + xyz_world=rays_points_world, # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]` # for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`. xyz_embedding_function=self.harmonic_embedding_xyz diff --git a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py index c701c54c..b9e3cc1e 100644 --- a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py +++ b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py @@ -6,9 +6,10 @@ from typing import Any, cast, Optional, Tuple import torch from omegaconf import DictConfig from pytorch3d.common.linear_with_repeat import LinearWithRepeat +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation -from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle +from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.implicit import HarmonicEmbedding @@ -68,7 +69,7 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module): def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, @@ -76,7 +77,7 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module): ): """ Args: - ray_bundle: A RayBundle object containing the following variables: + ray_bundle: An ImplicitronRayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the sampling rays in world coords. directions: A tensor of shape `(minibatch, ..., 3)` @@ -96,10 +97,11 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module): """ # We first convert the ray parametrizations to world # coordinates with `ray_bundle_to_ray_points`. + # pyre-ignore[6] rays_points_world = ray_bundle_to_ray_points(ray_bundle) embeds = create_embeddings_for_implicit_function( - xyz_world=ray_bundle_to_ray_points(ray_bundle), + xyz_world=rays_points_world, # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]` # for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`. xyz_embedding_function=self._harmonic_embedding, @@ -175,7 +177,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module): def forward( self, raymarch_features: torch.Tensor, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, camera: Optional[CamerasBase] = None, **kwargs, ): @@ -183,7 +185,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module): Args: raymarch_features: Features from the raymarching network of shape `(minibatch, ..., self.in_features)` - ray_bundle: A RayBundle object containing the following variables: + ray_bundle: An ImplicitronRayBundle object containing the following variables: origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins of the sampling rays in world coords. directions: A tensor of shape `(minibatch, ..., 3)` @@ -297,7 +299,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module): def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, @@ -350,7 +352,7 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module): def forward( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, @@ -410,7 +412,7 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module): def forward( self, *, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, fun_viewpool=None, camera: Optional[CamerasBase] = None, global_code=None, diff --git a/pytorch3d/implicitron/models/implicit_function/utils.py b/pytorch3d/implicitron/models/implicit_function/utils.py index 9b401c48..e9b688ef 100644 --- a/pytorch3d/implicitron/models/implicit_function/utils.py +++ b/pytorch3d/implicitron/models/implicit_function/utils.py @@ -10,9 +10,9 @@ import torch import torch.nn.functional as F from pytorch3d.common.compat import prod +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer.cameras import CamerasBase -from pytorch3d.renderer.implicit import RayBundle def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor): @@ -190,7 +190,7 @@ def interpolate_volume( def get_rays_points_world( - ray_bundle: Optional[RayBundle] = None, + ray_bundle: Optional[ImplicitronRayBundle] = None, rays_points_world: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ @@ -198,7 +198,7 @@ def get_rays_points_world( and raises error if both are defined. Args: - ray_bundle: A RayBundle object or None + ray_bundle: An ImplicitronRayBundle object or None rays_points_world: A torch.Tensor representing ray points converted to world coordinates Returns: @@ -213,5 +213,6 @@ def get_rays_points_world( if rays_points_world is not None: return rays_points_world if ray_bundle is not None: + # pyre-ignore[6] return ray_bundle_to_ray_points(ray_bundle) raise ValueError("ray_bundle and rays_points_world cannot both be None") diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index f55059aa..27ee1787 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -6,6 +6,8 @@ from __future__ import annotations +import dataclasses + from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum @@ -25,6 +27,38 @@ class RenderSamplingMode(Enum): FULL_GRID = "full_grid" +@dataclasses.dataclass +class ImplicitronRayBundle: + """ + Parametrizes points along projection rays by storing ray `origins`, + `directions` vectors and `lengths` at which the ray-points are sampled. + Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well. + Note that `directions` don't have to be normalized; they define unit vectors + in the respective 1D coordinate systems; see documentation for + :func:`ray_bundle_to_ray_points` for the conversion formula. + + camera_ids: A tensor of shape (N, ) which indicates which camera + was used to sample the rays. `N` is the number of different + sampled cameras. + camera_counts: A tensor of shape (N, ) which how many times the + coresponding camera in `camera_ids` was sampled. + `sum(camera_counts)==minibatch` + """ + + origins: torch.Tensor + directions: torch.Tensor + lengths: torch.Tensor + xys: torch.Tensor + camera_ids: Optional[torch.Tensor] = None + camera_counts: Optional[torch.Tensor] = None + + def is_packed(self) -> bool: + """ + Returns whether the ImplicitronRayBundle carries data in packed state + """ + return self.camera_ids is not None and self.camera_counts is not None + + @dataclass class RendererOutput: """ @@ -85,7 +119,7 @@ class BaseRenderer(ABC, ReplaceableBase): @abstractmethod def forward( self, - ray_bundle, + ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs, @@ -95,7 +129,7 @@ class BaseRenderer(ABC, ReplaceableBase): that returns an instance of RendererOutput. Args: - ray_bundle: A RayBundle object containing the following variables: + ray_bundle: An ImplicitronRayBundle object containing the following variables: origins: A tensor of shape (minibatch, ..., 3) denoting the origins of the rendering rays. directions: A tensor of shape (minibatch, ..., 3) @@ -108,6 +142,12 @@ class BaseRenderer(ABC, ReplaceableBase): xys: A tensor of shape (minibatch, ..., 2) containing the xy locations of each ray's pixel in the NDC screen space. + camera_ids: A tensor of shape (N, ) which indicates which camera + was used to sample the rays. `N` is the number of different + sampled cameras. + camera_counts: A tensor of shape (N, ) which how many times the + coresponding camera in `camera_ids` was sampled. + `sum(camera_counts)==minibatch` implicit_functions: List of ImplicitFunctionWrappers which define the implicit function methods to be used. Most Renderers only allow a single implicit function. Currently, only the diff --git a/pytorch3d/implicitron/models/renderer/lstm_renderer.py b/pytorch3d/implicitron/models/renderer/lstm_renderer.py index c5ce094f..b24c253f 100644 --- a/pytorch3d/implicitron/models/renderer/lstm_renderer.py +++ b/pytorch3d/implicitron/models/renderer/lstm_renderer.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import dataclasses import logging from typing import List, Optional, Tuple import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import registry -from pytorch3d.renderer import RayBundle from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput @@ -71,7 +72,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module): def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs, @@ -79,7 +80,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module): """ Args: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. implicit_functions: A single-element list of ImplicitFunctionWrappers which defines the implicit function to be used. @@ -102,9 +103,12 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module): ) # jitter the initial depths - ray_bundle_t = ray_bundle._replace( - lengths=ray_bundle.lengths - + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std + ray_bundle_t = dataclasses.replace( + ray_bundle, + lengths=( + ray_bundle.lengths + + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std + ), ) states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None] @@ -112,9 +116,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module): raymarch_features = None for t in range(self.num_raymarch_steps + 1): # move signed_distance along each ray - ray_bundle_t = ray_bundle_t._replace( - lengths=ray_bundle_t.lengths + signed_distance - ) + ray_bundle_t.lengths += signed_distance # eval the raymarching function raymarch_features, _ = implicit_function( diff --git a/pytorch3d/implicitron/models/renderer/multipass_ea.py b/pytorch3d/implicitron/models/renderer/multipass_ea.py index 61cf0d4c..648e7f37 100644 --- a/pytorch3d/implicitron/models/renderer/multipass_ea.py +++ b/pytorch3d/implicitron/models/renderer/multipass_ea.py @@ -7,8 +7,8 @@ from typing import List import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import registry, run_auto_creation -from pytorch3d.renderer import RayBundle from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput from .ray_point_refiner import RayPointRefiner @@ -107,14 +107,14 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13 def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs, ) -> RendererOutput: """ Args: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. implicit_functions: List of ImplicitFunctionWrappers which define the implicit functions to be used sequentially in diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index 76f9f5bc..6d3723ad 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -9,10 +9,10 @@ from typing import Optional, Tuple import torch from pytorch3d.implicitron.tools import camera_utils from pytorch3d.implicitron.tools.config import registry, ReplaceableBase -from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle +from pytorch3d.renderer import NDCMultinomialRaysampler from pytorch3d.renderer.cameras import CamerasBase -from .base import EvaluationMode, RenderSamplingMode +from .base import EvaluationMode, ImplicitronRayBundle, RenderSamplingMode class RaySamplerBase(ReplaceableBase): @@ -28,7 +28,7 @@ class RaySamplerBase(ReplaceableBase): cameras: CamerasBase, evaluation_mode: EvaluationMode, mask: Optional[torch.Tensor] = None, - ) -> RayBundle: + ) -> ImplicitronRayBundle: """ Args: cameras: A batch of `batch_size` cameras from which the rays are emitted. @@ -42,7 +42,7 @@ class RaySamplerBase(ReplaceableBase): corresponding pixel's ray. Returns: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. """ raise NotImplementedError() @@ -135,7 +135,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): cameras: CamerasBase, evaluation_mode: EvaluationMode, mask: Optional[torch.Tensor] = None, - ) -> RayBundle: + ) -> ImplicitronRayBundle: """ Args: @@ -150,7 +150,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): corresponding pixel's ray. Returns: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. """ sample_mask = None @@ -180,7 +180,19 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): max_depth=max_depth, ) - return ray_bundle + if isinstance(ray_bundle, tuple): + return ImplicitronRayBundle( + # pyre-ignore[16] + **{k: v for k, v in ray_bundle._asdict().items()} + ) + return ImplicitronRayBundle( + directions=ray_bundle.directions, + origins=ray_bundle.origins, + lengths=ray_bundle.lengths, + xys=ray_bundle.xys, + camera_ids=ray_bundle.camera_ids, + camera_counts=ray_bundle.camera_counts, + ) @registry.register diff --git a/pytorch3d/implicitron/models/renderer/rgb_net.py b/pytorch3d/implicitron/models/renderer/rgb_net.py index 47609e83..6d41d216 100644 --- a/pytorch3d/implicitron/models/renderer/rgb_net.py +++ b/pytorch3d/implicitron/models/renderer/rgb_net.py @@ -7,8 +7,10 @@ import logging from typing import List, Tuple import torch +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import enable_get_default_args -from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle +from pytorch3d.renderer.implicit import HarmonicEmbedding + from torch import nn @@ -89,7 +91,7 @@ class RayNormalColoringNetwork(torch.nn.Module): feature_vectors: torch.Tensor, points, normals, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, masks=None, pooling_fn=None, ): diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py index 2f0e626c..d8782911 100644 --- a/pytorch3d/implicitron/models/renderer/sdf_renderer.py +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -8,13 +8,13 @@ from typing import List, Optional, Tuple import torch from omegaconf import DictConfig from pytorch3d.common.compat import prod +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import ( get_default_args_field, registry, run_auto_creation, ) from pytorch3d.implicitron.tools.utils import evaluating -from pytorch3d.renderer import RayBundle from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput from .ray_tracing import RayTracing @@ -69,7 +69,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign def forward( self, - ray_bundle: RayBundle, + ray_bundle: ImplicitronRayBundle, implicit_functions: List[ImplicitFunctionWrapper], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, object_mask: Optional[torch.Tensor] = None, @@ -77,7 +77,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign ) -> RendererOutput: """ Args: - ray_bundle: A `RayBundle` object containing the parametrizations of the + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. implicit_functions: single element list of ImplicitFunctionWrappers which defines the implicit function to be used. diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index c53754e8..033f783a 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -149,9 +149,8 @@ class MultinomialRaysampler(torch.nn.Module): in __init__. n_rays_total: How many rays in total to sample from the cameras provided. The result is as if `n_rays_total_training` cameras were sampled with replacement from the - cameras provided and for every camera one ray was sampled. If set, this disables - `n_rays_per_image` and returns the HeterogeneousRayBundle with - batch_size=n_rays_total. + cameras provided and for every camera one ray was sampled. If set, returns the + HeterogeneousRayBundle with batch_size=n_rays_total. Returns: A named tuple RayBundle or dataclass HeterogeneousRayBundle with the following fields: @@ -188,9 +187,10 @@ class MultinomialRaysampler(torch.nn.Module): """ n_rays_total = n_rays_total or self._n_rays_total n_rays_per_image = n_rays_per_image or self._n_rays_per_image - assert (n_rays_total is None) or ( - n_rays_per_image is None - ), "`n_rays_total` and `n_rays_per_image` cannot both be defined." + if (n_rays_total is not None) and (n_rays_per_image is not None): + raise ValueError( + "`n_rays_total` and `n_rays_per_image` cannot both be defined." + ) if n_rays_total: ( cameras, @@ -357,9 +357,8 @@ class MonteCarloRaysampler(torch.nn.Module): max_depth: The maximum depth of each ray-point. n_rays_total: How many rays in total to sample from the cameras provided. The result is as if `n_rays_total_training` cameras were sampled with replacement from the - cameras provided and for every camera one ray was sampled. If set, this disables - `n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with - batch_size=n_rays_total. + cameras provided and for every camera one ray was sampled. If set, this returns + the HeterogeneousRayBundleyBundle with batch_size=n_rays_total. unit_directions: whether to normalize direction vectors in ray bundle. stratified_sampling: if True, performs stratified sampling in n_pts_per_ray bins for each ray; otherwise takes n_pts_per_ray deterministic points @@ -416,9 +415,14 @@ class MonteCarloRaysampler(torch.nn.Module): - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled cameras. Represents how many times each camera from `camera_ids` was sampled """ - assert (self._n_rays_total is None) or ( - self._n_rays_per_image is None - ), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined." + if ( + sum(x is not None for x in [self._n_rays_total, self._n_rays_per_image]) + != 1 + ): + raise ValueError( + "Exactly one of `self.n_rays_total` and `self.n_rays_per_image` " + "must be given." + ) if self._n_rays_total: ( diff --git a/pytorch3d/renderer/implicit/renderer.py b/pytorch3d/renderer/implicit/renderer.py index c2be5adc..56583cdb 100644 --- a/pytorch3d/renderer/implicit/renderer.py +++ b/pytorch3d/renderer/implicit/renderer.py @@ -297,6 +297,7 @@ class VolumeSampler(torch.nn.Module): """ Given an input ray parametrization, the forward function samples `self._volumes` at the respective 3D ray-points. + Can also accept ImplicitronRayBundle as argument for ray_bundle. Args: ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields: diff --git a/pytorch3d/vis/plotly_vis.py b/pytorch3d/vis/plotly_vis.py index 776f4768..1cb4985d 100644 --- a/pytorch3d/vis/plotly_vis.py +++ b/pytorch3d/vis/plotly_vis.py @@ -11,6 +11,7 @@ import plotly.graph_objects as go import torch from plotly.subplots import make_subplots from pytorch3d.renderer import ( + HeterogeneousRayBundle, ray_bundle_to_ray_points, RayBundle, TexturesAtlas, @@ -21,14 +22,45 @@ from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.structures import join_meshes_as_scene, Meshes, Pointclouds -Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle] +Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle, HeterogeneousRayBundle] -def _get_struct_len(struct: Struct) -> int: # pragma: no cover +def _get_len(struct: Union[Struct, List[Struct]]) -> int: # pragma: no cover """ Returns the length (usually corresponds to the batch size) of the input structure. """ - return len(struct.directions) if isinstance(struct, RayBundle) else len(struct) + # pyre-ignore[6] + if not _is_ray_bundle(struct): + # pyre-ignore[6] + return len(struct) + if _is_heterogeneous_ray_bundle(struct): + # pyre-ignore[16] + return len(struct.camera_counts) + # pyre-ignore[16] + return len(struct.directions) + + +def _is_ray_bundle(struct: Struct) -> bool: + """ + Args: + struct: Struct object to test + Returns: + True if something is a RayBundle, HeterogeneousRayBundle or + ImplicitronRayBundle, else False + """ + return hasattr(struct, "directions") + + +def _is_heterogeneous_ray_bundle(struct: Union[List[Struct], Struct]) -> bool: + """ + Args: + struct :object to test + Returns: + True if something is a HeterogeneousRayBundle or ImplicitronRayBundle + and cant be reduced to RayBundle else False + """ + # pyre-ignore[16] + return hasattr(struct, "camera_counts") and struct.camera_counts is not None def get_camera_wireframe(scale: float = 0.3): # pragma: no cover @@ -301,7 +333,7 @@ def plot_scene( _add_camera_trace( fig, struct, trace_name, subplot_idx, ncols, camera_scale ) - elif isinstance(struct, RayBundle): + elif _is_ray_bundle(struct): _add_ray_bundle_trace( fig, struct, @@ -316,7 +348,7 @@ def plot_scene( else: raise ValueError( "struct {} is not a Cameras, Meshes, Pointclouds,".format(struct) - + " or RayBundle object." + + " , RayBundle or HeterogeneousRayBundle object." ) # Ensure update for every subplot. @@ -401,15 +433,16 @@ def plot_batch_individually( In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in the input. These will either be rendered in the first subplot (if extend_struct is False), or in every subplot. + RayBundle includes ImplicitronRayBundle and HeterogeneousRaybundle. Args: - batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle - to be rendered. Each structure's corresponding batch element will be - plotted in a single subplot, resulting in n subplots for a batch of size n. - Every struct should either have the same batch size or be of batch size 1. - See extend_struct and the description above for how batch size 1 structs - are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle - object, which will have each individual element plotted in its own subplot. + batched_structs: a list of Cameras, Meshes, Pointclouds and RayBundle to be + rendered. Each structure's corresponding batch element will be plotted in a + single subplot, resulting in n subplots for a batch of size n. Every struct + should either have the same batch size or be of batch size 1. See extend_struct + and the description above for how batch size 1 structs are handled. Also accepts + a single Cameras, Meshes, Pointclouds, and RayBundle object, which will have + each individual element plotted in its own subplot. viewpoint_cameras: an instance of a Cameras object providing a location to view the plotly plot from. If the batch size is equal to the number of subplots, it is a one to one mapping. @@ -451,20 +484,20 @@ def plot_batch_individually( """ # check that every batch is the same size or is size 1 - if len(batched_structs) == 0: + if _get_len(batched_structs) == 0: msg = "No structs to plot" warnings.warn(msg) return max_size = 0 if isinstance(batched_structs, list): - max_size = max(_get_struct_len(s) for s in batched_structs) + max_size = max(_get_len(s) for s in batched_structs) for struct in batched_structs: - struct_len = _get_struct_len(struct) + struct_len = _get_len(struct) if struct_len not in (1, max_size): msg = "invalid batch size {} provided: {}".format(struct_len, struct) raise ValueError(msg) else: - max_size = _get_struct_len(batched_structs) + max_size = _get_len(batched_structs) if max_size == 0: msg = "No data is provided with at least one element" @@ -475,6 +508,14 @@ def plot_batch_individually( msg = "invalid number of subplot titles" raise ValueError(msg) + # if we are dealing with HeterogeneousRayBundle of ImplicitronRayBundle create + # first indexes for faster + first_idxs = None + if _is_heterogeneous_ray_bundle(batched_structs): + # pyre-ignore[16] + cumsum = batched_structs.camera_counts.cumsum(dim=0) + first_idxs = torch.cat((cumsum.new_zeros((1,)), cumsum)) + scene_dictionary = {} # construct the scene dictionary for scene_num in range(max_size): @@ -487,16 +528,30 @@ def plot_batch_individually( if isinstance(batched_structs, list): for i, batched_struct in enumerate(batched_structs): + first_idxs = None + if _is_heterogeneous_ray_bundle(batched_structs[i]): + # pyre-ignore[16] + cumsum = batched_struct.camera_counts.cumsum(dim=0) + first_idxs = torch.cat((cumsum.new_zeros((1,)), cumsum)) # check for whether this struct needs to be extended - batched_struct_len = _get_struct_len(batched_struct) + batched_struct_len = _get_len(batched_struct) if i >= batched_struct_len and not extend_struct: continue _add_struct_from_batch( - batched_struct, scene_num, subplot_title, scene_dictionary, i + 1 + batched_struct, + scene_num, + subplot_title, + scene_dictionary, + i + 1, + first_idxs=first_idxs, ) else: # batched_structs is a single struct _add_struct_from_batch( - batched_structs, scene_num, subplot_title, scene_dictionary + batched_structs, + scene_num, + subplot_title, + scene_dictionary, + first_idxs=first_idxs, ) return plot_scene( @@ -510,6 +565,7 @@ def _add_struct_from_batch( subplot_title: str, scene_dictionary: Dict[str, Dict[str, Struct]], trace_idx: int = 1, + first_idxs: Optional[torch.Tensor] = None, ) -> None: # pragma: no cover """ Adds the struct corresponding to the given scene_num index to @@ -544,17 +600,35 @@ def _add_struct_from_batch( # torch.Tensor, torch.nn.Module]` is not a function. T = T[t_idx].unsqueeze(0) struct = CamerasBase(device=batched_struct.device, R=R, T=T) - elif isinstance(batched_struct, RayBundle): - # for RayBundle we treat the 1st dim as the batch index - struct_idx = min(scene_num, len(batched_struct.lengths) - 1) + elif _is_ray_bundle(batched_struct) and not _is_heterogeneous_ray_bundle( + batched_struct + ): + # for RayBundle we treat the camera count as the batch index + struct_idx = min(scene_num, _get_len(batched_struct) - 1) + struct = RayBundle( **{ attr: getattr(batched_struct, attr)[struct_idx] for attr in ["origins", "directions", "lengths", "xys"] } ) + elif _is_heterogeneous_ray_bundle(batched_struct): + # for RayBundle we treat the camera count as the batch index + struct_idx = min(scene_num, _get_len(batched_struct) - 1) + + struct = RayBundle( + **{ + attr: getattr(batched_struct, attr)[ + # pyre-ignore[16] + first_idxs[struct_idx] : first_idxs[struct_idx + 1] + ] + for attr in ["origins", "directions", "lengths", "xys"] + } + ) + else: # batched meshes and pointclouds are indexable - struct_idx = min(scene_num, len(batched_struct) - 1) + struct_idx = min(scene_num, _get_len(batched_struct) - 1) + # pyre-ignore[16] struct = batched_struct[struct_idx] trace_name = "trace{}-{}".format(scene_num + 1, trace_idx) scene_dictionary[subplot_title][trace_name] = struct @@ -753,7 +827,7 @@ def _add_camera_trace( def _add_ray_bundle_trace( fig: go.Figure, - ray_bundle: RayBundle, + ray_bundle: Union[RayBundle, HeterogeneousRayBundle], trace_name: str, subplot_idx: int, ncols: int, @@ -763,12 +837,13 @@ def _add_ray_bundle_trace( line_width: int, ) -> None: # pragma: no cover """ - Adds a trace rendering a RayBundle object to the passed in figure, with - a given name and in a specific subplot. + Adds a trace rendering a ray bundle object + to the passed in figure, with a given name and in a specific subplot. Args: fig: plotly figure to add the trace within. - cameras: the Cameras object to render. It can be batched. + ray_bundle: the RayBundle, ImplicitronRayBundle or HeterogeneousRaybundle to render. + It can be batched. trace_name: name to label the trace with. subplot_idx: identifies the subplot, with 0 being the top left. ncols: the number of subplots per row. diff --git a/tests/implicitron/test_ray_point_refiner.py b/tests/implicitron/test_ray_point_refiner.py index fb512c24..9373edc2 100644 --- a/tests/implicitron/test_ray_point_refiner.py +++ b/tests/implicitron/test_ray_point_refiner.py @@ -8,7 +8,7 @@ import unittest import torch from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner -from pytorch3d.renderer import RayBundle +from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle from tests.common_testing import TestCaseMixin @@ -24,7 +24,14 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase): add_input_samples=add_input_samples, ) lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length) - bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None) + bundle = ImplicitronRayBundle( + lengths=lengths, + origins=None, + directions=None, + xys=None, + camera_ids=None, + camera_counts=None, + ) weights = torch.ones(3, 25, length) refined = ray_point_refiner(bundle, weights) diff --git a/tests/implicitron/test_srn.py b/tests/implicitron/test_srn.py index f6905ef4..311bbaa6 100644 --- a/tests/implicitron/test_srn.py +++ b/tests/implicitron/test_srn.py @@ -13,8 +13,10 @@ from pytorch3d.implicitron.models.implicit_function.scene_representation_network SRNImplicitFunction, SRNPixelGenerator, ) +from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle from pytorch3d.implicitron.tools.config import get_default_args -from pytorch3d.renderer import PerspectiveCameras, RayBundle +from pytorch3d.renderer import PerspectiveCameras + from tests.common_testing import TestCaseMixin _BATCH_SIZE: int = 3 @@ -31,12 +33,17 @@ class TestSRN(TestCaseMixin, unittest.TestCase): def test_pixel_generator(self): SRNPixelGenerator() - def _get_bundle(self, *, device) -> RayBundle: + def _get_bundle(self, *, device) -> ImplicitronRayBundle: origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device) - bundle = RayBundle( - lengths=lengths, origins=origins, directions=directions, xys=None + bundle = ImplicitronRayBundle( + lengths=lengths, + origins=origins, + directions=directions, + xys=None, + camera_ids=None, + camera_counts=None, ) return bundle