mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
ImplicitronRayBundle
Summary: new implicitronRayBundle with added cameraIDs and camera counts. Added to enable a single raybundle inside Implicitron and easier extension in the future. Since RayBundle is named tuple and RayBundleHeterogeneous is dataclass and RayBundleHeterogeneous cannot inherit RayBundle. So if there was no ImplicitronRayBundle every function that uses RayBundle now would have to use Union[RayBundle, RaybundleHeterogeneous] which is confusing and unecessary complicated. Reviewed By: bottler, kjchalup Differential Revision: D39262999 fbshipit-source-id: ece160e32f6c88c3977e408e966789bf8307af59
This commit is contained in:
parent
6ae863f301
commit
ad8907d373
@ -145,10 +145,9 @@
|
|||||||
"from pytorch3d.implicitron.dataset.dataset_base import FrameData\n",
|
"from pytorch3d.implicitron.dataset.dataset_base import FrameData\n",
|
||||||
"from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n",
|
"from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n",
|
||||||
"from pytorch3d.implicitron.models.generic_model import GenericModel\n",
|
"from pytorch3d.implicitron.models.generic_model import GenericModel\n",
|
||||||
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n",
|
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle\n",
|
||||||
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
|
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
|
||||||
"from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
|
"from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components\n",
|
||||||
"from pytorch3d.renderer import RayBundle\n",
|
|
||||||
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
|
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
|
||||||
"from pytorch3d.structures import Volumes\n",
|
"from pytorch3d.structures import Volumes\n",
|
||||||
"from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene"
|
"from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene"
|
||||||
@ -393,7 +392,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" def forward(\n",
|
" def forward(\n",
|
||||||
" self,\n",
|
" self,\n",
|
||||||
" ray_bundle: RayBundle,\n",
|
" ray_bundle: ImplicitronRayBundle,\n",
|
||||||
" fun_viewpool=None,\n",
|
" fun_viewpool=None,\n",
|
||||||
" global_code=None,\n",
|
" global_code=None,\n",
|
||||||
" ):\n",
|
" ):\n",
|
||||||
|
@ -22,6 +22,7 @@ from pytorch3d.implicitron.models.metrics import (
|
|||||||
RegularizationMetricsBase,
|
RegularizationMetricsBase,
|
||||||
ViewMetricsBase,
|
ViewMetricsBase,
|
||||||
)
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
expand_args_fields,
|
expand_args_fields,
|
||||||
@ -30,7 +31,8 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
||||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
||||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
from pytorch3d.renderer import utils as rend_utils
|
||||||
|
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from visdom import Visdom
|
from visdom import Visdom
|
||||||
|
|
||||||
@ -387,7 +389,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
)
|
)
|
||||||
|
|
||||||
# (1) Sample rendering rays with the ray sampler.
|
# (1) Sample rendering rays with the ray sampler.
|
||||||
ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29]
|
ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29]
|
||||||
target_cameras,
|
target_cameras,
|
||||||
evaluation_mode,
|
evaluation_mode,
|
||||||
mask=mask_crop[:n_targets]
|
mask=mask_crop[:n_targets]
|
||||||
@ -568,14 +570,14 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
def _render(
|
def _render(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
chunked_inputs: Dict[str, torch.Tensor],
|
chunked_inputs: Dict[str, torch.Tensor],
|
||||||
sampling_mode: RenderSamplingMode,
|
sampling_mode: RenderSamplingMode,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> RendererOutput:
|
) -> RendererOutput:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
|
||||||
sampled rendering rays.
|
sampled rendering rays.
|
||||||
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
|
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
|
||||||
SignedDistanceFunctionRenderer requires "object_mask", shape
|
SignedDistanceFunctionRenderer requires "object_mask", shape
|
||||||
@ -899,7 +901,7 @@ def _tensor_collator(batch, new_dims) -> torch.Tensor:
|
|||||||
|
|
||||||
def _chunk_generator(
|
def _chunk_generator(
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
chunked_inputs: Dict[str, torch.Tensor],
|
chunked_inputs: Dict[str, torch.Tensor],
|
||||||
tqdm_trigger_threshold: int,
|
tqdm_trigger_threshold: int,
|
||||||
*args,
|
*args,
|
||||||
@ -932,7 +934,7 @@ def _chunk_generator(
|
|||||||
|
|
||||||
for start_idx in iter:
|
for start_idx in iter:
|
||||||
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
|
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
|
||||||
ray_bundle_chunk = RayBundle(
|
ray_bundle_chunk = ImplicitronRayBundle(
|
||||||
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
|
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
|
||||||
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
|
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
|
||||||
:, start_idx:end_idx
|
:, start_idx:end_idx
|
||||||
|
@ -7,9 +7,10 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.renderer.implicit import RayBundle
|
|
||||||
|
|
||||||
|
|
||||||
class ImplicitFunctionBase(ABC, ReplaceableBase):
|
class ImplicitFunctionBase(ABC, ReplaceableBase):
|
||||||
@ -20,7 +21,7 @@ class ImplicitFunctionBase(ABC, ReplaceableBase):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
fun_viewpool=None,
|
fun_viewpool=None,
|
||||||
camera: Optional[CamerasBase] = None,
|
camera: Optional[CamerasBase] = None,
|
||||||
global_code=None,
|
global_code=None,
|
||||||
|
@ -6,8 +6,10 @@ import math
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from .base import ImplicitFunctionBase
|
from .base import ImplicitFunctionBase
|
||||||
@ -127,7 +129,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
ray_bundle: Optional[RayBundle] = None,
|
ray_bundle: Optional[ImplicitronRayBundle] = None,
|
||||||
rays_points_world: Optional[torch.Tensor] = None,
|
rays_points_world: Optional[torch.Tensor] = None,
|
||||||
fun_viewpool=None,
|
fun_viewpool=None,
|
||||||
global_code=None,
|
global_code=None,
|
||||||
|
@ -9,8 +9,9 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||||
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
|
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
|
||||||
@ -130,7 +131,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
fun_viewpool=None,
|
fun_viewpool=None,
|
||||||
camera: Optional[CamerasBase] = None,
|
camera: Optional[CamerasBase] = None,
|
||||||
global_code=None,
|
global_code=None,
|
||||||
@ -144,7 +145,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
RGB color and opacity respectively.
|
RGB color and opacity respectively.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A RayBundle object containing the following variables:
|
ray_bundle: An ImplicitronRayBundle object containing the following variables:
|
||||||
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||||
origins of the sampling rays in world coords.
|
origins of the sampling rays in world coords.
|
||||||
directions: A tensor of shape `(minibatch, ..., 3)`
|
directions: A tensor of shape `(minibatch, ..., 3)`
|
||||||
@ -165,11 +166,12 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
# We first convert the ray parametrizations to world
|
# We first convert the ray parametrizations to world
|
||||||
# coordinates with `ray_bundle_to_ray_points`.
|
# coordinates with `ray_bundle_to_ray_points`.
|
||||||
|
# pyre-ignore[6]
|
||||||
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||||
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
|
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
|
||||||
|
|
||||||
embeds = create_embeddings_for_implicit_function(
|
embeds = create_embeddings_for_implicit_function(
|
||||||
xyz_world=ray_bundle_to_ray_points(ray_bundle),
|
xyz_world=rays_points_world,
|
||||||
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
|
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
|
||||||
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
|
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
|
||||||
xyz_embedding_function=self.harmonic_embedding_xyz
|
xyz_embedding_function=self.harmonic_embedding_xyz
|
||||||
|
@ -6,9 +6,10 @@ from typing import Any, cast, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
|
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
|
||||||
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
||||||
from pytorch3d.renderer import ray_bundle_to_ray_points, RayBundle
|
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
|
||||||
@ -68,7 +69,7 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
fun_viewpool=None,
|
fun_viewpool=None,
|
||||||
camera: Optional[CamerasBase] = None,
|
camera: Optional[CamerasBase] = None,
|
||||||
global_code=None,
|
global_code=None,
|
||||||
@ -76,7 +77,7 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A RayBundle object containing the following variables:
|
ray_bundle: An ImplicitronRayBundle object containing the following variables:
|
||||||
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||||
origins of the sampling rays in world coords.
|
origins of the sampling rays in world coords.
|
||||||
directions: A tensor of shape `(minibatch, ..., 3)`
|
directions: A tensor of shape `(minibatch, ..., 3)`
|
||||||
@ -96,10 +97,11 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
# We first convert the ray parametrizations to world
|
# We first convert the ray parametrizations to world
|
||||||
# coordinates with `ray_bundle_to_ray_points`.
|
# coordinates with `ray_bundle_to_ray_points`.
|
||||||
|
# pyre-ignore[6]
|
||||||
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||||
|
|
||||||
embeds = create_embeddings_for_implicit_function(
|
embeds = create_embeddings_for_implicit_function(
|
||||||
xyz_world=ray_bundle_to_ray_points(ray_bundle),
|
xyz_world=rays_points_world,
|
||||||
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
|
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
|
||||||
# for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
|
# for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
|
||||||
xyz_embedding_function=self._harmonic_embedding,
|
xyz_embedding_function=self._harmonic_embedding,
|
||||||
@ -175,7 +177,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
raymarch_features: torch.Tensor,
|
raymarch_features: torch.Tensor,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
camera: Optional[CamerasBase] = None,
|
camera: Optional[CamerasBase] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -183,7 +185,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
raymarch_features: Features from the raymarching network of shape
|
raymarch_features: Features from the raymarching network of shape
|
||||||
`(minibatch, ..., self.in_features)`
|
`(minibatch, ..., self.in_features)`
|
||||||
ray_bundle: A RayBundle object containing the following variables:
|
ray_bundle: An ImplicitronRayBundle object containing the following variables:
|
||||||
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||||
origins of the sampling rays in world coords.
|
origins of the sampling rays in world coords.
|
||||||
directions: A tensor of shape `(minibatch, ..., 3)`
|
directions: A tensor of shape `(minibatch, ..., 3)`
|
||||||
@ -297,7 +299,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
fun_viewpool=None,
|
fun_viewpool=None,
|
||||||
camera: Optional[CamerasBase] = None,
|
camera: Optional[CamerasBase] = None,
|
||||||
global_code=None,
|
global_code=None,
|
||||||
@ -350,7 +352,7 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
fun_viewpool=None,
|
fun_viewpool=None,
|
||||||
camera: Optional[CamerasBase] = None,
|
camera: Optional[CamerasBase] = None,
|
||||||
global_code=None,
|
global_code=None,
|
||||||
@ -410,7 +412,7 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
fun_viewpool=None,
|
fun_viewpool=None,
|
||||||
camera: Optional[CamerasBase] = None,
|
camera: Optional[CamerasBase] = None,
|
||||||
global_code=None,
|
global_code=None,
|
||||||
|
@ -10,9 +10,9 @@ import torch
|
|||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pytorch3d.common.compat import prod
|
from pytorch3d.common.compat import prod
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.renderer import ray_bundle_to_ray_points
|
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.renderer.implicit import RayBundle
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
|
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
|
||||||
@ -190,7 +190,7 @@ def interpolate_volume(
|
|||||||
|
|
||||||
|
|
||||||
def get_rays_points_world(
|
def get_rays_points_world(
|
||||||
ray_bundle: Optional[RayBundle] = None,
|
ray_bundle: Optional[ImplicitronRayBundle] = None,
|
||||||
rays_points_world: Optional[torch.Tensor] = None,
|
rays_points_world: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@ -198,7 +198,7 @@ def get_rays_points_world(
|
|||||||
and raises error if both are defined.
|
and raises error if both are defined.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A RayBundle object or None
|
ray_bundle: An ImplicitronRayBundle object or None
|
||||||
rays_points_world: A torch.Tensor representing ray points converted to
|
rays_points_world: A torch.Tensor representing ray points converted to
|
||||||
world coordinates
|
world coordinates
|
||||||
Returns:
|
Returns:
|
||||||
@ -213,5 +213,6 @@ def get_rays_points_world(
|
|||||||
if rays_points_world is not None:
|
if rays_points_world is not None:
|
||||||
return rays_points_world
|
return rays_points_world
|
||||||
if ray_bundle is not None:
|
if ray_bundle is not None:
|
||||||
|
# pyre-ignore[6]
|
||||||
return ray_bundle_to_ray_points(ray_bundle)
|
return ray_bundle_to_ray_points(ray_bundle)
|
||||||
raise ValueError("ray_bundle and rays_points_world cannot both be None")
|
raise ValueError("ray_bundle and rays_points_world cannot both be None")
|
||||||
|
@ -6,6 +6,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -25,6 +27,38 @@ class RenderSamplingMode(Enum):
|
|||||||
FULL_GRID = "full_grid"
|
FULL_GRID = "full_grid"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ImplicitronRayBundle:
|
||||||
|
"""
|
||||||
|
Parametrizes points along projection rays by storing ray `origins`,
|
||||||
|
`directions` vectors and `lengths` at which the ray-points are sampled.
|
||||||
|
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
|
||||||
|
Note that `directions` don't have to be normalized; they define unit vectors
|
||||||
|
in the respective 1D coordinate systems; see documentation for
|
||||||
|
:func:`ray_bundle_to_ray_points` for the conversion formula.
|
||||||
|
|
||||||
|
camera_ids: A tensor of shape (N, ) which indicates which camera
|
||||||
|
was used to sample the rays. `N` is the number of different
|
||||||
|
sampled cameras.
|
||||||
|
camera_counts: A tensor of shape (N, ) which how many times the
|
||||||
|
coresponding camera in `camera_ids` was sampled.
|
||||||
|
`sum(camera_counts)==minibatch`
|
||||||
|
"""
|
||||||
|
|
||||||
|
origins: torch.Tensor
|
||||||
|
directions: torch.Tensor
|
||||||
|
lengths: torch.Tensor
|
||||||
|
xys: torch.Tensor
|
||||||
|
camera_ids: Optional[torch.Tensor] = None
|
||||||
|
camera_counts: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def is_packed(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether the ImplicitronRayBundle carries data in packed state
|
||||||
|
"""
|
||||||
|
return self.camera_ids is not None and self.camera_counts is not None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RendererOutput:
|
class RendererOutput:
|
||||||
"""
|
"""
|
||||||
@ -85,7 +119,7 @@ class BaseRenderer(ABC, ReplaceableBase):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
implicit_functions: List[ImplicitFunctionWrapper],
|
implicit_functions: List[ImplicitFunctionWrapper],
|
||||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -95,7 +129,7 @@ class BaseRenderer(ABC, ReplaceableBase):
|
|||||||
that returns an instance of RendererOutput.
|
that returns an instance of RendererOutput.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A RayBundle object containing the following variables:
|
ray_bundle: An ImplicitronRayBundle object containing the following variables:
|
||||||
origins: A tensor of shape (minibatch, ..., 3) denoting
|
origins: A tensor of shape (minibatch, ..., 3) denoting
|
||||||
the origins of the rendering rays.
|
the origins of the rendering rays.
|
||||||
directions: A tensor of shape (minibatch, ..., 3)
|
directions: A tensor of shape (minibatch, ..., 3)
|
||||||
@ -108,6 +142,12 @@ class BaseRenderer(ABC, ReplaceableBase):
|
|||||||
xys: A tensor of shape
|
xys: A tensor of shape
|
||||||
(minibatch, ..., 2) containing the
|
(minibatch, ..., 2) containing the
|
||||||
xy locations of each ray's pixel in the NDC screen space.
|
xy locations of each ray's pixel in the NDC screen space.
|
||||||
|
camera_ids: A tensor of shape (N, ) which indicates which camera
|
||||||
|
was used to sample the rays. `N` is the number of different
|
||||||
|
sampled cameras.
|
||||||
|
camera_counts: A tensor of shape (N, ) which how many times the
|
||||||
|
coresponding camera in `camera_ids` was sampled.
|
||||||
|
`sum(camera_counts)==minibatch`
|
||||||
implicit_functions: List of ImplicitFunctionWrappers which define the
|
implicit_functions: List of ImplicitFunctionWrappers which define the
|
||||||
implicit function methods to be used. Most Renderers only allow
|
implicit function methods to be used. Most Renderers only allow
|
||||||
a single implicit function. Currently, only the
|
a single implicit function. Currently, only the
|
||||||
|
@ -4,12 +4,13 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
from pytorch3d.renderer import RayBundle
|
|
||||||
|
|
||||||
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
||||||
|
|
||||||
@ -71,7 +72,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
implicit_functions: List[ImplicitFunctionWrapper],
|
implicit_functions: List[ImplicitFunctionWrapper],
|
||||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -79,7 +80,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
|
||||||
sampled rendering rays.
|
sampled rendering rays.
|
||||||
implicit_functions: A single-element list of ImplicitFunctionWrappers which
|
implicit_functions: A single-element list of ImplicitFunctionWrappers which
|
||||||
defines the implicit function to be used.
|
defines the implicit function to be used.
|
||||||
@ -102,9 +103,12 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# jitter the initial depths
|
# jitter the initial depths
|
||||||
ray_bundle_t = ray_bundle._replace(
|
ray_bundle_t = dataclasses.replace(
|
||||||
lengths=ray_bundle.lengths
|
ray_bundle,
|
||||||
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
|
lengths=(
|
||||||
|
ray_bundle.lengths
|
||||||
|
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
|
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
|
||||||
@ -112,9 +116,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
raymarch_features = None
|
raymarch_features = None
|
||||||
for t in range(self.num_raymarch_steps + 1):
|
for t in range(self.num_raymarch_steps + 1):
|
||||||
# move signed_distance along each ray
|
# move signed_distance along each ray
|
||||||
ray_bundle_t = ray_bundle_t._replace(
|
ray_bundle_t.lengths += signed_distance
|
||||||
lengths=ray_bundle_t.lengths + signed_distance
|
|
||||||
)
|
|
||||||
|
|
||||||
# eval the raymarching function
|
# eval the raymarching function
|
||||||
raymarch_features, _ = implicit_function(
|
raymarch_features, _ = implicit_function(
|
||||||
|
@ -7,8 +7,8 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
||||||
from pytorch3d.renderer import RayBundle
|
|
||||||
|
|
||||||
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
||||||
from .ray_point_refiner import RayPointRefiner
|
from .ray_point_refiner import RayPointRefiner
|
||||||
@ -107,14 +107,14 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
implicit_functions: List[ImplicitFunctionWrapper],
|
implicit_functions: List[ImplicitFunctionWrapper],
|
||||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> RendererOutput:
|
) -> RendererOutput:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
|
||||||
sampled rendering rays.
|
sampled rendering rays.
|
||||||
implicit_functions: List of ImplicitFunctionWrappers which
|
implicit_functions: List of ImplicitFunctionWrappers which
|
||||||
define the implicit functions to be used sequentially in
|
define the implicit functions to be used sequentially in
|
||||||
|
@ -9,10 +9,10 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools import camera_utils
|
from pytorch3d.implicitron.tools import camera_utils
|
||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
|
from pytorch3d.renderer import NDCMultinomialRaysampler
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from .base import EvaluationMode, RenderSamplingMode
|
from .base import EvaluationMode, ImplicitronRayBundle, RenderSamplingMode
|
||||||
|
|
||||||
|
|
||||||
class RaySamplerBase(ReplaceableBase):
|
class RaySamplerBase(ReplaceableBase):
|
||||||
@ -28,7 +28,7 @@ class RaySamplerBase(ReplaceableBase):
|
|||||||
cameras: CamerasBase,
|
cameras: CamerasBase,
|
||||||
evaluation_mode: EvaluationMode,
|
evaluation_mode: EvaluationMode,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
) -> RayBundle:
|
) -> ImplicitronRayBundle:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||||
@ -42,7 +42,7 @@ class RaySamplerBase(ReplaceableBase):
|
|||||||
corresponding pixel's ray.
|
corresponding pixel's ray.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
|
||||||
sampled rendering rays.
|
sampled rendering rays.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -135,7 +135,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
cameras: CamerasBase,
|
cameras: CamerasBase,
|
||||||
evaluation_mode: EvaluationMode,
|
evaluation_mode: EvaluationMode,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
) -> RayBundle:
|
) -> ImplicitronRayBundle:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -150,7 +150,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
corresponding pixel's ray.
|
corresponding pixel's ray.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
|
||||||
sampled rendering rays.
|
sampled rendering rays.
|
||||||
"""
|
"""
|
||||||
sample_mask = None
|
sample_mask = None
|
||||||
@ -180,7 +180,19 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
|||||||
max_depth=max_depth,
|
max_depth=max_depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ray_bundle
|
if isinstance(ray_bundle, tuple):
|
||||||
|
return ImplicitronRayBundle(
|
||||||
|
# pyre-ignore[16]
|
||||||
|
**{k: v for k, v in ray_bundle._asdict().items()}
|
||||||
|
)
|
||||||
|
return ImplicitronRayBundle(
|
||||||
|
directions=ray_bundle.directions,
|
||||||
|
origins=ray_bundle.origins,
|
||||||
|
lengths=ray_bundle.lengths,
|
||||||
|
xys=ray_bundle.xys,
|
||||||
|
camera_ids=ray_bundle.camera_ids,
|
||||||
|
camera_counts=ray_bundle.camera_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
|
@ -7,8 +7,10 @@ import logging
|
|||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||||
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
@ -89,7 +91,7 @@ class RayNormalColoringNetwork(torch.nn.Module):
|
|||||||
feature_vectors: torch.Tensor,
|
feature_vectors: torch.Tensor,
|
||||||
points,
|
points,
|
||||||
normals,
|
normals,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
masks=None,
|
masks=None,
|
||||||
pooling_fn=None,
|
pooling_fn=None,
|
||||||
):
|
):
|
||||||
|
@ -8,13 +8,13 @@ from typing import List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.common.compat import prod
|
from pytorch3d.common.compat import prod
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
get_default_args_field,
|
get_default_args_field,
|
||||||
registry,
|
registry,
|
||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.utils import evaluating
|
from pytorch3d.implicitron.tools.utils import evaluating
|
||||||
from pytorch3d.renderer import RayBundle
|
|
||||||
|
|
||||||
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
||||||
from .ray_tracing import RayTracing
|
from .ray_tracing import RayTracing
|
||||||
@ -69,7 +69,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
implicit_functions: List[ImplicitFunctionWrapper],
|
implicit_functions: List[ImplicitFunctionWrapper],
|
||||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
object_mask: Optional[torch.Tensor] = None,
|
object_mask: Optional[torch.Tensor] = None,
|
||||||
@ -77,7 +77,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
|||||||
) -> RendererOutput:
|
) -> RendererOutput:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
|
||||||
sampled rendering rays.
|
sampled rendering rays.
|
||||||
implicit_functions: single element list of ImplicitFunctionWrappers which
|
implicit_functions: single element list of ImplicitFunctionWrappers which
|
||||||
defines the implicit function to be used.
|
defines the implicit function to be used.
|
||||||
|
@ -149,9 +149,8 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
in __init__.
|
in __init__.
|
||||||
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
||||||
is as if `n_rays_total_training` cameras were sampled with replacement from the
|
is as if `n_rays_total_training` cameras were sampled with replacement from the
|
||||||
cameras provided and for every camera one ray was sampled. If set, this disables
|
cameras provided and for every camera one ray was sampled. If set, returns the
|
||||||
`n_rays_per_image` and returns the HeterogeneousRayBundle with
|
HeterogeneousRayBundle with batch_size=n_rays_total.
|
||||||
batch_size=n_rays_total.
|
|
||||||
Returns:
|
Returns:
|
||||||
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
|
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
|
||||||
following fields:
|
following fields:
|
||||||
@ -188,9 +187,10 @@ class MultinomialRaysampler(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
n_rays_total = n_rays_total or self._n_rays_total
|
n_rays_total = n_rays_total or self._n_rays_total
|
||||||
n_rays_per_image = n_rays_per_image or self._n_rays_per_image
|
n_rays_per_image = n_rays_per_image or self._n_rays_per_image
|
||||||
assert (n_rays_total is None) or (
|
if (n_rays_total is not None) and (n_rays_per_image is not None):
|
||||||
n_rays_per_image is None
|
raise ValueError(
|
||||||
), "`n_rays_total` and `n_rays_per_image` cannot both be defined."
|
"`n_rays_total` and `n_rays_per_image` cannot both be defined."
|
||||||
|
)
|
||||||
if n_rays_total:
|
if n_rays_total:
|
||||||
(
|
(
|
||||||
cameras,
|
cameras,
|
||||||
@ -357,9 +357,8 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
max_depth: The maximum depth of each ray-point.
|
max_depth: The maximum depth of each ray-point.
|
||||||
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
n_rays_total: How many rays in total to sample from the cameras provided. The result
|
||||||
is as if `n_rays_total_training` cameras were sampled with replacement from the
|
is as if `n_rays_total_training` cameras were sampled with replacement from the
|
||||||
cameras provided and for every camera one ray was sampled. If set, this disables
|
cameras provided and for every camera one ray was sampled. If set, this returns
|
||||||
`n_rays_per_image` and returns the HeterogeneousRayBundleyBundle with
|
the HeterogeneousRayBundleyBundle with batch_size=n_rays_total.
|
||||||
batch_size=n_rays_total.
|
|
||||||
unit_directions: whether to normalize direction vectors in ray bundle.
|
unit_directions: whether to normalize direction vectors in ray bundle.
|
||||||
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
|
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
|
||||||
bins for each ray; otherwise takes n_pts_per_ray deterministic points
|
bins for each ray; otherwise takes n_pts_per_ray deterministic points
|
||||||
@ -416,9 +415,14 @@ class MonteCarloRaysampler(torch.nn.Module):
|
|||||||
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
|
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
|
||||||
cameras. Represents how many times each camera from `camera_ids` was sampled
|
cameras. Represents how many times each camera from `camera_ids` was sampled
|
||||||
"""
|
"""
|
||||||
assert (self._n_rays_total is None) or (
|
if (
|
||||||
self._n_rays_per_image is None
|
sum(x is not None for x in [self._n_rays_total, self._n_rays_per_image])
|
||||||
), "`self.n_rays_total` and `self.n_rays_per_image` cannot both be defined."
|
!= 1
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Exactly one of `self.n_rays_total` and `self.n_rays_per_image` "
|
||||||
|
"must be given."
|
||||||
|
)
|
||||||
|
|
||||||
if self._n_rays_total:
|
if self._n_rays_total:
|
||||||
(
|
(
|
||||||
|
@ -297,6 +297,7 @@ class VolumeSampler(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Given an input ray parametrization, the forward function samples
|
Given an input ray parametrization, the forward function samples
|
||||||
`self._volumes` at the respective 3D ray-points.
|
`self._volumes` at the respective 3D ray-points.
|
||||||
|
Can also accept ImplicitronRayBundle as argument for ray_bundle.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields:
|
ray_bundle: A RayBundle or HeterogeneousRayBundle object with the following fields:
|
||||||
|
@ -11,6 +11,7 @@ import plotly.graph_objects as go
|
|||||||
import torch
|
import torch
|
||||||
from plotly.subplots import make_subplots
|
from plotly.subplots import make_subplots
|
||||||
from pytorch3d.renderer import (
|
from pytorch3d.renderer import (
|
||||||
|
HeterogeneousRayBundle,
|
||||||
ray_bundle_to_ray_points,
|
ray_bundle_to_ray_points,
|
||||||
RayBundle,
|
RayBundle,
|
||||||
TexturesAtlas,
|
TexturesAtlas,
|
||||||
@ -21,14 +22,45 @@ from pytorch3d.renderer.cameras import CamerasBase
|
|||||||
from pytorch3d.structures import join_meshes_as_scene, Meshes, Pointclouds
|
from pytorch3d.structures import join_meshes_as_scene, Meshes, Pointclouds
|
||||||
|
|
||||||
|
|
||||||
Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle]
|
Struct = Union[CamerasBase, Meshes, Pointclouds, RayBundle, HeterogeneousRayBundle]
|
||||||
|
|
||||||
|
|
||||||
def _get_struct_len(struct: Struct) -> int: # pragma: no cover
|
def _get_len(struct: Union[Struct, List[Struct]]) -> int: # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Returns the length (usually corresponds to the batch size) of the input structure.
|
Returns the length (usually corresponds to the batch size) of the input structure.
|
||||||
"""
|
"""
|
||||||
return len(struct.directions) if isinstance(struct, RayBundle) else len(struct)
|
# pyre-ignore[6]
|
||||||
|
if not _is_ray_bundle(struct):
|
||||||
|
# pyre-ignore[6]
|
||||||
|
return len(struct)
|
||||||
|
if _is_heterogeneous_ray_bundle(struct):
|
||||||
|
# pyre-ignore[16]
|
||||||
|
return len(struct.camera_counts)
|
||||||
|
# pyre-ignore[16]
|
||||||
|
return len(struct.directions)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_ray_bundle(struct: Struct) -> bool:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
struct: Struct object to test
|
||||||
|
Returns:
|
||||||
|
True if something is a RayBundle, HeterogeneousRayBundle or
|
||||||
|
ImplicitronRayBundle, else False
|
||||||
|
"""
|
||||||
|
return hasattr(struct, "directions")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_heterogeneous_ray_bundle(struct: Union[List[Struct], Struct]) -> bool:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
struct :object to test
|
||||||
|
Returns:
|
||||||
|
True if something is a HeterogeneousRayBundle or ImplicitronRayBundle
|
||||||
|
and cant be reduced to RayBundle else False
|
||||||
|
"""
|
||||||
|
# pyre-ignore[16]
|
||||||
|
return hasattr(struct, "camera_counts") and struct.camera_counts is not None
|
||||||
|
|
||||||
|
|
||||||
def get_camera_wireframe(scale: float = 0.3): # pragma: no cover
|
def get_camera_wireframe(scale: float = 0.3): # pragma: no cover
|
||||||
@ -301,7 +333,7 @@ def plot_scene(
|
|||||||
_add_camera_trace(
|
_add_camera_trace(
|
||||||
fig, struct, trace_name, subplot_idx, ncols, camera_scale
|
fig, struct, trace_name, subplot_idx, ncols, camera_scale
|
||||||
)
|
)
|
||||||
elif isinstance(struct, RayBundle):
|
elif _is_ray_bundle(struct):
|
||||||
_add_ray_bundle_trace(
|
_add_ray_bundle_trace(
|
||||||
fig,
|
fig,
|
||||||
struct,
|
struct,
|
||||||
@ -316,7 +348,7 @@ def plot_scene(
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"struct {} is not a Cameras, Meshes, Pointclouds,".format(struct)
|
"struct {} is not a Cameras, Meshes, Pointclouds,".format(struct)
|
||||||
+ " or RayBundle object."
|
+ " , RayBundle or HeterogeneousRayBundle object."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure update for every subplot.
|
# Ensure update for every subplot.
|
||||||
@ -401,15 +433,16 @@ def plot_batch_individually(
|
|||||||
In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in
|
In addition, you can include Cameras, Meshes, Pointclouds, or RayBundle of size 1 in
|
||||||
the input. These will either be rendered in the first subplot
|
the input. These will either be rendered in the first subplot
|
||||||
(if extend_struct is False), or in every subplot.
|
(if extend_struct is False), or in every subplot.
|
||||||
|
RayBundle includes ImplicitronRayBundle and HeterogeneousRaybundle.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batched_structs: a list of Cameras, Meshes, Pointclouds, and RayBundle
|
batched_structs: a list of Cameras, Meshes, Pointclouds and RayBundle to be
|
||||||
to be rendered. Each structure's corresponding batch element will be
|
rendered. Each structure's corresponding batch element will be plotted in a
|
||||||
plotted in a single subplot, resulting in n subplots for a batch of size n.
|
single subplot, resulting in n subplots for a batch of size n. Every struct
|
||||||
Every struct should either have the same batch size or be of batch size 1.
|
should either have the same batch size or be of batch size 1. See extend_struct
|
||||||
See extend_struct and the description above for how batch size 1 structs
|
and the description above for how batch size 1 structs are handled. Also accepts
|
||||||
are handled. Also accepts a single Cameras, Meshes, Pointclouds, and RayBundle
|
a single Cameras, Meshes, Pointclouds, and RayBundle object, which will have
|
||||||
object, which will have each individual element plotted in its own subplot.
|
each individual element plotted in its own subplot.
|
||||||
viewpoint_cameras: an instance of a Cameras object providing a location
|
viewpoint_cameras: an instance of a Cameras object providing a location
|
||||||
to view the plotly plot from. If the batch size is equal
|
to view the plotly plot from. If the batch size is equal
|
||||||
to the number of subplots, it is a one to one mapping.
|
to the number of subplots, it is a one to one mapping.
|
||||||
@ -451,20 +484,20 @@ def plot_batch_individually(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# check that every batch is the same size or is size 1
|
# check that every batch is the same size or is size 1
|
||||||
if len(batched_structs) == 0:
|
if _get_len(batched_structs) == 0:
|
||||||
msg = "No structs to plot"
|
msg = "No structs to plot"
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
return
|
return
|
||||||
max_size = 0
|
max_size = 0
|
||||||
if isinstance(batched_structs, list):
|
if isinstance(batched_structs, list):
|
||||||
max_size = max(_get_struct_len(s) for s in batched_structs)
|
max_size = max(_get_len(s) for s in batched_structs)
|
||||||
for struct in batched_structs:
|
for struct in batched_structs:
|
||||||
struct_len = _get_struct_len(struct)
|
struct_len = _get_len(struct)
|
||||||
if struct_len not in (1, max_size):
|
if struct_len not in (1, max_size):
|
||||||
msg = "invalid batch size {} provided: {}".format(struct_len, struct)
|
msg = "invalid batch size {} provided: {}".format(struct_len, struct)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
else:
|
else:
|
||||||
max_size = _get_struct_len(batched_structs)
|
max_size = _get_len(batched_structs)
|
||||||
|
|
||||||
if max_size == 0:
|
if max_size == 0:
|
||||||
msg = "No data is provided with at least one element"
|
msg = "No data is provided with at least one element"
|
||||||
@ -475,6 +508,14 @@ def plot_batch_individually(
|
|||||||
msg = "invalid number of subplot titles"
|
msg = "invalid number of subplot titles"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# if we are dealing with HeterogeneousRayBundle of ImplicitronRayBundle create
|
||||||
|
# first indexes for faster
|
||||||
|
first_idxs = None
|
||||||
|
if _is_heterogeneous_ray_bundle(batched_structs):
|
||||||
|
# pyre-ignore[16]
|
||||||
|
cumsum = batched_structs.camera_counts.cumsum(dim=0)
|
||||||
|
first_idxs = torch.cat((cumsum.new_zeros((1,)), cumsum))
|
||||||
|
|
||||||
scene_dictionary = {}
|
scene_dictionary = {}
|
||||||
# construct the scene dictionary
|
# construct the scene dictionary
|
||||||
for scene_num in range(max_size):
|
for scene_num in range(max_size):
|
||||||
@ -487,16 +528,30 @@ def plot_batch_individually(
|
|||||||
|
|
||||||
if isinstance(batched_structs, list):
|
if isinstance(batched_structs, list):
|
||||||
for i, batched_struct in enumerate(batched_structs):
|
for i, batched_struct in enumerate(batched_structs):
|
||||||
|
first_idxs = None
|
||||||
|
if _is_heterogeneous_ray_bundle(batched_structs[i]):
|
||||||
|
# pyre-ignore[16]
|
||||||
|
cumsum = batched_struct.camera_counts.cumsum(dim=0)
|
||||||
|
first_idxs = torch.cat((cumsum.new_zeros((1,)), cumsum))
|
||||||
# check for whether this struct needs to be extended
|
# check for whether this struct needs to be extended
|
||||||
batched_struct_len = _get_struct_len(batched_struct)
|
batched_struct_len = _get_len(batched_struct)
|
||||||
if i >= batched_struct_len and not extend_struct:
|
if i >= batched_struct_len and not extend_struct:
|
||||||
continue
|
continue
|
||||||
_add_struct_from_batch(
|
_add_struct_from_batch(
|
||||||
batched_struct, scene_num, subplot_title, scene_dictionary, i + 1
|
batched_struct,
|
||||||
|
scene_num,
|
||||||
|
subplot_title,
|
||||||
|
scene_dictionary,
|
||||||
|
i + 1,
|
||||||
|
first_idxs=first_idxs,
|
||||||
)
|
)
|
||||||
else: # batched_structs is a single struct
|
else: # batched_structs is a single struct
|
||||||
_add_struct_from_batch(
|
_add_struct_from_batch(
|
||||||
batched_structs, scene_num, subplot_title, scene_dictionary
|
batched_structs,
|
||||||
|
scene_num,
|
||||||
|
subplot_title,
|
||||||
|
scene_dictionary,
|
||||||
|
first_idxs=first_idxs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return plot_scene(
|
return plot_scene(
|
||||||
@ -510,6 +565,7 @@ def _add_struct_from_batch(
|
|||||||
subplot_title: str,
|
subplot_title: str,
|
||||||
scene_dictionary: Dict[str, Dict[str, Struct]],
|
scene_dictionary: Dict[str, Dict[str, Struct]],
|
||||||
trace_idx: int = 1,
|
trace_idx: int = 1,
|
||||||
|
first_idxs: Optional[torch.Tensor] = None,
|
||||||
) -> None: # pragma: no cover
|
) -> None: # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Adds the struct corresponding to the given scene_num index to
|
Adds the struct corresponding to the given scene_num index to
|
||||||
@ -544,17 +600,35 @@ def _add_struct_from_batch(
|
|||||||
# torch.Tensor, torch.nn.Module]` is not a function.
|
# torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
T = T[t_idx].unsqueeze(0)
|
T = T[t_idx].unsqueeze(0)
|
||||||
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
|
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
|
||||||
elif isinstance(batched_struct, RayBundle):
|
elif _is_ray_bundle(batched_struct) and not _is_heterogeneous_ray_bundle(
|
||||||
# for RayBundle we treat the 1st dim as the batch index
|
batched_struct
|
||||||
struct_idx = min(scene_num, len(batched_struct.lengths) - 1)
|
):
|
||||||
|
# for RayBundle we treat the camera count as the batch index
|
||||||
|
struct_idx = min(scene_num, _get_len(batched_struct) - 1)
|
||||||
|
|
||||||
struct = RayBundle(
|
struct = RayBundle(
|
||||||
**{
|
**{
|
||||||
attr: getattr(batched_struct, attr)[struct_idx]
|
attr: getattr(batched_struct, attr)[struct_idx]
|
||||||
for attr in ["origins", "directions", "lengths", "xys"]
|
for attr in ["origins", "directions", "lengths", "xys"]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif _is_heterogeneous_ray_bundle(batched_struct):
|
||||||
|
# for RayBundle we treat the camera count as the batch index
|
||||||
|
struct_idx = min(scene_num, _get_len(batched_struct) - 1)
|
||||||
|
|
||||||
|
struct = RayBundle(
|
||||||
|
**{
|
||||||
|
attr: getattr(batched_struct, attr)[
|
||||||
|
# pyre-ignore[16]
|
||||||
|
first_idxs[struct_idx] : first_idxs[struct_idx + 1]
|
||||||
|
]
|
||||||
|
for attr in ["origins", "directions", "lengths", "xys"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
else: # batched meshes and pointclouds are indexable
|
else: # batched meshes and pointclouds are indexable
|
||||||
struct_idx = min(scene_num, len(batched_struct) - 1)
|
struct_idx = min(scene_num, _get_len(batched_struct) - 1)
|
||||||
|
# pyre-ignore[16]
|
||||||
struct = batched_struct[struct_idx]
|
struct = batched_struct[struct_idx]
|
||||||
trace_name = "trace{}-{}".format(scene_num + 1, trace_idx)
|
trace_name = "trace{}-{}".format(scene_num + 1, trace_idx)
|
||||||
scene_dictionary[subplot_title][trace_name] = struct
|
scene_dictionary[subplot_title][trace_name] = struct
|
||||||
@ -753,7 +827,7 @@ def _add_camera_trace(
|
|||||||
|
|
||||||
def _add_ray_bundle_trace(
|
def _add_ray_bundle_trace(
|
||||||
fig: go.Figure,
|
fig: go.Figure,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: Union[RayBundle, HeterogeneousRayBundle],
|
||||||
trace_name: str,
|
trace_name: str,
|
||||||
subplot_idx: int,
|
subplot_idx: int,
|
||||||
ncols: int,
|
ncols: int,
|
||||||
@ -763,12 +837,13 @@ def _add_ray_bundle_trace(
|
|||||||
line_width: int,
|
line_width: int,
|
||||||
) -> None: # pragma: no cover
|
) -> None: # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Adds a trace rendering a RayBundle object to the passed in figure, with
|
Adds a trace rendering a ray bundle object
|
||||||
a given name and in a specific subplot.
|
to the passed in figure, with a given name and in a specific subplot.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fig: plotly figure to add the trace within.
|
fig: plotly figure to add the trace within.
|
||||||
cameras: the Cameras object to render. It can be batched.
|
ray_bundle: the RayBundle, ImplicitronRayBundle or HeterogeneousRaybundle to render.
|
||||||
|
It can be batched.
|
||||||
trace_name: name to label the trace with.
|
trace_name: name to label the trace with.
|
||||||
subplot_idx: identifies the subplot, with 0 being the top left.
|
subplot_idx: identifies the subplot, with 0 being the top left.
|
||||||
ncols: the number of subplots per row.
|
ncols: the number of subplots per row.
|
||||||
|
@ -8,7 +8,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
|
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
|
||||||
from pytorch3d.renderer import RayBundle
|
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
|
|
||||||
@ -24,7 +24,14 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
|||||||
add_input_samples=add_input_samples,
|
add_input_samples=add_input_samples,
|
||||||
)
|
)
|
||||||
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
|
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
|
||||||
bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None)
|
bundle = ImplicitronRayBundle(
|
||||||
|
lengths=lengths,
|
||||||
|
origins=None,
|
||||||
|
directions=None,
|
||||||
|
xys=None,
|
||||||
|
camera_ids=None,
|
||||||
|
camera_counts=None,
|
||||||
|
)
|
||||||
weights = torch.ones(3, 25, length)
|
weights = torch.ones(3, 25, length)
|
||||||
refined = ray_point_refiner(bundle, weights)
|
refined = ray_point_refiner(bundle, weights)
|
||||||
|
|
||||||
|
@ -13,8 +13,10 @@ from pytorch3d.implicitron.models.implicit_function.scene_representation_network
|
|||||||
SRNImplicitFunction,
|
SRNImplicitFunction,
|
||||||
SRNPixelGenerator,
|
SRNPixelGenerator,
|
||||||
)
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
||||||
from pytorch3d.implicitron.tools.config import get_default_args
|
from pytorch3d.implicitron.tools.config import get_default_args
|
||||||
from pytorch3d.renderer import PerspectiveCameras, RayBundle
|
from pytorch3d.renderer import PerspectiveCameras
|
||||||
|
|
||||||
from tests.common_testing import TestCaseMixin
|
from tests.common_testing import TestCaseMixin
|
||||||
|
|
||||||
_BATCH_SIZE: int = 3
|
_BATCH_SIZE: int = 3
|
||||||
@ -31,12 +33,17 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_pixel_generator(self):
|
def test_pixel_generator(self):
|
||||||
SRNPixelGenerator()
|
SRNPixelGenerator()
|
||||||
|
|
||||||
def _get_bundle(self, *, device) -> RayBundle:
|
def _get_bundle(self, *, device) -> ImplicitronRayBundle:
|
||||||
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
||||||
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
||||||
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
|
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
|
||||||
bundle = RayBundle(
|
bundle = ImplicitronRayBundle(
|
||||||
lengths=lengths, origins=origins, directions=directions, xys=None
|
lengths=lengths,
|
||||||
|
origins=origins,
|
||||||
|
directions=directions,
|
||||||
|
xys=None,
|
||||||
|
camera_ids=None,
|
||||||
|
camera_counts=None,
|
||||||
)
|
)
|
||||||
return bundle
|
return bundle
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user