ImplicitronRayBundle

Summary: new implicitronRayBundle with added cameraIDs and camera counts. Added to enable a single raybundle inside Implicitron and easier extension in the future. Since RayBundle is named tuple and RayBundleHeterogeneous is dataclass and RayBundleHeterogeneous cannot inherit RayBundle. So if there was no ImplicitronRayBundle every function that uses RayBundle now would have to use Union[RayBundle, RaybundleHeterogeneous] which is confusing and unecessary complicated.

Reviewed By: bottler, kjchalup

Differential Revision: D39262999

fbshipit-source-id: ece160e32f6c88c3977e408e966789bf8307af59
This commit is contained in:
Darijan Gudelj 2022-10-03 08:36:47 -07:00 committed by Facebook GitHub Bot
parent 6ae863f301
commit ad8907d373
18 changed files with 259 additions and 100 deletions

View File

@ -145,10 +145,9 @@
"from pytorch3d.implicitron.dataset.dataset_base import FrameData\n", "from pytorch3d.implicitron.dataset.dataset_base import FrameData\n",
"from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n", "from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n",
"from pytorch3d.implicitron.models.generic_model import GenericModel\n", "from pytorch3d.implicitron.models.generic_model import GenericModel\n",
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n", "from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle\n",
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n", "from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
"from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n", "from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
"from pytorch3d.renderer import RayBundle\n",
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n", "from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
"from pytorch3d.structures import Volumes\n", "from pytorch3d.structures import Volumes\n",
"from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene" "from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene"
@ -393,7 +392,7 @@
"\n", "\n",
" def forward(\n", " def forward(\n",
" self,\n", " self,\n",
" ray_bundle: RayBundle,\n", " ray_bundle: ImplicitronRayBundle,\n",
" fun_viewpool=None,\n", " fun_viewpool=None,\n",
" global_code=None,\n", " global_code=None,\n",
" ):\n", " ):\n",

View File

