mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Implicit/Volume renderer
Summary: Implements the `ImplicitRenderer` and `VolumeRenderer`. Reviewed By: gkioxari Differential Revision: D24418791 fbshipit-source-id: 127f21186d8e210895db1dcd0681f09f230d81a4
This commit is contained in:
parent
e6a32bfc37
commit
b466c381da
@ -24,9 +24,12 @@ from .implicit import (
|
|||||||
AbsorptionOnlyRaymarcher,
|
AbsorptionOnlyRaymarcher,
|
||||||
EmissionAbsorptionRaymarcher,
|
EmissionAbsorptionRaymarcher,
|
||||||
GridRaysampler,
|
GridRaysampler,
|
||||||
|
ImplicitRenderer,
|
||||||
MonteCarloRaysampler,
|
MonteCarloRaysampler,
|
||||||
NDCGridRaysampler,
|
NDCGridRaysampler,
|
||||||
RayBundle,
|
RayBundle,
|
||||||
|
VolumeRenderer,
|
||||||
|
VolumeSampler,
|
||||||
ray_bundle_to_ray_points,
|
ray_bundle_to_ray_points,
|
||||||
ray_bundle_variables_to_ray_points,
|
ray_bundle_variables_to_ray_points,
|
||||||
)
|
)
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||||
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||||
|
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
|
||||||
from .utils import (
|
from .utils import (
|
||||||
RayBundle,
|
RayBundle,
|
||||||
ray_bundle_to_ray_points,
|
ray_bundle_to_ray_points,
|
||||||
|
372
pytorch3d/renderer/implicit/renderer.py
Normal file
372
pytorch3d/renderer/implicit/renderer.py
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...ops.utils import eyes
|
||||||
|
from ...structures import Volumes
|
||||||
|
from ...transforms import Transform3d
|
||||||
|
from ..cameras import CamerasBase
|
||||||
|
from .raysampling import RayBundle
|
||||||
|
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points
|
||||||
|
|
||||||
|
|
||||||
|
# The implicit renderer class should be initialized with a
|
||||||
|
# function for raysampling and a function for raymarching.
|
||||||
|
|
||||||
|
# During the forward pass:
|
||||||
|
# 1) The raysampler:
|
||||||
|
# - samples rays from input cameras
|
||||||
|
# - transforms the rays to world coordinates
|
||||||
|
# 2) The volumetric_function (which is a callable argument of the forwad pass)
|
||||||
|
# evaluates ray_densities and ray_features at the sampled ray-points.
|
||||||
|
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
|
||||||
|
# algorithm to render each ray.
|
||||||
|
|
||||||
|
|
||||||
|
class ImplicitRenderer(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A class for rendering a batch of implicit surfaces. The class should
|
||||||
|
be initialized with a raysampler and raymarcher class which both have
|
||||||
|
to be a `Callable`.
|
||||||
|
|
||||||
|
VOLUMETRIC_FUNCTION
|
||||||
|
|
||||||
|
The `forward` function of the renderer accepts as input the rendering cameras as well
|
||||||
|
as the `volumetric_function` `Callable`, which defines a field of opacity
|
||||||
|
and feature vectors over the 3D domain of the scene.
|
||||||
|
|
||||||
|
A standard `volumetric_function` has the following signature:
|
||||||
|
```
|
||||||
|
def volumetric_function(ray_bundle: RayBundle) -> Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
```
|
||||||
|
With the following arguments:
|
||||||
|
`ray_bundle`: A RayBundle object containing the following variables:
|
||||||
|
`rays_origins`: A tensor of shape `(minibatch, ..., 3)` denoting
|
||||||
|
the origins of the rendering rays.
|
||||||
|
`rays_directions`: A tensor of shape `(minibatch, ..., 3)`
|
||||||
|
containing the direction vectors of rendering rays.
|
||||||
|
`rays_lengths`: A tensor of shape
|
||||||
|
`(minibatch, ..., num_points_per_ray)`containing the
|
||||||
|
lengths at which the ray points are sampled.
|
||||||
|
Calling `volumetric_function` then returns the following:
|
||||||
|
`rays_densities`: A tensor of shape
|
||||||
|
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing
|
||||||
|
the an opacity vector for each ray point.
|
||||||
|
`rays_features`: A tensor of shape
|
||||||
|
`(minibatch, ..., num_points_per_ray, feature_dim)` containing
|
||||||
|
the an feature vector for each ray point.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
A simple volumetric function of a 0-centered
|
||||||
|
RGB sphere with a unit diameter is defined as follows:
|
||||||
|
```
|
||||||
|
def volumetric_function(
|
||||||
|
ray_bundle: RayBundle,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
# first convert the ray origins, directions and lengths
|
||||||
|
# to 3D ray point locations in world coords
|
||||||
|
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||||
|
|
||||||
|
# set the densities as an inverse sigmoid of the
|
||||||
|
# ray point distance from the sphere centroid
|
||||||
|
rays_densities = torch.sigmoid(
|
||||||
|
-100.0 * rays_points_world.norm(dim=-1, keepdim=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# set the ray features to RGB colors proportional
|
||||||
|
# to the 3D location of the projection of ray points
|
||||||
|
# on the sphere surface
|
||||||
|
rays_features = torch.nn.functional.normalize(
|
||||||
|
rays_points_world, dim=-1
|
||||||
|
) * 0.5 + 0.5
|
||||||
|
|
||||||
|
return rays_densities, rays_features
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, raysampler: Callable, raymarcher: Callable):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
raysampler: A `Callable` that takes as input scene cameras
|
||||||
|
(an instance of `CamerasBase`) and returns a `RayBundle` that
|
||||||
|
describes the rays emitted from the cameras.
|
||||||
|
raymarcher: A `Callable` that receives the response of the
|
||||||
|
`volumetric_function` (an input to `self.forward`) evaluated
|
||||||
|
along the sampled rays, and renders the rays with a
|
||||||
|
ray-marching algorithm.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not callable(raysampler):
|
||||||
|
raise ValueError('"raysampler" has to be a "Callable" object.')
|
||||||
|
if not callable(raymarcher):
|
||||||
|
raise ValueError('"raymarcher" has to be a "Callable" object.')
|
||||||
|
|
||||||
|
self.raysampler = raysampler
|
||||||
|
self.raymarcher = raymarcher
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
|
||||||
|
) -> Tuple[torch.Tensor, RayBundle]:
|
||||||
|
"""
|
||||||
|
Render a batch of images using a volumetric function
|
||||||
|
represented as a callable (e.g. a Pytorch module).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cameras: A batch of cameras that render the scene. A `self.raysampler`
|
||||||
|
takes the cameras as input and samples rays that pass through the
|
||||||
|
domain of the volumentric function.
|
||||||
|
volumetric_function: A `Callable` that accepts the parametrizations
|
||||||
|
of the rendering rays and returns the densities and features
|
||||||
|
at the respective 3D of the rendering rays. Please refer to
|
||||||
|
the main class documentation for details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
|
||||||
|
containing the result of the rendering.
|
||||||
|
ray_bundle: A `RayBundle` containing the parametrizations of the
|
||||||
|
sampled rendering rays.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not callable(volumetric_function):
|
||||||
|
raise ValueError('"volumetric_function" has to be a "Callable" object.')
|
||||||
|
|
||||||
|
# first call the ray sampler that returns the RayBundle parametrizing
|
||||||
|
# the rendering rays.
|
||||||
|
ray_bundle = self.raysampler(
|
||||||
|
cameras=cameras, volumetric_function=volumetric_function, **kwargs
|
||||||
|
)
|
||||||
|
# ray_bundle.origins - minibatch x ... x 3
|
||||||
|
# ray_bundle.directions - minibatch x ... x 3
|
||||||
|
# ray_bundle.lengths - minibatch x ... x n_pts_per_ray
|
||||||
|
# ray_bundle.xys - minibatch x ... x 2
|
||||||
|
|
||||||
|
# given sampled rays, call the volumetric function that
|
||||||
|
# evaluates the densities and features at the locations of the
|
||||||
|
# ray points
|
||||||
|
rays_densities, rays_features = volumetric_function(
|
||||||
|
ray_bundle=ray_bundle, cameras=cameras, **kwargs
|
||||||
|
)
|
||||||
|
# ray_densities - minibatch x ... x n_pts_per_ray x density_dim
|
||||||
|
# ray_features - minibatch x ... x n_pts_per_ray x feature_dim
|
||||||
|
|
||||||
|
# finally, march along the sampled rays to obtain the renders
|
||||||
|
images = self.raymarcher(
|
||||||
|
rays_densities=rays_densities,
|
||||||
|
rays_features=rays_features,
|
||||||
|
ray_bundle=ray_bundle,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
# images - minibatch x ... x (feature_dim + opacity_dim)
|
||||||
|
|
||||||
|
return images, ray_bundle
|
||||||
|
|
||||||
|
|
||||||
|
# The volume renderer class should be initialized with a
|
||||||
|
# function for raysampling and a function for raymarching.
|
||||||
|
|
||||||
|
# During the forward pass:
|
||||||
|
# 1) The raysampler:
|
||||||
|
# - samples rays from input cameras
|
||||||
|
# - transforms the rays to world coordinates
|
||||||
|
# 2) The scene volumes (which are an argument of the forward function)
|
||||||
|
# are then sampled at the locations of the ray-points to generate
|
||||||
|
# ray_densities and ray_features.
|
||||||
|
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
|
||||||
|
# algorithm to render each ray.
|
||||||
|
|
||||||
|
|
||||||
|
class VolumeRenderer(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A class for rendering a batch of Volumes. The class should
|
||||||
|
be initialized with a raysampler and a raymarcher class which both have
|
||||||
|
to be a `Callable`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
raysampler: A `Callable` that takes as input scene cameras
|
||||||
|
(an instance of `CamerasBase`) and returns a `RayBundle` that
|
||||||
|
describes the rays emitted from the cameras.
|
||||||
|
raymarcher: A `Callable` that receives the `volumes`
|
||||||
|
(an instance of `Volumes` input to `self.forward`)
|
||||||
|
sampled at the ray-points, and renders the rays with a
|
||||||
|
ray-marching algorithm.
|
||||||
|
sample_mode: Defines the algorithm used to sample the volumetric
|
||||||
|
voxel grid. Can be either "bilinear" or "nearest".
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.renderer = ImplicitRenderer(raysampler, raymarcher)
|
||||||
|
self._sample_mode = sample_mode
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, cameras: CamerasBase, volumes: Volumes, **kwargs
|
||||||
|
) -> Tuple[torch.Tensor, RayBundle]:
|
||||||
|
"""
|
||||||
|
Render a batch of images using raymarching over rays cast through
|
||||||
|
input `Volumes`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cameras: A batch of cameras that render the scene. A `self.raysampler`
|
||||||
|
takes the cameras as input and samples rays that pass through the
|
||||||
|
domain of the volumentric function.
|
||||||
|
volumes: An instance of the `Volumes` class representing a
|
||||||
|
batch of volumes that are being rendered.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
|
||||||
|
containing the result of the rendering.
|
||||||
|
ray_bundle: A `RayBundle` containing the parametrizations of the
|
||||||
|
sampled rendering rays.
|
||||||
|
"""
|
||||||
|
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
|
||||||
|
return self.renderer(
|
||||||
|
cameras=cameras, volumetric_function=volumetric_function, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VolumeSampler(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A class that allows to sample a batch of volumes `Volumes`
|
||||||
|
at 3D points sampled along projection rays.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
volumes: An instance of the `Volumes` class representing a
|
||||||
|
batch if volumes that are being rendered.
|
||||||
|
sample_mode: Defines the algorithm used to sample the volumetric
|
||||||
|
voxel grid. Can be either "bilinear" or "nearest".
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(volumes, Volumes):
|
||||||
|
raise ValueError("'volumes' have to be an instance of the 'Volumes' class.")
|
||||||
|
self._volumes = volumes
|
||||||
|
self._sample_mode = sample_mode
|
||||||
|
|
||||||
|
def _get_ray_directions_transform(self):
|
||||||
|
"""
|
||||||
|
Compose the ray-directions transform by removing the translation component
|
||||||
|
from the volume global-to-local coords transform.
|
||||||
|
"""
|
||||||
|
world2local = self._volumes.get_world_to_local_coords_transform().get_matrix()
|
||||||
|
directions_transform_matrix = eyes(
|
||||||
|
4,
|
||||||
|
N=world2local.shape[0],
|
||||||
|
device=world2local.device,
|
||||||
|
dtype=world2local.dtype,
|
||||||
|
)
|
||||||
|
directions_transform_matrix[:, :3, :3] = world2local[:, :3, :3]
|
||||||
|
directions_transform = Transform3d(matrix=directions_transform_matrix)
|
||||||
|
return directions_transform
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, ray_bundle: RayBundle, **kwargs
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Given an input ray parametrization, the forward function samples
|
||||||
|
`self._volumes` at the respective 3D ray-points.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ray_bundle: A RayBundle object with the following fields:
|
||||||
|
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
|
||||||
|
origins of the sampling rays in world coords.
|
||||||
|
rays_directions_world: A tensor of shape `(minibatch, ..., 3)`
|
||||||
|
containing the direction vectors of sampling rays in world coords.
|
||||||
|
rays_lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
|
||||||
|
containing the lengths at which the rays are sampled.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
rays_densities: A tensor of shape
|
||||||
|
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing the
|
||||||
|
densitity vectors sampled from the volume at the locations of
|
||||||
|
the ray points.
|
||||||
|
rays_features: A tensor of shape
|
||||||
|
`(minibatch, ..., num_points_per_ray, feature_dim)` containing the
|
||||||
|
feature vectors sampled from the volume at the locations of
|
||||||
|
the ray points.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# take out the interesting parts of ray_bundle
|
||||||
|
rays_origins_world = ray_bundle.origins
|
||||||
|
rays_directions_world = ray_bundle.directions
|
||||||
|
rays_lengths = ray_bundle.lengths
|
||||||
|
|
||||||
|
# validate the inputs
|
||||||
|
_validate_ray_bundle_variables(
|
||||||
|
rays_origins_world, rays_directions_world, rays_lengths
|
||||||
|
)
|
||||||
|
if self._volumes.densities().shape[0] != rays_origins_world.shape[0]:
|
||||||
|
raise ValueError("Input volumes have to have the same batch size as rays.")
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# 1) convert the origins/directions to the local coords #
|
||||||
|
#########################################################
|
||||||
|
|
||||||
|
# origins are mapped with the world_to_local transform of the volumes
|
||||||
|
rays_origins_local = self._volumes.world_to_local_coords(rays_origins_world)
|
||||||
|
|
||||||
|
# obtain the Transform3d object that transforms ray directions to local coords
|
||||||
|
directions_transform = self._get_ray_directions_transform()
|
||||||
|
|
||||||
|
# transform the directions to the local coords
|
||||||
|
rays_directions_local = directions_transform.transform_points(
|
||||||
|
rays_directions_world.view(rays_lengths.shape[0], -1, 3)
|
||||||
|
).view(rays_directions_world.shape)
|
||||||
|
|
||||||
|
############################
|
||||||
|
# 2) obtain the ray points #
|
||||||
|
############################
|
||||||
|
|
||||||
|
# this op produces a fairly big tensor (minibatch, ..., n_samples_per_ray, 3)
|
||||||
|
rays_points_local = ray_bundle_variables_to_ray_points(
|
||||||
|
rays_origins_local, rays_directions_local, rays_lengths
|
||||||
|
)
|
||||||
|
|
||||||
|
########################
|
||||||
|
# 3) sample the volume #
|
||||||
|
########################
|
||||||
|
|
||||||
|
# generate the tensor for sampling
|
||||||
|
volumes_densities = self._volumes.densities()
|
||||||
|
dim_density = volumes_densities.shape[1]
|
||||||
|
volumes_features = self._volumes.features()
|
||||||
|
# adjust the volumes_features variable in case we have a feature-less volume
|
||||||
|
if volumes_features is None:
|
||||||
|
dim_feature = 0
|
||||||
|
data_to_sample = volumes_densities
|
||||||
|
else:
|
||||||
|
dim_feature = volumes_features.shape[1]
|
||||||
|
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)
|
||||||
|
|
||||||
|
# reshape to a size which grid_sample likes
|
||||||
|
rays_points_local_flat = rays_points_local.view(
|
||||||
|
rays_points_local.shape[0], -1, 1, 1, 3
|
||||||
|
)
|
||||||
|
|
||||||
|
# run the grid sampler
|
||||||
|
data_sampled = torch.nn.functional.grid_sample(
|
||||||
|
data_to_sample,
|
||||||
|
rays_points_local_flat,
|
||||||
|
align_corners=True,
|
||||||
|
mode=self._sample_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# permute the dimensions & reshape after sampling
|
||||||
|
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
|
||||||
|
*rays_points_local.shape[:-1], data_sampled.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# split back to densities and features
|
||||||
|
rays_densities, rays_features = data_sampled.split(
|
||||||
|
[dim_density, dim_feature], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
return rays_densities, rays_features
|
@ -53,12 +53,12 @@ def ray_bundle_variables_to_ray_points(
|
|||||||
rays_lengths: torch.Tensor,
|
rays_lengths: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Converts rays parametrized with origins, directions
|
Converts rays parametrized with origins and directions
|
||||||
to 3D points by extending each ray according to the corresponding
|
to 3D points by extending each ray according to the corresponding
|
||||||
ray_length:
|
ray length:
|
||||||
|
|
||||||
E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions`
|
E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions`
|
||||||
and `rays_lengths`, the ray point at position `[i, j]` is:
|
and `rays_lengths`, the ray point at position `[i, j]` is:
|
||||||
```
|
```
|
||||||
rays_points[i, j, :] = (
|
rays_points[i, j, :] = (
|
||||||
rays_origins[i, :]
|
rays_origins[i, :]
|
||||||
@ -80,3 +80,39 @@ def ray_bundle_variables_to_ray_points(
|
|||||||
+ rays_lengths[..., :, None] * rays_directions[..., None, :]
|
+ rays_lengths[..., :, None] * rays_directions[..., None, :]
|
||||||
)
|
)
|
||||||
return rays_points
|
return rays_points
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_ray_bundle_variables(
|
||||||
|
rays_origins: torch.Tensor,
|
||||||
|
rays_directions: torch.Tensor,
|
||||||
|
rays_lengths: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Validate the shapes of RayBundle variables
|
||||||
|
`rays_origins`, `rays_directions`, and `rays_lengths`.
|
||||||
|
"""
|
||||||
|
ndim = rays_origins.ndim
|
||||||
|
if any(r.ndim != ndim for r in (rays_directions, rays_lengths)):
|
||||||
|
raise ValueError(
|
||||||
|
"rays_origins, rays_directions and rays_lengths"
|
||||||
|
+ " have to have the same number of dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
if ndim <= 2:
|
||||||
|
raise ValueError(
|
||||||
|
"rays_origins, rays_directions and rays_lengths"
|
||||||
|
+ " have to have at least 3 dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
spatial_size = rays_origins.shape[:-1]
|
||||||
|
if any(spatial_size != r.shape[:-1] for r in (rays_directions, rays_lengths)):
|
||||||
|
raise ValueError(
|
||||||
|
"The shapes of rays_origins, rays_directions and rays_lengths"
|
||||||
|
+ " may differ only in the last dimension."
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(r.shape[-1] != 3 for r in (rays_origins, rays_directions)):
|
||||||
|
raise ValueError(
|
||||||
|
"The size of the last dimension of rays_origins/rays_directions"
|
||||||
|
+ "has to be 3."
|
||||||
|
)
|
||||||
|
22
tests/bm_render_implicit.py
Normal file
22
tests/bm_render_implicit.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
from fvcore.common.benchmark import benchmark
|
||||||
|
from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||||
|
from test_render_implicit import TestRenderImplicit
|
||||||
|
|
||||||
|
|
||||||
|
def bm_render_volumes() -> None:
|
||||||
|
case_grid = {
|
||||||
|
"batch_size": [1, 5],
|
||||||
|
"raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher],
|
||||||
|
"n_rays_per_image": [64 ** 2, 256 ** 2],
|
||||||
|
"n_pts_per_ray": [16, 128],
|
||||||
|
}
|
||||||
|
test_cases = itertools.product(*case_grid.values())
|
||||||
|
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||||
|
|
||||||
|
benchmark(
|
||||||
|
TestRenderImplicit.renderer, "IMPLICIT_RENDERER", kwargs_list, warmup_iters=1
|
||||||
|
)
|
24
tests/bm_render_volumes.py
Normal file
24
tests/bm_render_volumes.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
from fvcore.common.benchmark import benchmark
|
||||||
|
from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||||
|
from test_render_volumes import TestRenderVolumes
|
||||||
|
|
||||||
|
|
||||||
|
def bm_render_volumes() -> None:
|
||||||
|
case_grid = {
|
||||||
|
"volume_size": [tuple([17] * 3), tuple([129] * 3)],
|
||||||
|
"batch_size": [1, 5],
|
||||||
|
"shape": ["sphere", "cube"],
|
||||||
|
"raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher],
|
||||||
|
"n_rays_per_image": [64 ** 2, 256 ** 2],
|
||||||
|
"n_pts_per_ray": [16, 128],
|
||||||
|
}
|
||||||
|
test_cases = itertools.product(*case_grid.values())
|
||||||
|
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||||
|
|
||||||
|
benchmark(
|
||||||
|
TestRenderVolumes.renderer, "VOLUME_RENDERER", kwargs_list, warmup_iters=1
|
||||||
|
)
|
403
tests/test_render_implicit.py
Normal file
403
tests/test_render_implicit.py
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
BlendParams,
|
||||||
|
EmissionAbsorptionRaymarcher,
|
||||||
|
GridRaysampler,
|
||||||
|
ImplicitRenderer,
|
||||||
|
Materials,
|
||||||
|
MeshRasterizer,
|
||||||
|
MeshRenderer,
|
||||||
|
MonteCarloRaysampler,
|
||||||
|
NDCGridRaysampler,
|
||||||
|
PointLights,
|
||||||
|
RasterizationSettings,
|
||||||
|
RayBundle,
|
||||||
|
SoftPhongShader,
|
||||||
|
TexturesVertex,
|
||||||
|
ray_bundle_to_ray_points,
|
||||||
|
)
|
||||||
|
from pytorch3d.structures import Meshes
|
||||||
|
from pytorch3d.utils import ico_sphere
|
||||||
|
from test_render_volumes import init_cameras
|
||||||
|
|
||||||
|
|
||||||
|
DEBUG = False
|
||||||
|
if DEBUG:
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def spherical_volumetric_function(
|
||||||
|
ray_bundle: RayBundle,
|
||||||
|
sphere_centroid: torch.Tensor,
|
||||||
|
sphere_diameter: float,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Volumetric function of a simple RGB sphere with diameter `sphere_diameter`
|
||||||
|
and centroid `sphere_centroid`.
|
||||||
|
"""
|
||||||
|
# convert the ray bundle to world points
|
||||||
|
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||||
|
batch_size = rays_points_world.shape[0]
|
||||||
|
|
||||||
|
# surface_vectors = vectors from world coords towards the sphere centroid
|
||||||
|
surface_vectors = (
|
||||||
|
rays_points_world.view(batch_size, -1, 3) - sphere_centroid[:, None]
|
||||||
|
)
|
||||||
|
|
||||||
|
# the squared distance of each ray point to the centroid of the sphere
|
||||||
|
surface_dist = (
|
||||||
|
(surface_vectors ** 2)
|
||||||
|
.sum(-1, keepdim=True)
|
||||||
|
.view(*rays_points_world.shape[:-1], 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# set all ray densities within the sphere_diameter distance from the centroid to 1
|
||||||
|
rays_densities = torch.sigmoid(-100.0 * (surface_dist - sphere_diameter ** 2))
|
||||||
|
|
||||||
|
# ray colors are proportional to the normalized surface_vectors
|
||||||
|
rays_features = (
|
||||||
|
torch.nn.functional.normalize(
|
||||||
|
surface_vectors.view(rays_points_world.shape), dim=-1
|
||||||
|
)
|
||||||
|
* 0.5
|
||||||
|
+ 0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
return rays_densities, rays_features
|
||||||
|
|
||||||
|
|
||||||
|
class TestRenderImplicit(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
torch.manual_seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def renderer(
|
||||||
|
batch_size=10,
|
||||||
|
raymarcher_type=EmissionAbsorptionRaymarcher,
|
||||||
|
n_rays_per_image=10,
|
||||||
|
n_pts_per_ray=10,
|
||||||
|
sphere_diameter=0.75,
|
||||||
|
):
|
||||||
|
# generate NDC camera extrinsics and intrinsics
|
||||||
|
cameras = init_cameras(batch_size, image_size=None, ndc=True)
|
||||||
|
|
||||||
|
# get rand offset of the volume
|
||||||
|
sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1
|
||||||
|
|
||||||
|
# init the mc raysampler
|
||||||
|
raysampler = MonteCarloRaysampler(
|
||||||
|
min_x=-1.0,
|
||||||
|
max_x=1.0,
|
||||||
|
min_y=-1.0,
|
||||||
|
max_y=1.0,
|
||||||
|
n_rays_per_image=n_rays_per_image,
|
||||||
|
n_pts_per_ray=n_pts_per_ray,
|
||||||
|
min_depth=0.1,
|
||||||
|
max_depth=2.0,
|
||||||
|
).to(cameras.device)
|
||||||
|
|
||||||
|
# get the raymarcher
|
||||||
|
raymarcher = raymarcher_type()
|
||||||
|
|
||||||
|
# get the implicit renderer
|
||||||
|
renderer = ImplicitRenderer(raysampler=raysampler, raymarcher=raymarcher)
|
||||||
|
|
||||||
|
def run_renderer():
|
||||||
|
renderer(
|
||||||
|
cameras=cameras,
|
||||||
|
volumetric_function=spherical_volumetric_function,
|
||||||
|
sphere_centroid=sphere_centroid,
|
||||||
|
sphere_diameter=sphere_diameter,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run_renderer
|
||||||
|
|
||||||
|
def test_input_types(self):
|
||||||
|
"""
|
||||||
|
Check that ValueErrors are thrown where expected.
|
||||||
|
"""
|
||||||
|
# check the constructor
|
||||||
|
for bad_raysampler in (None, 5, []):
|
||||||
|
for bad_raymarcher in (None, 5, []):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ImplicitRenderer(
|
||||||
|
raysampler=bad_raysampler, raymarcher=bad_raymarcher
|
||||||
|
)
|
||||||
|
|
||||||
|
# init a trivial renderer
|
||||||
|
renderer = ImplicitRenderer(
|
||||||
|
raysampler=NDCGridRaysampler(
|
||||||
|
image_width=100,
|
||||||
|
image_height=100,
|
||||||
|
n_pts_per_ray=10,
|
||||||
|
min_depth=0.1,
|
||||||
|
max_depth=1.0,
|
||||||
|
),
|
||||||
|
raymarcher=EmissionAbsorptionRaymarcher(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# get default cameras
|
||||||
|
cameras = init_cameras()
|
||||||
|
|
||||||
|
for bad_volumetric_function in (None, 5, []):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
renderer(cameras=cameras, volumetric_function=bad_volumetric_function)
|
||||||
|
|
||||||
|
def test_compare_with_meshes_renderer(
|
||||||
|
self, batch_size=11, image_size=100, sphere_diameter=0.6
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a spherical RGB volumetric function and its corresponding mesh
|
||||||
|
and check whether MeshesRenderer returns the same images as the
|
||||||
|
corresponding ImplicitRenderer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# generate NDC camera extrinsics and intrinsics
|
||||||
|
cameras = init_cameras(
|
||||||
|
batch_size, image_size=[image_size, image_size], ndc=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# get rand offset of the volume
|
||||||
|
sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1
|
||||||
|
sphere_centroid.requires_grad = True
|
||||||
|
|
||||||
|
# init the grid raysampler with the ndc grid
|
||||||
|
raysampler = NDCGridRaysampler(
|
||||||
|
image_width=image_size,
|
||||||
|
image_height=image_size,
|
||||||
|
n_pts_per_ray=256,
|
||||||
|
min_depth=0.1,
|
||||||
|
max_depth=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the EA raymarcher
|
||||||
|
raymarcher = EmissionAbsorptionRaymarcher()
|
||||||
|
|
||||||
|
# jitter the camera intrinsics a bit for each render
|
||||||
|
cameras_randomized = cameras.clone()
|
||||||
|
cameras_randomized.principal_point = (
|
||||||
|
torch.randn_like(cameras.principal_point) * 0.3
|
||||||
|
)
|
||||||
|
cameras_randomized.focal_length = (
|
||||||
|
cameras.focal_length + torch.randn_like(cameras.focal_length) * 0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
# the list of differentiable camera vars
|
||||||
|
cam_vars = ("R", "T", "focal_length", "principal_point")
|
||||||
|
# enable the gradient caching for the camera variables
|
||||||
|
for cam_var in cam_vars:
|
||||||
|
getattr(cameras_randomized, cam_var).requires_grad = True
|
||||||
|
|
||||||
|
# get the implicit renderer
|
||||||
|
images_opacities = ImplicitRenderer(
|
||||||
|
raysampler=raysampler, raymarcher=raymarcher
|
||||||
|
)(
|
||||||
|
cameras=cameras_randomized,
|
||||||
|
volumetric_function=spherical_volumetric_function,
|
||||||
|
sphere_centroid=sphere_centroid,
|
||||||
|
sphere_diameter=sphere_diameter,
|
||||||
|
)[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
|
||||||
|
# check that the renderer does not erase gradients
|
||||||
|
loss = images_opacities.sum()
|
||||||
|
loss.backward()
|
||||||
|
for check_var in (
|
||||||
|
*[getattr(cameras_randomized, cam_var) for cam_var in cam_vars],
|
||||||
|
sphere_centroid,
|
||||||
|
):
|
||||||
|
self.assertIsNotNone(check_var.grad)
|
||||||
|
|
||||||
|
# instantiate the corresponding spherical mesh
|
||||||
|
ico = ico_sphere(level=4, device=cameras.device).extend(batch_size)
|
||||||
|
verts = (
|
||||||
|
torch.nn.functional.normalize(ico.verts_padded(), dim=-1) * sphere_diameter
|
||||||
|
+ sphere_centroid[:, None]
|
||||||
|
)
|
||||||
|
meshes = Meshes(
|
||||||
|
verts=verts,
|
||||||
|
faces=ico.faces_padded(),
|
||||||
|
textures=TexturesVertex(
|
||||||
|
verts_features=(
|
||||||
|
torch.nn.functional.normalize(verts, dim=-1) * 0.5 + 0.5
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# instantiate the corresponding mesh renderer
|
||||||
|
lights = PointLights(device=cameras.device, location=[[0.0, 0.0, 0.0]])
|
||||||
|
renderer_textured = MeshRenderer(
|
||||||
|
rasterizer=MeshRasterizer(
|
||||||
|
cameras=cameras_randomized,
|
||||||
|
raster_settings=RasterizationSettings(
|
||||||
|
image_size=image_size, blur_radius=1e-3, faces_per_pixel=10
|
||||||
|
),
|
||||||
|
),
|
||||||
|
shader=SoftPhongShader(
|
||||||
|
device=cameras.device,
|
||||||
|
cameras=cameras_randomized,
|
||||||
|
lights=lights,
|
||||||
|
materials=Materials(
|
||||||
|
ambient_color=((2.0, 2.0, 2.0),),
|
||||||
|
diffuse_color=((0.0, 0.0, 0.0),),
|
||||||
|
specular_color=((0.0, 0.0, 0.0),),
|
||||||
|
shininess=64,
|
||||||
|
device=cameras.device,
|
||||||
|
),
|
||||||
|
blend_params=BlendParams(
|
||||||
|
sigma=1e-3, gamma=1e-4, background_color=(0.0, 0.0, 0.0)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the mesh render
|
||||||
|
images_opacities_meshes = renderer_textured(
|
||||||
|
meshes, cameras=cameras_randomized, lights=lights
|
||||||
|
)
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
outdir = tempfile.gettempdir() + "/test_implicit_vs_mesh_renderer"
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for (image_opacity, image_opacity_mesh) in zip(
|
||||||
|
images_opacities, images_opacities_meshes
|
||||||
|
):
|
||||||
|
image, opacity = image_opacity.split([3, 1], dim=-1)
|
||||||
|
image_mesh, opacity_mesh = image_opacity_mesh.split([3, 1], dim=-1)
|
||||||
|
diff_image = (
|
||||||
|
((image - image_mesh) * 0.5 + 0.5)
|
||||||
|
.mean(dim=2, keepdim=True)
|
||||||
|
.repeat(1, 1, 3)
|
||||||
|
)
|
||||||
|
image_pil = Image.fromarray(
|
||||||
|
(
|
||||||
|
torch.cat(
|
||||||
|
(
|
||||||
|
image,
|
||||||
|
image_mesh,
|
||||||
|
diff_image,
|
||||||
|
opacity.repeat(1, 1, 3),
|
||||||
|
opacity_mesh.repeat(1, 1, 3),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
* 255.0
|
||||||
|
).astype(np.uint8)
|
||||||
|
)
|
||||||
|
frames.append(image_pil)
|
||||||
|
|
||||||
|
# export gif
|
||||||
|
outfile = os.path.join(outdir, "implicit_vs_mesh_render.gif")
|
||||||
|
frames[0].save(
|
||||||
|
outfile,
|
||||||
|
save_all=True,
|
||||||
|
append_images=frames[1:],
|
||||||
|
duration=batch_size // 15,
|
||||||
|
loop=0,
|
||||||
|
)
|
||||||
|
print(f"exported {outfile}")
|
||||||
|
|
||||||
|
# export concatenated frames
|
||||||
|
outfile_cat = os.path.join(outdir, "implicit_vs_mesh_render.png")
|
||||||
|
Image.fromarray(np.concatenate([np.array(f) for f in frames], axis=0)).save(
|
||||||
|
outfile_cat
|
||||||
|
)
|
||||||
|
print(f"exported {outfile_cat}")
|
||||||
|
|
||||||
|
# compare the renders
|
||||||
|
diff = (images_opacities - images_opacities_meshes).abs().mean(dim=-1)
|
||||||
|
mu_diff = diff.mean(dim=(1, 2))
|
||||||
|
std_diff = diff.std(dim=(1, 2))
|
||||||
|
self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=5e-2)
|
||||||
|
self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2)
|
||||||
|
|
||||||
|
def test_rotating_gif(
|
||||||
|
self, n_frames=50, fps=15, image_size=(100, 100), sphere_diameter=0.5
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Render a gif animation of a rotating sphere (runs only if `DEBUG==True`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not DEBUG:
|
||||||
|
# do not run this if debug is False
|
||||||
|
return
|
||||||
|
|
||||||
|
# generate camera extrinsics and intrinsics
|
||||||
|
cameras = init_cameras(n_frames, image_size=image_size)
|
||||||
|
|
||||||
|
# init the grid raysampler
|
||||||
|
raysampler = GridRaysampler(
|
||||||
|
min_x=0.5,
|
||||||
|
max_x=image_size[1] - 0.5,
|
||||||
|
min_y=0.5,
|
||||||
|
max_y=image_size[0] - 0.5,
|
||||||
|
image_width=image_size[1],
|
||||||
|
image_height=image_size[0],
|
||||||
|
n_pts_per_ray=256,
|
||||||
|
min_depth=0.1,
|
||||||
|
max_depth=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the EA raymarcher
|
||||||
|
raymarcher = EmissionAbsorptionRaymarcher()
|
||||||
|
|
||||||
|
# get the implicit render
|
||||||
|
renderer = ImplicitRenderer(raysampler=raysampler, raymarcher=raymarcher)
|
||||||
|
|
||||||
|
# get the (0) centroid of the sphere
|
||||||
|
sphere_centroid = torch.zeros(n_frames, 3, device=cameras.device) * 0.1
|
||||||
|
|
||||||
|
# run the renderer
|
||||||
|
images_opacities = renderer(
|
||||||
|
cameras=cameras,
|
||||||
|
volumetric_function=spherical_volumetric_function,
|
||||||
|
sphere_centroid=sphere_centroid,
|
||||||
|
sphere_diameter=sphere_diameter,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# split output to the alpha channel and rendered images
|
||||||
|
images, opacities = images_opacities[..., :3], images_opacities[..., 3]
|
||||||
|
|
||||||
|
# export the gif
|
||||||
|
outdir = tempfile.gettempdir() + "/test_implicit_renderer_gifs"
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
frames = []
|
||||||
|
for image, opacity in zip(images, opacities):
|
||||||
|
image_pil = Image.fromarray(
|
||||||
|
(
|
||||||
|
torch.cat(
|
||||||
|
(image, opacity[..., None].clamp(0.0, 1.0).repeat(1, 1, 3)),
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
* 255.0
|
||||||
|
).astype(np.uint8)
|
||||||
|
)
|
||||||
|
frames.append(image_pil)
|
||||||
|
outfile = os.path.join(outdir, "rotating_sphere.gif")
|
||||||
|
frames[0].save(
|
||||||
|
outfile,
|
||||||
|
save_all=True,
|
||||||
|
append_images=frames[1:],
|
||||||
|
duration=n_frames // fps,
|
||||||
|
loop=0,
|
||||||
|
)
|
||||||
|
print(f"exported {outfile}")
|
711
tests/test_render_volumes.py
Normal file
711
tests/test_render_volumes.py
Normal file
@ -0,0 +1,711 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
import unittest
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.ops import knn_points
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
AbsorptionOnlyRaymarcher,
|
||||||
|
AlphaCompositor,
|
||||||
|
EmissionAbsorptionRaymarcher,
|
||||||
|
GridRaysampler,
|
||||||
|
MonteCarloRaysampler,
|
||||||
|
NDCGridRaysampler,
|
||||||
|
PerspectiveCameras,
|
||||||
|
PointsRasterizationSettings,
|
||||||
|
PointsRasterizer,
|
||||||
|
PointsRenderer,
|
||||||
|
RayBundle,
|
||||||
|
VolumeRenderer,
|
||||||
|
VolumeSampler,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.implicit.utils import _validate_ray_bundle_variables
|
||||||
|
from pytorch3d.structures import Pointclouds, Volumes
|
||||||
|
from test_points_to_volumes import init_uniform_y_rotations
|
||||||
|
|
||||||
|
|
||||||
|
DEBUG = False
|
||||||
|
if DEBUG:
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
ZERO_TRANSLATION = torch.zeros(1, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def init_boundary_volume(
|
||||||
|
batch_size: int,
|
||||||
|
volume_size: Tuple[int, int, int],
|
||||||
|
border_offset: int = 2,
|
||||||
|
shape: str = "cube",
|
||||||
|
volume_translation: torch.Tensor = ZERO_TRANSLATION,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a volume with sides colored with distinct colors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
# first center the volume for the purpose of generating the canonical shape
|
||||||
|
volume_translation_tmp = (0.0, 0.0, 0.0)
|
||||||
|
|
||||||
|
# set the voxel size to 1 / (volume_size-1)
|
||||||
|
volume_voxel_size = 1 / (volume_size[0] - 1.0)
|
||||||
|
|
||||||
|
# colors of the sides of the cube
|
||||||
|
clr_sides = torch.tensor(
|
||||||
|
[
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
[1.0, 0.0, 0.0],
|
||||||
|
[1.0, 0.0, 1.0],
|
||||||
|
[1.0, 1.0, 0.0],
|
||||||
|
[0.0, 1.0, 0.0],
|
||||||
|
[0.0, 1.0, 1.0],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the coord grid of the volume
|
||||||
|
coord_grid = Volumes(
|
||||||
|
densities=torch.zeros(1, 1, *volume_size, device=device),
|
||||||
|
voxel_size=volume_voxel_size,
|
||||||
|
volume_translation=volume_translation_tmp,
|
||||||
|
).get_coord_grid()[0]
|
||||||
|
|
||||||
|
# extract the boundary points and their colors of the cube
|
||||||
|
if shape == "cube":
|
||||||
|
boundary_points, boundary_colors = [], []
|
||||||
|
for side, clr_side in enumerate(clr_sides):
|
||||||
|
first = side % 2
|
||||||
|
dim = side // 2
|
||||||
|
slices = [slice(border_offset, -border_offset, 1)] * 3
|
||||||
|
slices[dim] = int(border_offset * (2 * first - 1))
|
||||||
|
slices.append(slice(0, 3, 1))
|
||||||
|
boundary_points_ = coord_grid[slices].reshape(-1, 3)
|
||||||
|
boundary_points.append(boundary_points_)
|
||||||
|
boundary_colors.append(clr_side[None].expand_as(boundary_points_))
|
||||||
|
# set the internal part of the volume to be completely opaque
|
||||||
|
volume_densities = torch.zeros(*volume_size, device=device)
|
||||||
|
volume_densities[[slice(border_offset, -border_offset, 1)] * 3] = 1.0
|
||||||
|
boundary_points, boundary_colors = [
|
||||||
|
torch.cat(p, dim=0) for p in [boundary_points, boundary_colors]
|
||||||
|
]
|
||||||
|
# color the volume voxels with the nearest boundary points' color
|
||||||
|
_, idx, _ = knn_points(
|
||||||
|
coord_grid.view(1, -1, 3), boundary_points.view(1, -1, 3)
|
||||||
|
)
|
||||||
|
volume_colors = (
|
||||||
|
boundary_colors[idx.view(-1)].view(*volume_size, 3).permute(3, 0, 1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif shape == "sphere":
|
||||||
|
# set all voxels within a certain distance from the origin to be opaque
|
||||||
|
volume_densities = (
|
||||||
|
coord_grid.norm(dim=-1)
|
||||||
|
<= 0.5 * volume_voxel_size * (volume_size[0] - border_offset)
|
||||||
|
).float()
|
||||||
|
# color each voxel with the standrd spherical color
|
||||||
|
volume_colors = (
|
||||||
|
(torch.nn.functional.normalize(coord_grid, dim=-1) + 1.0) * 0.5
|
||||||
|
).permute(3, 0, 1, 2)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(shape)
|
||||||
|
|
||||||
|
volume_voxel_size = torch.ones((batch_size, 1), device=device) * volume_voxel_size
|
||||||
|
volume_translation = volume_translation.expand(batch_size, 3)
|
||||||
|
volumes = Volumes(
|
||||||
|
densities=volume_densities[None, None].expand(batch_size, 1, *volume_size),
|
||||||
|
features=volume_colors[None].expand(batch_size, 3, *volume_size),
|
||||||
|
voxel_size=volume_voxel_size,
|
||||||
|
volume_translation=volume_translation,
|
||||||
|
)
|
||||||
|
|
||||||
|
return volumes, volume_voxel_size, volume_translation
|
||||||
|
|
||||||
|
|
||||||
|
def init_cameras(
|
||||||
|
batch_size: int = 10,
|
||||||
|
image_size: Optional[Tuple[int, int]] = (50, 50),
|
||||||
|
ndc: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a batch of cameras whose extrinsics rotate the cameras around
|
||||||
|
the world's y axis.
|
||||||
|
Depending on whether we want an NDC-space (`ndc==True`) or a screen-space camera,
|
||||||
|
the camera's focal length and principal point are initialized accordingly:
|
||||||
|
For `ndc==False`, p0=focal_length=image_size/2.
|
||||||
|
For `ndc==True`, focal_length=1.0, p0 = 0.0.
|
||||||
|
The the z-coordinate of the translation vector of each camera is fixed to 1.5.
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
# trivial rotations
|
||||||
|
R = init_uniform_y_rotations(batch_size).to(device)
|
||||||
|
|
||||||
|
# move camera 1.5 m away from the scene center
|
||||||
|
T = torch.zeros((batch_size, 3), device=device)
|
||||||
|
T[:, 2] = 1.5
|
||||||
|
|
||||||
|
if ndc:
|
||||||
|
p0 = torch.zeros(batch_size, 2, device=device)
|
||||||
|
focal = torch.ones(batch_size, device=device)
|
||||||
|
else:
|
||||||
|
p0 = torch.ones(batch_size, 2, device=device)
|
||||||
|
p0[:, 0] *= image_size[1] * 0.5
|
||||||
|
p0[:, 1] *= image_size[0] * 0.5
|
||||||
|
focal = image_size[0] * torch.ones(batch_size, device=device)
|
||||||
|
|
||||||
|
# convert to a Camera object
|
||||||
|
cameras = PerspectiveCameras(focal, p0, R=R, T=T, device=device)
|
||||||
|
return cameras
|
||||||
|
|
||||||
|
|
||||||
|
class TestRenderVolumes(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super().setUp()
|
||||||
|
torch.manual_seed(42)
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def renderer(
|
||||||
|
volume_size=(25, 25, 25),
|
||||||
|
batch_size=10,
|
||||||
|
shape="sphere",
|
||||||
|
raymarcher_type=EmissionAbsorptionRaymarcher,
|
||||||
|
n_rays_per_image=10,
|
||||||
|
n_pts_per_ray=10,
|
||||||
|
):
|
||||||
|
# get the volumes
|
||||||
|
volumes = init_boundary_volume(
|
||||||
|
volume_size=volume_size, batch_size=batch_size, shape=shape
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# init the mc raysampler
|
||||||
|
raysampler = MonteCarloRaysampler(
|
||||||
|
min_x=-1.0,
|
||||||
|
max_x=1.0,
|
||||||
|
min_y=-1.0,
|
||||||
|
max_y=1.0,
|
||||||
|
n_rays_per_image=n_rays_per_image,
|
||||||
|
n_pts_per_ray=n_pts_per_ray,
|
||||||
|
min_depth=0.1,
|
||||||
|
max_depth=2.0,
|
||||||
|
).to(volumes.device)
|
||||||
|
|
||||||
|
# get the raymarcher
|
||||||
|
raymarcher = raymarcher_type()
|
||||||
|
|
||||||
|
renderer = VolumeRenderer(
|
||||||
|
raysampler=raysampler, raymarcher=raymarcher, sample_mode="bilinear"
|
||||||
|
)
|
||||||
|
|
||||||
|
# generate NDC camera extrinsics and intrinsics
|
||||||
|
cameras = init_cameras(batch_size, image_size=None, ndc=True)
|
||||||
|
|
||||||
|
def run_renderer():
|
||||||
|
renderer(cameras=cameras, volumes=volumes)
|
||||||
|
|
||||||
|
return run_renderer
|
||||||
|
|
||||||
|
def test_input_types(self, batch_size: int = 10):
|
||||||
|
"""
|
||||||
|
Check that ValueErrors are thrown where expected.
|
||||||
|
"""
|
||||||
|
# check the constructor
|
||||||
|
for bad_raysampler in (None, 5, []):
|
||||||
|
for bad_raymarcher in (None, 5, []):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
VolumeRenderer(raysampler=bad_raysampler, raymarcher=bad_raymarcher)
|
||||||
|
|
||||||
|
raysampler = NDCGridRaysampler(
|
||||||
|
image_width=100,
|
||||||
|
image_height=100,
|
||||||
|
n_pts_per_ray=10,
|
||||||
|
min_depth=0.1,
|
||||||
|
max_depth=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init a trivial renderer
|
||||||
|
renderer = VolumeRenderer(
|
||||||
|
raysampler=raysampler, raymarcher=EmissionAbsorptionRaymarcher()
|
||||||
|
)
|
||||||
|
|
||||||
|
# get cameras
|
||||||
|
cameras = init_cameras(batch_size=batch_size)
|
||||||
|
|
||||||
|
# get volumes
|
||||||
|
volumes = init_boundary_volume(volume_size=(10, 10, 10), batch_size=batch_size)[
|
||||||
|
0
|
||||||
|
]
|
||||||
|
|
||||||
|
# different batch sizes for cameras / volumes
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
renderer(cameras=cameras, volumes=volumes[:-1])
|
||||||
|
|
||||||
|
# ray checks for VolumeSampler
|
||||||
|
volume_sampler = VolumeSampler(volumes=volumes)
|
||||||
|
n_rays = 100
|
||||||
|
for bad_ray_bundle in (
|
||||||
|
(
|
||||||
|
torch.rand(batch_size, n_rays, 3),
|
||||||
|
torch.rand(batch_size, n_rays + 1, 3),
|
||||||
|
torch.rand(batch_size, n_rays, 10),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
torch.rand(batch_size + 1, n_rays, 3),
|
||||||
|
torch.rand(batch_size, n_rays, 3),
|
||||||
|
torch.rand(batch_size, n_rays, 10),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
torch.rand(batch_size, n_rays, 3),
|
||||||
|
torch.rand(batch_size, n_rays, 2),
|
||||||
|
torch.rand(batch_size, n_rays, 10),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
torch.rand(batch_size, n_rays, 3),
|
||||||
|
torch.rand(batch_size, n_rays, 3),
|
||||||
|
torch.rand(batch_size, n_rays),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
ray_bundle = RayBundle(
|
||||||
|
**dict(
|
||||||
|
zip(
|
||||||
|
("origins", "directions", "lengths"),
|
||||||
|
[r.to(cameras.device) for r in bad_ray_bundle],
|
||||||
|
)
|
||||||
|
),
|
||||||
|
xys=None,
|
||||||
|
)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
volume_sampler(ray_bundle)
|
||||||
|
|
||||||
|
# check also explicitly the ray bundle validation function
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_validate_ray_bundle_variables(*bad_ray_bundle)
|
||||||
|
|
||||||
|
def test_compare_with_pointclouds_renderer(
|
||||||
|
self, batch_size=11, volume_size=(30, 30, 30), image_size=200
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a volume and its corresponding point cloud and check whether
|
||||||
|
PointsRenderer returns the same images as the corresponding VolumeRenderer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# generate NDC camera extrinsics and intrinsics
|
||||||
|
cameras = init_cameras(
|
||||||
|
batch_size, image_size=[image_size, image_size], ndc=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# init the boundary volume
|
||||||
|
for shape in ("sphere", "cube"):
|
||||||
|
|
||||||
|
if not DEBUG and shape == "cube":
|
||||||
|
# do not run numeric checks for the cube as the
|
||||||
|
# differences in rendering equations make the renders incomparable
|
||||||
|
continue
|
||||||
|
|
||||||
|
# get rand offset of the volume
|
||||||
|
volume_translation = torch.randn(batch_size, 3) * 0.1
|
||||||
|
# volume_translation[2] = 0.1
|
||||||
|
volumes = init_boundary_volume(
|
||||||
|
volume_size=volume_size,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shape=shape,
|
||||||
|
volume_translation=volume_translation,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# convert the volumes to a pointcloud
|
||||||
|
points = []
|
||||||
|
points_features = []
|
||||||
|
for densities_one, features_one, grid_one in zip(
|
||||||
|
volumes.densities(),
|
||||||
|
volumes.features(),
|
||||||
|
volumes.get_coord_grid(world_coordinates=True),
|
||||||
|
):
|
||||||
|
opaque = densities_one.view(-1) > 1e-4
|
||||||
|
points.append(grid_one.view(-1, 3)[opaque])
|
||||||
|
points_features.append(features_one.reshape(3, -1).t()[opaque])
|
||||||
|
pointclouds = Pointclouds(points, features=points_features)
|
||||||
|
|
||||||
|
# init the grid raysampler with the ndc grid
|
||||||
|
coord_range = 1.0
|
||||||
|
half_pix_size = coord_range / image_size
|
||||||
|
raysampler = NDCGridRaysampler(
|
||||||
|
image_width=image_size,
|
||||||
|
image_height=image_size,
|
||||||
|
n_pts_per_ray=256,
|
||||||
|
min_depth=0.1,
|
||||||
|
max_depth=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the EA raymarcher
|
||||||
|
raymarcher = EmissionAbsorptionRaymarcher()
|
||||||
|
|
||||||
|
# jitter the camera intrinsics a bit for each render
|
||||||
|
cameras_randomized = cameras.clone()
|
||||||
|
cameras_randomized.principal_point = (
|
||||||
|
torch.randn_like(cameras.principal_point) * 0.3
|
||||||
|
)
|
||||||
|
cameras_randomized.focal_length = (
|
||||||
|
cameras.focal_length + torch.randn_like(cameras.focal_length) * 0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the volumetric render
|
||||||
|
images = VolumeRenderer(
|
||||||
|
raysampler=raysampler, raymarcher=raymarcher, sample_mode="bilinear"
|
||||||
|
)(cameras=cameras_randomized, volumes=volumes)[0][..., :3]
|
||||||
|
|
||||||
|
# instantiate the points renderer
|
||||||
|
point_radius = 6 * half_pix_size
|
||||||
|
points_renderer = PointsRenderer(
|
||||||
|
rasterizer=PointsRasterizer(
|
||||||
|
cameras=cameras_randomized,
|
||||||
|
raster_settings=PointsRasterizationSettings(
|
||||||
|
image_size=image_size, radius=point_radius, points_per_pixel=10
|
||||||
|
),
|
||||||
|
),
|
||||||
|
compositor=AlphaCompositor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the point render
|
||||||
|
images_pts = points_renderer(pointclouds)
|
||||||
|
|
||||||
|
if shape == "sphere":
|
||||||
|
diff = (images - images_pts).abs().mean(dim=-1)
|
||||||
|
mu_diff = diff.mean(dim=(1, 2))
|
||||||
|
std_diff = diff.std(dim=(1, 2))
|
||||||
|
self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=3e-2)
|
||||||
|
self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2)
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
outdir = tempfile.gettempdir() + "/test_volume_vs_pts_renderer"
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
|
||||||
|
frames = []
|
||||||
|
for (image, image_pts) in zip(images, images_pts):
|
||||||
|
diff_image = (
|
||||||
|
((image - image_pts) * 0.5 + 0.5)
|
||||||
|
.mean(dim=2, keepdim=True)
|
||||||
|
.repeat(1, 1, 3)
|
||||||
|
)
|
||||||
|
image_pil = Image.fromarray(
|
||||||
|
(
|
||||||
|
torch.cat((image, image_pts, diff_image), dim=1)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
* 255.0
|
||||||
|
).astype(np.uint8)
|
||||||
|
)
|
||||||
|
frames.append(image_pil)
|
||||||
|
|
||||||
|
# export gif
|
||||||
|
outfile = os.path.join(outdir, f"volume_vs_pts_render_{shape}.gif")
|
||||||
|
frames[0].save(
|
||||||
|
outfile,
|
||||||
|
save_all=True,
|
||||||
|
append_images=frames[1:],
|
||||||
|
duration=batch_size // 15,
|
||||||
|
loop=0,
|
||||||
|
)
|
||||||
|
print(f"exported {outfile}")
|
||||||
|
|
||||||
|
# export concatenated frames
|
||||||
|
outfile_cat = os.path.join(outdir, f"volume_vs_pts_render_{shape}.png")
|
||||||
|
Image.fromarray(
|
||||||
|
np.concatenate([np.array(f) for f in frames], axis=0)
|
||||||
|
).save(outfile_cat)
|
||||||
|
print(f"exported {outfile_cat}")
|
||||||
|
|
||||||
|
def test_monte_carlo_rendering(
|
||||||
|
self, n_frames=20, volume_size=(30, 30, 30), image_size=(40, 50)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Tests that rendering with the MonteCarloRaysampler matches the
|
||||||
|
rendering with GridRaysampler sampled at the corresponding
|
||||||
|
MonteCarlo locations.
|
||||||
|
"""
|
||||||
|
volumes = init_boundary_volume(
|
||||||
|
volume_size=volume_size, batch_size=n_frames, shape="sphere"
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# generate camera extrinsics and intrinsics
|
||||||
|
cameras = init_cameras(n_frames, image_size=image_size)
|
||||||
|
|
||||||
|
# init the grid raysampler
|
||||||
|
raysampler_grid = GridRaysampler(
|
||||||
|
min_x=0.5,
|
||||||
|
max_x=image_size[1] - 0.5,
|
||||||
|
min_y=0.5,
|
||||||
|
max_y=image_size[0] - 0.5,
|
||||||
|
image_width=image_size[1],
|
||||||
|
image_height=image_size[0],
|
||||||
|
n_pts_per_ray=256,
|
||||||
|
min_depth=0.5,
|
||||||
|
max_depth=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init the mc raysampler
|
||||||
|
raysampler_mc = MonteCarloRaysampler(
|
||||||
|
min_x=0.5,
|
||||||
|
max_x=image_size[1] - 0.5,
|
||||||
|
min_y=0.5,
|
||||||
|
max_y=image_size[0] - 0.5,
|
||||||
|
n_rays_per_image=3000,
|
||||||
|
n_pts_per_ray=256,
|
||||||
|
min_depth=0.5,
|
||||||
|
max_depth=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the EA raymarcher
|
||||||
|
raymarcher = EmissionAbsorptionRaymarcher()
|
||||||
|
|
||||||
|
# get both mc and grid renders
|
||||||
|
(
|
||||||
|
(images_opacities_mc, ray_bundle_mc),
|
||||||
|
(images_opacities_grid, ray_bundle_grid),
|
||||||
|
) = [
|
||||||
|
VolumeRenderer(
|
||||||
|
raysampler=raysampler_grid,
|
||||||
|
raymarcher=raymarcher,
|
||||||
|
sample_mode="bilinear",
|
||||||
|
)(cameras=cameras, volumes=volumes)
|
||||||
|
for raysampler in (raysampler_mc, raysampler_grid)
|
||||||
|
]
|
||||||
|
|
||||||
|
# convert the mc sampling locations to [-1, 1]
|
||||||
|
sample_loc = ray_bundle_mc.xys.clone()
|
||||||
|
sample_loc[..., 0] = 2 * (sample_loc[..., 0] / image_size[1]) - 1
|
||||||
|
sample_loc[..., 1] = 2 * (sample_loc[..., 1] / image_size[0]) - 1
|
||||||
|
|
||||||
|
# sample the grid render at the mc locations
|
||||||
|
images_opacities_mc_ = torch.nn.functional.grid_sample(
|
||||||
|
images_opacities_grid.permute(0, 3, 1, 2), sample_loc, align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that the samples are the same
|
||||||
|
self.assertClose(
|
||||||
|
images_opacities_mc.permute(0, 3, 1, 2), images_opacities_mc_, atol=1e-4
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_rotating_gif(
|
||||||
|
self, n_frames=50, fps=15, volume_size=(100, 100, 100), image_size=(100, 100)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Render a gif animation of a rotating cube/sphere (runs only if `DEBUG==True`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not DEBUG:
|
||||||
|
# do not run this if debug is False
|
||||||
|
return
|
||||||
|
|
||||||
|
for shape in ("sphere", "cube"):
|
||||||
|
for sample_mode in ("bilinear", "nearest"):
|
||||||
|
|
||||||
|
volumes = init_boundary_volume(
|
||||||
|
volume_size=volume_size, batch_size=n_frames, shape=shape
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# generate camera extrinsics and intrinsics
|
||||||
|
cameras = init_cameras(n_frames, image_size=image_size)
|
||||||
|
|
||||||
|
# init the grid raysampler
|
||||||
|
raysampler = GridRaysampler(
|
||||||
|
min_x=0.5,
|
||||||
|
max_x=image_size[1] - 0.5,
|
||||||
|
min_y=0.5,
|
||||||
|
max_y=image_size[0] - 0.5,
|
||||||
|
image_width=image_size[1],
|
||||||
|
image_height=image_size[0],
|
||||||
|
n_pts_per_ray=256,
|
||||||
|
min_depth=0.5,
|
||||||
|
max_depth=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the EA raymarcher
|
||||||
|
raymarcher = EmissionAbsorptionRaymarcher()
|
||||||
|
|
||||||
|
# intialize the renderer
|
||||||
|
renderer = VolumeRenderer(
|
||||||
|
raysampler=raysampler,
|
||||||
|
raymarcher=raymarcher,
|
||||||
|
sample_mode=sample_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
# run the renderer
|
||||||
|
images_opacities = renderer(cameras=cameras, volumes=volumes)[0]
|
||||||
|
|
||||||
|
# split output to the alpha channel and rendered images
|
||||||
|
images, opacities = images_opacities[..., :3], images_opacities[..., 3]
|
||||||
|
|
||||||
|
# export the gif
|
||||||
|
outdir = tempfile.gettempdir() + "/test_volume_renderer_gifs"
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
frames = []
|
||||||
|
for image, opacity in zip(images, opacities):
|
||||||
|
image_pil = Image.fromarray(
|
||||||
|
(
|
||||||
|
torch.cat(
|
||||||
|
(image, opacity[..., None].repeat(1, 1, 3)), dim=1
|
||||||
|
)
|
||||||
|
.detach()
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
* 255.0
|
||||||
|
).astype(np.uint8)
|
||||||
|
)
|
||||||
|
frames.append(image_pil)
|
||||||
|
outfile = os.path.join(outdir, f"{shape}_{sample_mode}.gif")
|
||||||
|
frames[0].save(
|
||||||
|
outfile,
|
||||||
|
save_all=True,
|
||||||
|
append_images=frames[1:],
|
||||||
|
duration=n_frames // fps,
|
||||||
|
loop=0,
|
||||||
|
)
|
||||||
|
print(f"exported {outfile}")
|
||||||
|
|
||||||
|
def test_rotating_cube_volume_render(self):
|
||||||
|
"""
|
||||||
|
Generates 4 renders of 4 sides of a volume representing a 3D cube.
|
||||||
|
Since each side of the cube is homogenously colored with
|
||||||
|
a different color, this should result in 4 images of homogenous color
|
||||||
|
with the depth of each pixel equal to a constant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# batch_size = 4 sides of the cube
|
||||||
|
batch_size = 4
|
||||||
|
image_size = (50, 50)
|
||||||
|
|
||||||
|
for volume_size in ([25, 25, 25],):
|
||||||
|
for sample_mode in ("bilinear", "nearest"):
|
||||||
|
|
||||||
|
volume_translation = torch.zeros(4, 3)
|
||||||
|
volume_translation.requires_grad = True
|
||||||
|
volumes, volume_voxel_size, _ = init_boundary_volume(
|
||||||
|
volume_size=volume_size,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shape="cube",
|
||||||
|
volume_translation=volume_translation,
|
||||||
|
)
|
||||||
|
|
||||||
|
# generate camera extrinsics and intrinsics
|
||||||
|
cameras = init_cameras(batch_size, image_size=image_size)
|
||||||
|
|
||||||
|
# enable the gradient caching for the camera variables
|
||||||
|
# the list of differentiable camera vars
|
||||||
|
cam_vars = ("R", "T", "focal_length", "principal_point")
|
||||||
|
for cam_var in cam_vars:
|
||||||
|
getattr(cameras, cam_var).requires_grad = True
|
||||||
|
# enable the grad for volume vars as well
|
||||||
|
volumes.features().requires_grad = True
|
||||||
|
volumes.densities().requires_grad = True
|
||||||
|
|
||||||
|
raysampler = GridRaysampler(
|
||||||
|
min_x=0.5,
|
||||||
|
max_x=image_size[1] - 0.5,
|
||||||
|
min_y=0.5,
|
||||||
|
max_y=image_size[0] - 0.5,
|
||||||
|
image_width=image_size[1],
|
||||||
|
image_height=image_size[0],
|
||||||
|
n_pts_per_ray=128,
|
||||||
|
min_depth=0.01,
|
||||||
|
max_depth=3.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
raymarcher = EmissionAbsorptionRaymarcher()
|
||||||
|
renderer = VolumeRenderer(
|
||||||
|
raysampler=raysampler,
|
||||||
|
raymarcher=raymarcher,
|
||||||
|
sample_mode=sample_mode,
|
||||||
|
)
|
||||||
|
images_opacities = renderer(cameras=cameras, volumes=volumes)[0]
|
||||||
|
images, opacities = images_opacities[..., :3], images_opacities[..., 3]
|
||||||
|
|
||||||
|
# check that the renderer does not erase gradients
|
||||||
|
loss = images_opacities.sum()
|
||||||
|
loss.backward()
|
||||||
|
for check_var in (
|
||||||
|
*[getattr(cameras, cam_var) for cam_var in cam_vars],
|
||||||
|
volumes.features(),
|
||||||
|
volumes.densities(),
|
||||||
|
volume_translation,
|
||||||
|
):
|
||||||
|
self.assertIsNotNone(check_var.grad)
|
||||||
|
|
||||||
|
# ao opacities should be exactly the same as the ea ones
|
||||||
|
# we can further get the ea opacities from a feature-less
|
||||||
|
# version of our volumes
|
||||||
|
raymarcher_ao = AbsorptionOnlyRaymarcher()
|
||||||
|
renderer_ao = VolumeRenderer(
|
||||||
|
raysampler=raysampler,
|
||||||
|
raymarcher=raymarcher_ao,
|
||||||
|
sample_mode=sample_mode,
|
||||||
|
)
|
||||||
|
volumes_featureless = Volumes(
|
||||||
|
densities=volumes.densities(),
|
||||||
|
volume_translation=volume_translation,
|
||||||
|
voxel_size=volume_voxel_size,
|
||||||
|
)
|
||||||
|
opacities_ao = renderer_ao(
|
||||||
|
cameras=cameras, volumes=volumes_featureless
|
||||||
|
)[0][..., 0]
|
||||||
|
self.assertClose(opacities, opacities_ao)
|
||||||
|
|
||||||
|
# colors of the sides of the cube
|
||||||
|
gt_clr_sides = torch.tensor(
|
||||||
|
[
|
||||||
|
[1.0, 0.0, 0.0],
|
||||||
|
[0.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 1.0],
|
||||||
|
[0.0, 1.0, 0.0],
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=images.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
outdir = tempfile.gettempdir() + "/test_volume_renderer"
|
||||||
|
os.makedirs(outdir, exist_ok=True)
|
||||||
|
for imidx, (image, opacity) in enumerate(zip(images, opacities)):
|
||||||
|
for image_ in (image, opacity):
|
||||||
|
image_pil = Image.fromarray(
|
||||||
|
(image_.detach().cpu().numpy() * 255.0).astype(np.uint8)
|
||||||
|
)
|
||||||
|
outfile = (
|
||||||
|
outdir
|
||||||
|
+ f"/rgb_{sample_mode}"
|
||||||
|
+ f"_{str(volume_size).replace(' ','')}"
|
||||||
|
+ f"_{imidx:003d}"
|
||||||
|
)
|
||||||
|
if image_ is image:
|
||||||
|
outfile += "_rgb.png"
|
||||||
|
else:
|
||||||
|
outfile += "_opacity.png"
|
||||||
|
image_pil.save(outfile)
|
||||||
|
print(f"exported {outfile}")
|
||||||
|
|
||||||
|
border = 10
|
||||||
|
for image, opacity, gt_color in zip(images, opacities, gt_clr_sides):
|
||||||
|
image_crop = image[border:-border, border:-border]
|
||||||
|
opacity_crop = opacity[border:-border, border:-border]
|
||||||
|
|
||||||
|
# check mean and std difference from gt
|
||||||
|
err = (
|
||||||
|
(image_crop - gt_color[None, None].expand_as(image_crop))
|
||||||
|
.abs()
|
||||||
|
.mean(dim=-1)
|
||||||
|
)
|
||||||
|
zero = err.new_zeros(1)[0]
|
||||||
|
self.assertClose(err.mean(), zero, atol=1e-2)
|
||||||
|
self.assertClose(err.std(), zero, atol=1e-2)
|
||||||
|
|
||||||
|
err_opacity = (opacity_crop - 1.0).abs()
|
||||||
|
self.assertClose(err_opacity.mean(), zero, atol=1e-2)
|
||||||
|
self.assertClose(err_opacity.std(), zero, atol=1e-2)
|
Loading…
x
Reference in New Issue
Block a user