@ -22,6 +22,7 @@ from pytorch3d.implicitron.models.metrics import (
RegularizationMetricsBase, RegularizationMetricsBase,
ViewMetricsBase, ViewMetricsBase,
) )
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools import image_utils, vis_utils from pytorch3d.implicitron.tools import image_utils, vis_utils
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
expand_args_fields, expand_args_fields,
@ -30,7 +31,8 @@ from pytorch3d.implicitron.tools.config import (
) )
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.utils import cat_dataclass from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import RayBundle, utils as rend_utils from pytorch3d.renderer import utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom from visdom import Visdom
@ -387,7 +389,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
) )
# (1) Sample rendering rays with the ray sampler. # (1) Sample rendering rays with the ray sampler.
ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29] ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29]
target_cameras, target_cameras,
evaluation_mode, evaluation_mode,
mask=mask_crop[:n_targets] mask=mask_crop[:n_targets]
@ -568,14 +570,14 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
def _render( def _render(
self, self,
*, *,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
chunked_inputs: Dict[str, torch.Tensor], chunked_inputs: Dict[str, torch.Tensor],
sampling_mode: RenderSamplingMode, sampling_mode: RenderSamplingMode,
**kwargs, **kwargs,
) -> RendererOutput: ) -> RendererOutput:
""" """
Args: Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays. sampled rendering rays.
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g. chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
SignedDistanceFunctionRenderer requires "object_mask", shape SignedDistanceFunctionRenderer requires "object_mask", shape
@ -899,7 +901,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor:
def _chunk_generator( def _chunk_generator(
chunk_size: int, chunk_size: int,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
chunked_inputs: Dict[str, torch.Tensor], chunked_inputs: Dict[str, torch.Tensor],
tqdm_trigger_threshold: int, tqdm_trigger_threshold: int,
*args, *args,
@ -932,7 +934,7 @@ def _chunk_generator(
for start_idx in iter: for start_idx in iter:
end_idx = min(start_idx + chunk_size_in_rays, n_rays) end_idx = min(start_idx + chunk_size_in_rays, n_rays)
ray_bundle_chunk = RayBundle( ray_bundle_chunk = ImplicitronRayBundle(
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx], origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
:, start_idx:end_idx :, start_idx:end_idx

View File

@ -7,9 +7,10 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import ReplaceableBase from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle
class ImplicitFunctionBase(ABC, ReplaceableBase): class ImplicitFunctionBase(ABC, ReplaceableBase):
@ -20,7 +21,7 @@ class ImplicitFunctionBase(ABC, ReplaceableBase):
def forward( def forward(
self, self,
*, *,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
fun_viewpool=None, fun_viewpool=None,
camera: Optional[CamerasBase] = None, camera: Optional[CamerasBase] = None,
global_code=None, global_code=None,

View File

@ -6,8 +6,10 @@ import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import registry from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle from pytorch3d.renderer.implicit import HarmonicEmbedding
from torch import nn from torch import nn
from .base import ImplicitFunctionBase from .base import ImplicitFunctionBase
@ -127,7 +129,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
def forward( def forward(
self, self,
*, *,
ray_bundle: Optional[RayBundle] = None, ray_bundle: Optional[ImplicitronRayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None, rays_points_world: Optional[torch.Tensor] = None,
fun_viewpool=None, fun_viewpool=None,
global_code=None, global_code=None,

View File

@ -9,8 +9,9 @@ from typing import Optional, Tuple
import torch import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import expand_args_fields, registry from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding from pytorch3d.renderer.implicit import HarmonicEmbedding
@ -130,7 +131,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
def forward( def forward(
self, self,
*, *,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
fun_viewpool=None, fun_viewpool=None,
camera: Optional[CamerasBase] = None, camera: Optional[CamerasBase] = None,
global_code=None, global_code=None,
@ -144,7 +145,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
RGB color and opacity respectively. RGB color and opacity respectively.
Args: Args:
ray_bundle: A RayBundle object containing the following variables: ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords. origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)` directions: A tensor of shape `(minibatch, ..., 3)`
@ -165,11 +166,12 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
""" """
# We first convert the ray parametrizations to world # We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`. # coordinates with `ray_bundle_to_ray_points`.
# pyre-ignore[6]
rays_points_world = ray_bundle_to_ray_points(ray_bundle) rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3] # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
embeds = create_embeddings_for_implicit_function( embeds = create_embeddings_for_implicit_function(
xyz_world=ray_bundle_to_ray_points(ray_bundle), xyz_world=rays_points_world,
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]` # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`. # for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self.harmonic_embedding_xyz xyz_embedding_function=self.harmonic_embedding_xyz

View File

@ -6,9 +6,10 @@ from typing import Any, cast, Optional, Tuple
import torch import torch
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.common.linear_with_repeat import LinearWithRepeat from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding from pytorch3d.renderer.implicit import HarmonicEmbedding
@ -68,7 +69,7 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
def forward( def forward(
self, self,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
fun_viewpool=None, fun_viewpool=None,
camera: Optional[CamerasBase] = None, camera: Optional[CamerasBase] = None,
global_code=None, global_code=None,
@ -76,7 +77,7 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
): ):
""" """
Args: Args:
ray_bundle: A RayBundle object containing the following variables: ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords. origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)` directions: A tensor of shape `(minibatch, ..., 3)`
@ -96,10 +97,11 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
""" """
# We first convert the ray parametrizations to world # We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`. # coordinates with `ray_bundle_to_ray_points`.
# pyre-ignore[6]
rays_points_world = ray_bundle_to_ray_points(ray_bundle) rays_points_world = ray_bundle_to_ray_points(ray_bundle)
embeds = create_embeddings_for_implicit_function( embeds = create_embeddings_for_implicit_function(
xyz_world=ray_bundle_to_ray_points(ray_bundle), xyz_world=rays_points_world,
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]` # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`. # for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self._harmonic_embedding, xyz_embedding_function=self._harmonic_embedding,
@ -175,7 +177,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
def forward( def forward(
self, self,
raymarch_features: torch.Tensor, raymarch_features: torch.Tensor,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
camera: Optional[CamerasBase] = None, camera: Optional[CamerasBase] = None,
**kwargs, **kwargs,
): ):
@ -183,7 +185,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
Args: Args:
raymarch_features: Features from the raymarching network of shape raymarch_features: Features from the raymarching network of shape
`(minibatch, ..., self.in_features)` `(minibatch, ..., self.in_features)`
ray_bundle: A RayBundle object containing the following variables: ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords. origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)` directions: A tensor of shape `(minibatch, ..., 3)`
@ -297,7 +299,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
def forward( def forward(
self, self,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
fun_viewpool=None, fun_viewpool=None,
camera: Optional[CamerasBase] = None, camera: Optional[CamerasBase] = None,
global_code=None, global_code=None,
@ -350,7 +352,7 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def forward( def forward(
self, self,
*, *,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
fun_viewpool=None, fun_viewpool=None,
camera: Optional[CamerasBase] = None, camera: Optional[CamerasBase] = None,
global_code=None, global_code=None,
@ -410,7 +412,7 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def forward( def forward(
self, self,
*, *,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
fun_viewpool=None, fun_viewpool=None,
camera: Optional[CamerasBase] = None, camera: Optional[CamerasBase] = None,
global_code=None, global_code=None,

View File

@ -10,9 +10,9 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.common.compat import prod from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.renderer import ray_bundle_to_ray_points from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor): def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
@ -190,7 +190,7 @@ def interpolate_volume(
def get_rays_points_world( def get_rays_points_world(
ray_bundle: Optional[RayBundle] = None, ray_bundle: Optional[ImplicitronRayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None, rays_points_world: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@ -198,7 +198,7 @@ def get_rays_points_world(
and raises error if both are defined. and raises error if both are defined.
Args: Args:
ray_bundle: A RayBundle object or None ray_bundle: An ImplicitronRayBundle object or None
rays_points_world: A torch.Tensor representing ray points converted to rays_points_world: A torch.Tensor representing ray points converted to
world coordinates world coordinates
Returns: Returns:
@ -213,5 +213,6 @@ def get_rays_points_world(
if rays_points_world is not None: if rays_points_world is not None:
return rays_points_world return rays_points_world
if ray_bundle is not None: if ray_bundle is not None:
# pyre-ignore[6]
return ray_bundle_to_ray_points(ray_bundle) return ray_bundle_to_ray_points(ray_bundle)
raise ValueError("ray_bundle and rays_points_world cannot both be None") raise ValueError("ray_bundle and rays_points_world cannot both be None")

View File

@ -6,6 +6,8 @@
from __future__ import annotations from __future__ import annotations
import dataclasses
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
@ -25,6 +27,38 @@ class RenderSamplingMode(Enum):
FULL_GRID = "full_grid" FULL_GRID = "full_grid"
@dataclasses.dataclass
class ImplicitronRayBundle:
"""
Parametrizes points along projection rays by storing ray `origins`,
`directions` vectors and `lengths` at which the ray-points are sampled.
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
Note that `directions` don't have to be normalized; they define unit vectors
in the respective 1D coordinate systems; see documentation for
:func:`ray_bundle_to_ray_points` for the conversion formula.
camera_ids: A tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of different
sampled cameras.
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts)==minibatch`
"""
origins: torch.Tensor
directions: torch.Tensor
lengths: torch.Tensor
xys: torch.Tensor
camera_ids: Optional[torch.Tensor] = None
camera_counts: Optional[torch.Tensor] = None
def is_packed(self) -> bool:
"""
Returns whether the ImplicitronRayBundle carries data in packed state
"""
return self.camera_ids is not None and self.camera_counts is not None
@dataclass @dataclass
class RendererOutput: class RendererOutput:
""" """
@ -85,7 +119,7 @@ class BaseRenderer(ABC, ReplaceableBase):
@abstractmethod @abstractmethod
def forward( def forward(
self, self,
ray_bundle, ray_bundle: ImplicitronRayBundle,
implicit_functions: List[ImplicitFunctionWrapper], implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs, **kwargs,
@ -95,7 +129,7 @@ class BaseRenderer(ABC, ReplaceableBase):
that returns an instance of RendererOutput. that returns an instance of RendererOutput.
Args: Args:
ray_bundle: A RayBundle object containing the following variables: ray_bundle: An ImplicitronRayBundle object containing the following variables:
origins: A tensor of shape (minibatch, ..., 3) denoting origins: A tensor of shape (minibatch, ..., 3) denoting
the origins of the rendering rays. the origins of the rendering rays.
directions: A tensor of shape (minibatch, ..., 3) directions: A tensor of shape (minibatch, ..., 3)
@ -108,6 +142,12 @@ class BaseRenderer(ABC, ReplaceableBase):
xys: A tensor of shape xys: A tensor of shape
(minibatch, ..., 2) containing the (minibatch, ..., 2) containing the
xy locations of each ray's pixel in the NDC screen space. xy locations of each ray's pixel in the NDC screen space.
camera_ids: A tensor of shape (N, ) which indicates which camera
was used to sample the rays. `N` is the number of different
sampled cameras.
camera_counts: A tensor of shape (N, ) which how many times the
coresponding camera in `camera_ids` was sampled.
`sum(camera_counts)==minibatch`
implicit_functions: List of ImplicitFunctionWrappers which define the implicit_functions: List of ImplicitFunctionWrappers which define the
implicit function methods to be used. Most Renderers only allow implicit function methods to be used. Most Renderers only allow
a single implicit function. Currently, only the a single implicit function. Currently, only the

View File

@ -4,12 +4,13 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import dataclasses
import logging import logging
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import registry from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
@ -71,7 +72,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
def forward( def forward(
self, self,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
implicit_functions: List[ImplicitFunctionWrapper], implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs, **kwargs,
@ -79,7 +80,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
""" """
Args: Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays. sampled rendering rays.
implicit_functions: A single-element list of ImplicitFunctionWrappers which implicit_functions: A single-element list of ImplicitFunctionWrappers which
defines the implicit function to be used. defines the implicit function to be used.
@ -102,9 +103,12 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
) )
# jitter the initial depths # jitter the initial depths
ray_bundle_t = ray_bundle._replace( ray_bundle_t = dataclasses.replace(
lengths=ray_bundle.lengths ray_bundle,
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std lengths=(
ray_bundle.lengths
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
),
) )
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None] states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
@ -112,9 +116,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
raymarch_features = None raymarch_features = None
for t in range(self.num_raymarch_steps + 1): for t in range(self.num_raymarch_steps + 1):
# move signed_distance along each ray # move signed_distance along each ray
ray_bundle_t = ray_bundle_t._replace( ray_bundle_t.lengths += signed_distance
lengths=ray_bundle_t.lengths + signed_distance
)
# eval the raymarching function # eval the raymarching function
raymarch_features, _ = implicit_function( raymarch_features, _ = implicit_function(

View File

@ -7,8 +7,8 @@
from typing import List from typing import List
import torch import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import registry, run_auto_creation from pytorch3d.implicitron.tools.config import registry, run_auto_creation
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
from .ray_point_refiner import RayPointRefiner from .ray_point_refiner import RayPointRefiner
@ -107,14 +107,14 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
def forward( def forward(
self, self,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
implicit_functions: List[ImplicitFunctionWrapper], implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs, **kwargs,
) -> RendererOutput: ) -> RendererOutput:
""" """
Args: Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays. sampled rendering rays.
implicit_functions: List of ImplicitFunctionWrappers which implicit_functions: List of ImplicitFunctionWrappers which
define the implicit functions to be used sequentially in define the implicit functions to be used sequentially in

View File

@ -9,10 +9,10 @@ from typing import Optional, Tuple
import torch import torch
from pytorch3d.implicitron.tools import camera_utils from pytorch3d.implicitron.tools import camera_utils
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle from pytorch3d.renderer import NDCMultinomialRaysampler
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from .base import EvaluationMode, RenderSamplingMode from .base import EvaluationMode, ImplicitronRayBundle, RenderSamplingMode
class RaySamplerBase(ReplaceableBase): class RaySamplerBase(ReplaceableBase):
@ -28,7 +28,7 @@ class RaySamplerBase(ReplaceableBase):
cameras: CamerasBase, cameras: CamerasBase,
evaluation_mode: EvaluationMode, evaluation_mode: EvaluationMode,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
) -> RayBundle: ) -> ImplicitronRayBundle:
""" """
Args: Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted. cameras: A batch of `batch_size` cameras from which the rays are emitted.
@ -42,7 +42,7 @@ class RaySamplerBase(ReplaceableBase):
corresponding pixel's ray. corresponding pixel's ray.
Returns: Returns:
ray_bundle: A `RayBundle` object containing the parametrizations of the ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays. sampled rendering rays.
""" """
raise NotImplementedError() raise NotImplementedError()
@ -135,7 +135,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
cameras: CamerasBase, cameras: CamerasBase,
evaluation_mode: EvaluationMode, evaluation_mode: EvaluationMode,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
) -> RayBundle: ) -> ImplicitronRayBundle:
""" """
Args: Args:
@ -150,7 +150,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
corresponding pixel's ray. corresponding pixel's ray.
Returns: Returns:
ray_bundle: A `RayBundle` object containing the parametrizations of the ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays. sampled rendering rays.
""" """
sample_mask = None sample_mask = None
@ -180,7 +180,19 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
max_depth=max_depth, max_depth=max_depth,
) )
return ray_bundle if isinstance(ray_bundle, tuple):
return ImplicitronRayBundle(
# pyre-ignore[16]
**{k: v for k, v in ray_bundle._asdict().items()}
)
return ImplicitronRayBundle(
directions=ray_bundle.directions,
origins=ray_bundle.origins,
lengths=ray_bundle.lengths,
xys=ray_bundle.xys,
camera_ids=ray_bundle.camera_ids,
camera_counts=ray_bundle.camera_counts,
)
@registry.register @registry.register

View File

@ -7,8 +7,10 @@ import logging
from typing import List, Tuple from typing import List, Tuple
import torch import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import enable_get_default_args from pytorch3d.implicitron.tools.config import enable_get_default_args
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle from pytorch3d.renderer.implicit import HarmonicEmbedding
from torch import nn from torch import nn
@ -89,7 +91,7 @@ class RayNormalColoringNetwork(torch.nn.Module):
feature_vectors: torch.Tensor, feature_vectors: torch.Tensor,
points, points,
normals, normals,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
masks=None, masks=None,
pooling_fn=None, pooling_fn=None,
): ):

View File

@ -8,13 +8,13 @@ from typing import List, Optional, Tuple
import torch import torch
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.common.compat import prod from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
get_default_args_field, get_default_args_field,
registry, registry,
run_auto_creation, run_auto_creation,
) )
from pytorch3d.implicitron.tools.utils import evaluating from pytorch3d.implicitron.tools.utils import evaluating
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
from .ray_tracing import RayTracing from .ray_tracing import RayTracing
@ -69,7 +69,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
def forward( def forward(
self, self,
ray_bundle: RayBundle, ray_bundle: ImplicitronRayBundle,
implicit_functions: List[ImplicitFunctionWrapper], implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
object_mask: Optional[torch.Tensor] = None, object_mask: Optional[torch.Tensor] = None,
@ -77,7 +77,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
) -> RendererOutput: ) -> RendererOutput:
""" """
Args: Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
sampled rendering rays. sampled rendering rays.
implicit_functions: single element list of ImplicitFunctionWrappers which implicit_functions: single element list of ImplicitFunctionWrappers which
defines the implicit function to be used. defines the implicit function to be used.

View File

@ -149,9 +149,8 @@ class MultinomialRaysampler(torch.nn.Module):
in __init__. in __init__.
n_rays_total: How many rays in total to sample from the cameras provided. The result n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables cameras provided and for every camera one ray was sampled. If set, returns the
`n_rays_per_image` and returns the HeterogeneousRayBundle with HeterogeneousRayBundle with batch_size=n_rays_total.
batch_size=n_rays_total.
Returns: Returns:
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
following fields: following fields:
@ -188,9 +187,10 @@ class MultinomialRaysampler(torch.nn.Module):
""" """
n_rays_total = n_rays_total or self._n_rays_total n_rays_total = n_rays_total or self._n_rays_total
n_rays_per_image = n_rays_per_image or self._n_rays_per_image n_rays_per_image = n_rays_per_image or self._n_rays_per_image
assert (n_rays_total is None) or ( if (n_rays_total is not None) and (n_rays_per_image is not None):
n_rays_per_image is None raise ValueError(
), "`n_rays_total` and `n_rays_per_image` cannot both be defined." "`n_rays_total` and `n_rays_per_image` cannot both be defined."
)
if n_rays_total: if n_rays_total:
( (
cameras, cameras,
@ -357,9 +357,8 @@ class MonteCarloRaysampler(torch.nn.Module):
max_depth: The maximum depth of each ray-point. max_depth: The maximum depth of each ray-point.
n_rays_total: How many rays in total to sample from the cameras provided. The result n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set, this disables cameras provided and for every camera one ray was sampled. If set, this returns
`n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with the HeterogeneousRayBundleyBundle with batch_size=n_rays_total.
batch_size=n_rays_total.
unit_directions: whether to normalize direction vectors in ray bundle. unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
bins for each ray; otherwise takes n_pts_per_ray deterministic points bins for each ray; otherwise takes n_pts_per_ray deterministic points
@ -416,9 +415,14 @@ class MonteCarloRaysampler(torch.nn.Module):
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled - camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
cameras. Represents how many times each camera from `camera_ids` was sampled cameras. Represents how many times each camera from `camera_ids` was sampled
""" """
assert (self._n_rays_total is None) or ( if (
self._n_rays_per_image is None sum(x is not None for x in [self._n_rays_total, self._n_rays_per_image])
), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined." != 1
):
raise ValueError(
"Exactly one of `self.n_rays_total` and `self.n_rays_per_image` "
"must be given."
)
if self._n_rays_total: if self._n_rays_total:
( (

View File

@ -297,6 +297,7 @@ class VolumeSampler(torch.nn.Module):
""" """
Given an input ray parametrization, the forward function samples Given an input ray parametrization, the forward function samples
`self._volumes` at the respective 3D ray-points. `self._volumes` at the respective 3D ray-points.
Can also accept ImplicitronRayBundle as argument for ray_bundle.
Args: Args:
ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields: ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields:

View File

@ -11,6 +11,7 @@ import plotly.graph_objects as go
import torch import torch
from plotly.subplots import make_subplots from plotly.subplots import make_subplots
from pytorch3d.renderer import ( from pytorch3d.renderer import (
HeterogeneousRayBundle,
ray_bundle_to_ray_points, ray_bundle_to_ray_points,
RayBundle, RayBundle,
TexturesAtlas, TexturesAtlas,
@ -21,14 +22,45 @@ from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import join_meshes_as_scene, Meshes, Pointclouds from pytorch3d.structures import join_meshes_as_scene, Meshes, Pointclouds
Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle] Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle, HeterogeneousRayBundle]
def _get_struct_len(struct: Struct) -> int: # pragma: no cover def _get_len(struct: Union[Struct, List[Struct]]) -> int: # pragma: no cover
""" """
Returns the length (usually corresponds to the batch size) of the input structure. Returns the length (usually corresponds to the batch size) of the input structure.
""" """
return len(struct.directions) if isinstance(struct, RayBundle) else len(struct) # pyre-ignore[6]
if not _is_ray_bundle(struct):
# pyre-ignore[6]
return len(struct)
if _is_heterogeneous_ray_bundle(struct):
# pyre-ignore[16]
return len(struct.camera_counts)
# pyre-ignore[16]
return len(struct.directions)
def _is_ray_bundle(struct: Struct) -> bool:
"""
Args:
struct: Struct object to test
Returns:
True if something is a RayBundle, HeterogeneousRayBundle or
ImplicitronRayBundle, else False
"""
return hasattr(struct, "directions")
def _is_heterogeneous_ray_bundle(struct: Union[List[Struct], Struct]) -> bool:
"""
Args:
struct :object to test
Returns:
True if something is a HeterogeneousRayBundle or ImplicitronRayBundle
and cant be reduced to RayBundle else False
"""
# pyre-ignore[16]
return hasattr(struct, "camera_counts") and struct.camera_counts is not None
def get_camera_wireframe(scale: float = 0.3): # pragma: no cover def get_camera_wireframe(scale: float = 0.3): # pragma: no cover
@ -301,7 +333,7 @@ def plot_scene(
_add_camera_trace( _add_camera_trace(
fig, struct, trace_name, subplot_idx, ncols, camera_scale fig, struct, trace_name, subplot_idx, ncols, camera_scale
) )
elif isinstance(struct, RayBundle): elif _is_ray_bundle(struct):
_add_ray_bundle_trace( _add_ray_bundle_trace(
fig, fig,
struct, struct,
@ -316,7 +348,7 @@ def plot_scene(
else: else:
raise ValueError( raise ValueError(
"struct {} is not a Cameras, Meshes, Pointclouds,".format(struct) "struct {} is not a Cameras, Meshes, Pointclouds,".format(struct)
+ " or RayBundle object." + " , RayBundle or HeterogeneousRayBundle object."
) )
# Ensure update for every subplot. # Ensure update for every subplot.
@ -401,15 +433,16 @@ def plot_batch_individually(
In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in
the input. These will either be rendered in the first subplot the input. These will either be rendered in the first subplot
(if extend_struct is False), or in every subplot. (if extend_struct is False), or in every subplot.
RayBundle includes ImplicitronRayBundle and HeterogeneousRaybundle.
Args: Args:
batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle batched_structs: a list of Cameras, Meshes, Pointclouds and RayBundle to be
to be rendered. Each structure's corresponding batch element will be rendered. Each structure's corresponding batch element will be plotted in a
plotted in a single subplot, resulting in n subplots for a batch of size n. single subplot, resulting in n subplots for a batch of size n. Every struct
Every struct should either have the same batch size or be of batch size 1. should either have the same batch size or be of batch size 1. See extend_struct
See extend_struct and the description above for how batch size 1 structs and the description above for how batch size 1 structs are handled. Also accepts
are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle a single Cameras, Meshes, Pointclouds, and RayBundle object, which will have
object, which will have each individual element plotted in its own subplot. each individual element plotted in its own subplot.
viewpoint_cameras: an instance of a Cameras object providing a location viewpoint_cameras: an instance of a Cameras object providing a location
to view the plotly plot from. If the batch size is equal to view the plotly plot from. If the batch size is equal
to the number of subplots, it is a one to one mapping. to the number of subplots, it is a one to one mapping.
@ -451,20 +484,20 @@ def plot_batch_individually(
""" """
# check that every batch is the same size or is size 1 # check that every batch is the same size or is size 1
if len(batched_structs) == 0: if _get_len(batched_structs) == 0:
msg = "No structs to plot" msg = "No structs to plot"
warnings.warn(msg) warnings.warn(msg)
return return
max_size = 0 max_size = 0
if isinstance(batched_structs, list): if isinstance(batched_structs, list):
max_size = max(_get_struct_len(s) for s in batched_structs) max_size = max(_get_len(s) for s in batched_structs)
for struct in batched_structs: for struct in batched_structs:
struct_len = _get_struct_len(struct) struct_len = _get_len(struct)
if struct_len not in (1, max_size): if struct_len not in (1, max_size):
msg = "invalid batch size {} provided: {}".format(struct_len, struct) msg = "invalid batch size {} provided: {}".format(struct_len, struct)
raise ValueError(msg) raise ValueError(msg)
else: else:
max_size = _get_struct_len(batched_structs) max_size = _get_len(batched_structs)
if max_size == 0: if max_size == 0:
msg = "No data is provided with at least one element" msg = "No data is provided with at least one element"
@ -475,6 +508,14 @@ def plot_batch_individually(
msg = "invalid number of subplot titles" msg = "invalid number of subplot titles"
raise ValueError(msg) raise ValueError(msg)
# if we are dealing with HeterogeneousRayBundle of ImplicitronRayBundle create
# first indexes for faster
first_idxs = None
if _is_heterogeneous_ray_bundle(batched_structs):
# pyre-ignore[16]
cumsum = batched_structs.camera_counts.cumsum(dim=0)
first_idxs = torch.cat((cumsum.new_zeros((1,)), cumsum))
scene_dictionary = {} scene_dictionary = {}
# construct the scene dictionary # construct the scene dictionary
for scene_num in range(max_size): for scene_num in range(max_size):
@ -487,16 +528,30 @@ def plot_batch_individually(
if isinstance(batched_structs, list): if isinstance(batched_structs, list):
for i, batched_struct in enumerate(batched_structs): for i, batched_struct in enumerate(batched_structs):
first_idxs = None
if _is_heterogeneous_ray_bundle(batched_structs[i]):
# pyre-ignore[16]
cumsum = batched_struct.camera_counts.cumsum(dim=0)
first_idxs = torch.cat((cumsum.new_zeros((1,)), cumsum))
# check for whether this struct needs to be extended # check for whether this struct needs to be extended
batched_struct_len = _get_struct_len(batched_struct) batched_struct_len = _get_len(batched_struct)
if i >= batched_struct_len and not extend_struct: if i >= batched_struct_len and not extend_struct:
continue continue
_add_struct_from_batch( _add_struct_from_batch(
batched_struct, scene_num, subplot_title, scene_dictionary, i + 1 batched_struct,
scene_num,
subplot_title,
scene_dictionary,
i + 1,
first_idxs=first_idxs,
) )
else: # batched_structs is a single struct else: # batched_structs is a single struct
_add_struct_from_batch( _add_struct_from_batch(
batched_structs, scene_num, subplot_title, scene_dictionary batched_structs,
scene_num,
subplot_title,
scene_dictionary,
first_idxs=first_idxs,
) )
return plot_scene( return plot_scene(
@ -510,6 +565,7 @@ def _add_struct_from_batch(
subplot_title: str, subplot_title: str,
scene_dictionary: Dict[str, Dict[str, Struct]], scene_dictionary: Dict[str, Dict[str, Struct]],
trace_idx: int = 1, trace_idx: int = 1,
first_idxs: Optional[torch.Tensor] = None,
) -> None: # pragma: no cover ) -> None: # pragma: no cover
""" """
Adds the struct corresponding to the given scene_num index to Adds the struct corresponding to the given scene_num index to
@ -544,17 +600,35 @@ def _add_struct_from_batch(
# torch.Tensor, torch.nn.Module]` is not a function. # torch.Tensor, torch.nn.Module]` is not a function.
T = T[t_idx].unsqueeze(0) T = T[t_idx].unsqueeze(0)
struct = CamerasBase(device=batched_struct.device, R=R, T=T) struct = CamerasBase(device=batched_struct.device, R=R, T=T)
elif isinstance(batched_struct, RayBundle): elif _is_ray_bundle(batched_struct) and not _is_heterogeneous_ray_bundle(
# for RayBundle we treat the 1st dim as the batch index batched_struct
struct_idx = min(scene_num, len(batched_struct.lengths) - 1) ):
# for RayBundle we treat the camera count as the batch index
struct_idx = min(scene_num, _get_len(batched_struct) - 1)
struct = RayBundle( struct = RayBundle(
**{ **{
attr: getattr(batched_struct, attr)[struct_idx] attr: getattr(batched_struct, attr)[struct_idx]
for attr in ["origins", "directions", "lengths", "xys"] for attr in ["origins", "directions", "lengths", "xys"]
} }
) )
elif _is_heterogeneous_ray_bundle(batched_struct):
# for RayBundle we treat the camera count as the batch index
struct_idx = min(scene_num, _get_len(batched_struct) - 1)
struct = RayBundle(
**{
attr: getattr(batched_struct, attr)[
# pyre-ignore[16]
first_idxs[struct_idx] : first_idxs[struct_idx + 1]
]
for attr in ["origins", "directions", "lengths", "xys"]
}
)
else: # batched meshes and pointclouds are indexable else: # batched meshes and pointclouds are indexable
struct_idx = min(scene_num, len(batched_struct) - 1) struct_idx = min(scene_num, _get_len(batched_struct) - 1)
# pyre-ignore[16]
struct = batched_struct[struct_idx] struct = batched_struct[struct_idx]
trace_name = "trace{}-{}".format(scene_num + 1, trace_idx) trace_name = "trace{}-{}".format(scene_num + 1, trace_idx)
scene_dictionary[subplot_title][trace_name] = struct scene_dictionary[subplot_title][trace_name] = struct
@ -753,7 +827,7 @@ def _add_camera_trace(
def _add_ray_bundle_trace( def _add_ray_bundle_trace(
fig: go.Figure, fig: go.Figure,
ray_bundle: RayBundle, ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
trace_name: str, trace_name: str,
subplot_idx: int, subplot_idx: int,
ncols: int, ncols: int,
@ -763,12 +837,13 @@ def _add_ray_bundle_trace(
line_width: int, line_width: int,
) -> None: # pragma: no cover ) -> None: # pragma: no cover
""" """
Adds a trace rendering a RayBundle object to the passed in figure, with Adds a trace rendering a ray bundle object
a given name and in a specific subplot. to the passed in figure, with a given name and in a specific subplot.
Args: Args:
fig: plotly figure to add the trace within. fig: plotly figure to add the trace within.
cameras: the Cameras object to render. It can be batched. ray_bundle: the RayBundle, ImplicitronRayBundle or HeterogeneousRaybundle to render.
It can be batched.
trace_name: name to label the trace with. trace_name: name to label the trace with.
subplot_idx: identifies the subplot, with 0 being the top left. subplot_idx: identifies the subplot, with 0 being the top left.
ncols: the number of subplots per row. ncols: the number of subplots per row.

View File

@ -8,7 +8,7 @@ import unittest
import torch import torch
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
from pytorch3d.renderer import RayBundle from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from tests.common_testing import TestCaseMixin from tests.common_testing import TestCaseMixin
@ -24,7 +24,14 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
add_input_samples=add_input_samples, add_input_samples=add_input_samples,
) )
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length) lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None) bundle = ImplicitronRayBundle(
lengths=lengths,
origins=None,
directions=None,
xys=None,
camera_ids=None,
camera_counts=None,
)
weights = torch.ones(3, 25, length) weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights) refined = ray_point_refiner(bundle, weights)

View File

@ -13,8 +13,10 @@ from pytorch3d.implicitron.models.implicit_function.scene_representation_network
SRNImplicitFunction, SRNImplicitFunction,
SRNPixelGenerator, SRNPixelGenerator,
) )
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import get_default_args from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import PerspectiveCameras, RayBundle from pytorch3d.renderer import PerspectiveCameras
from tests.common_testing import TestCaseMixin from tests.common_testing import TestCaseMixin
_BATCH_SIZE: int = 3 _BATCH_SIZE: int = 3
@ -31,12 +33,17 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
def test_pixel_generator(self): def test_pixel_generator(self):
SRNPixelGenerator() SRNPixelGenerator()
def _get_bundle(self, *, device) -> RayBundle: def _get_bundle(self, *, device) -> ImplicitronRayBundle:
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device) lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
bundle = RayBundle( bundle = ImplicitronRayBundle(
lengths=lengths, origins=origins, directions=directions, xys=None lengths=lengths,
origins=origins,
directions=directions,
xys=None,
camera_ids=None,
camera_counts=None,
) )
return bundle return bundle