Implicit/Volume renderer

Summary: Implements the `ImplicitRenderer` and `VolumeRenderer`.

Reviewed By: gkioxari

Differential Revision: D24418791

fbshipit-source-id: 127f21186d8e210895db1dcd0681f09f230d81a4
This commit is contained in:
David Novotny 2021-01-06 06:21:50 -08:00 committed by Facebook GitHub Bot
parent e6a32bfc37
commit b466c381da
8 changed files with 1575 additions and 3 deletions

View File

@ -24,9 +24,12 @@ from .implicit import (
AbsorptionOnlyRaymarcher,
EmissionAbsorptionRaymarcher,
GridRaysampler,
ImplicitRenderer,
MonteCarloRaysampler,
NDCGridRaysampler,
RayBundle,
VolumeRenderer,
VolumeSampler,
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
)

View File

@ -2,6 +2,7 @@
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler
from .utils import (
RayBundle,
ray_bundle_to_ray_points,

View File

@ -0,0 +1,372 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Callable, Tuple
import torch
from ...ops.utils import eyes
from ...structures import Volumes
from ...transforms import Transform3d
from ..cameras import CamerasBase
from .raysampling import RayBundle
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points
# The implicit renderer class should be initialized with a
# function for raysampling and a function for raymarching.
# During the forward pass:
# 1) The raysampler:
# - samples rays from input cameras
# - transforms the rays to world coordinates
# 2) The volumetric_function (which is a callable argument of the forwad pass)
# evaluates ray_densities and ray_features at the sampled ray-points.
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
# algorithm to render each ray.
class ImplicitRenderer(torch.nn.Module):
"""
A class for rendering a batch of implicit surfaces. The class should
be initialized with a raysampler and raymarcher class which both have
to be a `Callable`.
VOLUMETRIC_FUNCTION
The `forward` function of the renderer accepts as input the rendering cameras as well
as the `volumetric_function` `Callable`, which defines a field of opacity
and feature vectors over the 3D domain of the scene.
A standard `volumetric_function` has the following signature:
```
def volumetric_function(ray_bundle: RayBundle) -> Tuple[torch.Tensor, torch.Tensor]
```
With the following arguments:
`ray_bundle`: A RayBundle object containing the following variables:
`rays_origins`: A tensor of shape `(minibatch, ..., 3)` denoting
the origins of the rendering rays.
`rays_directions`: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of rendering rays.
`rays_lengths`: A tensor of shape
`(minibatch, ..., num_points_per_ray)`containing the
lengths at which the ray points are sampled.
Calling `volumetric_function` then returns the following:
`rays_densities`: A tensor of shape
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing
the an opacity vector for each ray point.
`rays_features`: A tensor of shape
`(minibatch, ..., num_points_per_ray, feature_dim)` containing
the an feature vector for each ray point.
Example:
A simple volumetric function of a 0-centered
RGB sphere with a unit diameter is defined as follows:
```
def volumetric_function(
ray_bundle: RayBundle,
) -> Tuple[torch.Tensor, torch.Tensor]:
# first convert the ray origins, directions and lengths
# to 3D ray point locations in world coords
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# set the densities as an inverse sigmoid of the
# ray point distance from the sphere centroid
rays_densities = torch.sigmoid(
-100.0 * rays_points_world.norm(dim=-1, keepdim=True)
)
# set the ray features to RGB colors proportional
# to the 3D location of the projection of ray points
# on the sphere surface
rays_features = torch.nn.functional.normalize(
rays_points_world, dim=-1
) * 0.5 + 0.5
return rays_densities, rays_features
```
"""
def __init__(self, raysampler: Callable, raymarcher: Callable):
"""
Args:
raysampler: A `Callable` that takes as input scene cameras
(an instance of `CamerasBase`) and returns a `RayBundle` that
describes the rays emitted from the cameras.
raymarcher: A `Callable` that receives the response of the
`volumetric_function` (an input to `self.forward`) evaluated
along the sampled rays, and renders the rays with a
ray-marching algorithm.
"""
super().__init__()
if not callable(raysampler):
raise ValueError('"raysampler" has to be a "Callable" object.')
if not callable(raymarcher):
raise ValueError('"raymarcher" has to be a "Callable" object.')
self.raysampler = raysampler
self.raymarcher = raymarcher
def forward(
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs
) -> Tuple[torch.Tensor, RayBundle]:
"""
Render a batch of images using a volumetric function
represented as a callable (e.g. a Pytorch module).
Args:
cameras: A batch of cameras that render the scene. A `self.raysampler`
takes the cameras as input and samples rays that pass through the
domain of the volumentric function.
volumetric_function: A `Callable` that accepts the parametrizations
of the rendering rays and returns the densities and features
at the respective 3D of the rendering rays. Please refer to
the main class documentation for details.
Returns:
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)`
containing the result of the rendering.
ray_bundle: A `RayBundle` containing the parametrizations of the
sampled rendering rays.
"""
if not callable(volumetric_function):
raise ValueError('"volumetric_function" has to be a "Callable" object.')
# first call the ray sampler that returns the RayBundle parametrizing
# the rendering rays.
ray_bundle = self.raysampler(
cameras=cameras, volumetric_function=volumetric_function, **kwargs
)
# ray_bundle.origins - minibatch x ... x 3
# ray_bundle.directions - minibatch x ... x 3
# ray_bundle.lengths - minibatch x ... x n_pts_per_ray
# ray_bundle.xys - minibatch x ... x 2
# given sampled rays, call the volumetric function that
# evaluates the densities and features at the locations of the
# ray points
rays_densities, rays_features = volumetric_function(
ray_bundle=ray_bundle, cameras=cameras, **kwargs
)
# ray_densities - minibatch x ... x n_pts_per_ray x density_dim
# ray_features - minibatch x ... x n_pts_per_ray x feature_dim
# finally, march along the sampled rays to obtain the renders
images = self.raymarcher(
rays_densities=rays_densities,
rays_features=rays_features,
ray_bundle=ray_bundle,
**kwargs
)
# images - minibatch x ... x (feature_dim + opacity_dim)
return images, ray_bundle
# The volume renderer class should be initialized with a
# function for raysampling and a function for raymarching.
# During the forward pass:
# 1) The raysampler:
# - samples rays from input cameras
# - transforms the rays to world coordinates
# 2) The scene volumes (which are an argument of the forward function)
# are then sampled at the locations of the ray-points to generate
# ray_densities and ray_features.
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching
# algorithm to render each ray.
class VolumeRenderer(torch.nn.Module):
"""
A class for rendering a batch of Volumes. The class should
be initialized with a raysampler and a raymarcher class which both have
to be a `Callable`.
"""
def __init__(
self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear"
):
"""
Args:
raysampler: A `Callable` that takes as input scene cameras
(an instance of `CamerasBase`) and returns a `RayBundle` that
describes the rays emitted from the cameras.
raymarcher: A `Callable` that receives the `volumes`
(an instance of `Volumes` input to `self.forward`)
sampled at the ray-points, and renders the rays with a
ray-marching algorithm.
sample_mode: Defines the algorithm used to sample the volumetric
voxel grid. Can be either "bilinear" or "nearest".
"""
super().__init__()
self.renderer = ImplicitRenderer(raysampler, raymarcher)
self._sample_mode = sample_mode
def forward(
self, cameras: CamerasBase, volumes: Volumes, **kwargs
) -> Tuple[torch.Tensor, RayBundle]:
"""
Render a batch of images using raymarching over rays cast through
input `Volumes`.
Args:
cameras: A batch of cameras that render the scene. A `self.raysampler`
takes the cameras as input and samples rays that pass through the
domain of the volumentric function.
volumes: An instance of the `Volumes` class representing a
batch of volumes that are being rendered.
Returns:
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)`
containing the result of the rendering.
ray_bundle: A `RayBundle` containing the parametrizations of the
sampled rendering rays.
"""
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode)
return self.renderer(
cameras=cameras, volumetric_function=volumetric_function, **kwargs
)
class VolumeSampler(torch.nn.Module):
"""
A class that allows to sample a batch of volumes `Volumes`
at 3D points sampled along projection rays.
"""
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"):
"""
Args:
volumes: An instance of the `Volumes` class representing a
batch if volumes that are being rendered.
sample_mode: Defines the algorithm used to sample the volumetric
voxel grid. Can be either "bilinear" or "nearest".
"""
super().__init__()
if not isinstance(volumes, Volumes):
raise ValueError("'volumes' have to be an instance of the 'Volumes' class.")
self._volumes = volumes
self._sample_mode = sample_mode
def _get_ray_directions_transform(self):
"""
Compose the ray-directions transform by removing the translation component
from the volume global-to-local coords transform.
"""
world2local = self._volumes.get_world_to_local_coords_transform().get_matrix()
directions_transform_matrix = eyes(
4,
N=world2local.shape[0],
device=world2local.device,
dtype=world2local.dtype,
)
directions_transform_matrix[:, :3, :3] = world2local[:, :3, :3]
directions_transform = Transform3d(matrix=directions_transform_matrix)
return directions_transform
def forward(
self, ray_bundle: RayBundle, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Given an input ray parametrization, the forward function samples
`self._volumes` at the respective 3D ray-points.
Args:
ray_bundle: A RayBundle object with the following fields:
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
rays_directions_world: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
rays_lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
Returns:
rays_densities: A tensor of shape
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing the
densitity vectors sampled from the volume at the locations of
the ray points.
rays_features: A tensor of shape
`(minibatch, ..., num_points_per_ray, feature_dim)` containing the
feature vectors sampled from the volume at the locations of
the ray points.
"""
# take out the interesting parts of ray_bundle
rays_origins_world = ray_bundle.origins
rays_directions_world = ray_bundle.directions
rays_lengths = ray_bundle.lengths
# validate the inputs
_validate_ray_bundle_variables(
rays_origins_world, rays_directions_world, rays_lengths
)
if self._volumes.densities().shape[0] != rays_origins_world.shape[0]:
raise ValueError("Input volumes have to have the same batch size as rays.")
#########################################################
# 1) convert the origins/directions to the local coords #
#########################################################
# origins are mapped with the world_to_local transform of the volumes
rays_origins_local = self._volumes.world_to_local_coords(rays_origins_world)
# obtain the Transform3d object that transforms ray directions to local coords
directions_transform = self._get_ray_directions_transform()
# transform the directions to the local coords
rays_directions_local = directions_transform.transform_points(
rays_directions_world.view(rays_lengths.shape[0], -1, 3)
).view(rays_directions_world.shape)
############################
# 2) obtain the ray points #
############################
# this op produces a fairly big tensor (minibatch, ..., n_samples_per_ray, 3)
rays_points_local = ray_bundle_variables_to_ray_points(
rays_origins_local, rays_directions_local, rays_lengths
)
########################
# 3) sample the volume #
########################
# generate the tensor for sampling
volumes_densities = self._volumes.densities()
dim_density = volumes_densities.shape[1]
volumes_features = self._volumes.features()
# adjust the volumes_features variable in case we have a feature-less volume
if volumes_features is None:
dim_feature = 0
data_to_sample = volumes_densities
else:
dim_feature = volumes_features.shape[1]
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1)
# reshape to a size which grid_sample likes
rays_points_local_flat = rays_points_local.view(
rays_points_local.shape[0], -1, 1, 1, 3
)
# run the grid sampler
data_sampled = torch.nn.functional.grid_sample(
data_to_sample,
rays_points_local_flat,
align_corners=True,
mode=self._sample_mode,
)
# permute the dimensions & reshape after sampling
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view(
*rays_points_local.shape[:-1], data_sampled.shape[1]
)
# split back to densities and features
rays_densities, rays_features = data_sampled.split(
[dim_density, dim_feature], dim=-1
)
return rays_densities, rays_features

View File

@ -53,12 +53,12 @@ def ray_bundle_variables_to_ray_points(
rays_lengths: torch.Tensor,
) -> torch.Tensor:
"""
Converts rays parametrized with origins, directions
Converts rays parametrized with origins and directions
to 3D points by extending each ray according to the corresponding
ray_length:
ray length:
E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions`
and `rays_lengths`, the ray point at position `[i, j]` is:
and `rays_lengths`, the ray point at position `[i, j]` is:
```
rays_points[i, j, :] = (
rays_origins[i, :]
@ -80,3 +80,39 @@ def ray_bundle_variables_to_ray_points(
+ rays_lengths[..., :, None] * rays_directions[..., None, :]
)
return rays_points
def _validate_ray_bundle_variables(
rays_origins: torch.Tensor,
rays_directions: torch.Tensor,
rays_lengths: torch.Tensor,
):
"""
Validate the shapes of RayBundle variables
`rays_origins`, `rays_directions`, and `rays_lengths`.
"""
ndim = rays_origins.ndim
if any(r.ndim != ndim for r in (rays_directions, rays_lengths)):
raise ValueError(
"rays_origins, rays_directions and rays_lengths"
+ " have to have the same number of dimensions."
)
if ndim <= 2:
raise ValueError(
"rays_origins, rays_directions and rays_lengths"
+ " have to have at least 3 dimensions."
)
spatial_size = rays_origins.shape[:-1]
if any(spatial_size != r.shape[:-1] for r in (rays_directions, rays_lengths)):
raise ValueError(
"The shapes of rays_origins, rays_directions and rays_lengths"
+ " may differ only in the last dimension."
)
if any(r.shape[-1] != 3 for r in (rays_origins, rays_directions)):
raise ValueError(
"The size of the last dimension of rays_origins/rays_directions"
+ "has to be 3."
)

View File

@ -0,0 +1,22 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from test_render_implicit import TestRenderImplicit
def bm_render_volumes() -> None:
case_grid = {
"batch_size": [1, 5],
"raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher],
"n_rays_per_image": [64 ** 2, 256 ** 2],
"n_pts_per_ray": [16, 128],
}
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
benchmark(
TestRenderImplicit.renderer, "IMPLICIT_RENDERER", kwargs_list, warmup_iters=1
)

View File

@ -0,0 +1,24 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from test_render_volumes import TestRenderVolumes
def bm_render_volumes() -> None:
case_grid = {
"volume_size": [tuple([17] * 3), tuple([129] * 3)],
"batch_size": [1, 5],
"shape": ["sphere", "cube"],
"raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher],
"n_rays_per_image": [64 ** 2, 256 ** 2],
"n_pts_per_ray": [16, 128],
}
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
benchmark(
TestRenderVolumes.renderer, "VOLUME_RENDERER", kwargs_list, warmup_iters=1
)

View File

@ -0,0 +1,403 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer import (
BlendParams,
EmissionAbsorptionRaymarcher,
GridRaysampler,
ImplicitRenderer,
Materials,
MeshRasterizer,
MeshRenderer,
MonteCarloRaysampler,
NDCGridRaysampler,
PointLights,
RasterizationSettings,
RayBundle,
SoftPhongShader,
TexturesVertex,
ray_bundle_to_ray_points,
)
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from test_render_volumes import init_cameras
DEBUG = False
if DEBUG:
import os
import tempfile
from PIL import Image
def spherical_volumetric_function(
ray_bundle: RayBundle,
sphere_centroid: torch.Tensor,
sphere_diameter: float,
**kwargs,
):
"""
Volumetric function of a simple RGB sphere with diameter `sphere_diameter`
and centroid `sphere_centroid`.
"""
# convert the ray bundle to world points
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
batch_size = rays_points_world.shape[0]
# surface_vectors = vectors from world coords towards the sphere centroid
surface_vectors = (
rays_points_world.view(batch_size, -1, 3) - sphere_centroid[:, None]
)
# the squared distance of each ray point to the centroid of the sphere
surface_dist = (
(surface_vectors ** 2)
.sum(-1, keepdim=True)
.view(*rays_points_world.shape[:-1], 1)
)
# set all ray densities within the sphere_diameter distance from the centroid to 1
rays_densities = torch.sigmoid(-100.0 * (surface_dist - sphere_diameter ** 2))
# ray colors are proportional to the normalized surface_vectors
rays_features = (
torch.nn.functional.normalize(
surface_vectors.view(rays_points_world.shape), dim=-1
)
* 0.5
+ 0.5
)
return rays_densities, rays_features
class TestRenderImplicit(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
@staticmethod
def renderer(
batch_size=10,
raymarcher_type=EmissionAbsorptionRaymarcher,
n_rays_per_image=10,
n_pts_per_ray=10,
sphere_diameter=0.75,
):
# generate NDC camera extrinsics and intrinsics
cameras = init_cameras(batch_size, image_size=None, ndc=True)
# get rand offset of the volume
sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1
# init the mc raysampler
raysampler = MonteCarloRaysampler(
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
max_y=1.0,
n_rays_per_image=n_rays_per_image,
n_pts_per_ray=n_pts_per_ray,
min_depth=0.1,
max_depth=2.0,
).to(cameras.device)
# get the raymarcher
raymarcher = raymarcher_type()
# get the implicit renderer
renderer = ImplicitRenderer(raysampler=raysampler, raymarcher=raymarcher)
def run_renderer():
renderer(
cameras=cameras,
volumetric_function=spherical_volumetric_function,
sphere_centroid=sphere_centroid,
sphere_diameter=sphere_diameter,
)
return run_renderer
def test_input_types(self):
"""
Check that ValueErrors are thrown where expected.
"""
# check the constructor
for bad_raysampler in (None, 5, []):
for bad_raymarcher in (None, 5, []):
with self.assertRaises(ValueError):
ImplicitRenderer(
raysampler=bad_raysampler, raymarcher=bad_raymarcher
)
# init a trivial renderer
renderer = ImplicitRenderer(
raysampler=NDCGridRaysampler(
image_width=100,
image_height=100,
n_pts_per_ray=10,
min_depth=0.1,
max_depth=1.0,
),
raymarcher=EmissionAbsorptionRaymarcher(),
)
# get default cameras
cameras = init_cameras()
for bad_volumetric_function in (None, 5, []):
with self.assertRaises(ValueError):
renderer(cameras=cameras, volumetric_function=bad_volumetric_function)
def test_compare_with_meshes_renderer(
self, batch_size=11, image_size=100, sphere_diameter=0.6
):
"""
Generate a spherical RGB volumetric function and its corresponding mesh
and check whether MeshesRenderer returns the same images as the
corresponding ImplicitRenderer.
"""
# generate NDC camera extrinsics and intrinsics
cameras = init_cameras(
batch_size, image_size=[image_size, image_size], ndc=True
)
# get rand offset of the volume
sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1
sphere_centroid.requires_grad = True
# init the grid raysampler with the ndc grid
raysampler = NDCGridRaysampler(
image_width=image_size,
image_height=image_size,
n_pts_per_ray=256,
min_depth=0.1,
max_depth=2.0,
)
# get the EA raymarcher
raymarcher = EmissionAbsorptionRaymarcher()
# jitter the camera intrinsics a bit for each render
cameras_randomized = cameras.clone()
cameras_randomized.principal_point = (
torch.randn_like(cameras.principal_point) * 0.3
)
cameras_randomized.focal_length = (
cameras.focal_length + torch.randn_like(cameras.focal_length) * 0.2
)
# the list of differentiable camera vars
cam_vars = ("R", "T", "focal_length", "principal_point")
# enable the gradient caching for the camera variables
for cam_var in cam_vars:
getattr(cameras_randomized, cam_var).requires_grad = True
# get the implicit renderer
images_opacities = ImplicitRenderer(
raysampler=raysampler, raymarcher=raymarcher
)(
cameras=cameras_randomized,
volumetric_function=spherical_volumetric_function,
sphere_centroid=sphere_centroid,
sphere_diameter=sphere_diameter,
)[
0
]
# check that the renderer does not erase gradients
loss = images_opacities.sum()
loss.backward()
for check_var in (
*[getattr(cameras_randomized, cam_var) for cam_var in cam_vars],
sphere_centroid,
):
self.assertIsNotNone(check_var.grad)
# instantiate the corresponding spherical mesh
ico = ico_sphere(level=4, device=cameras.device).extend(batch_size)
verts = (
torch.nn.functional.normalize(ico.verts_padded(), dim=-1) * sphere_diameter
+ sphere_centroid[:, None]
)
meshes = Meshes(
verts=verts,
faces=ico.faces_padded(),
textures=TexturesVertex(
verts_features=(
torch.nn.functional.normalize(verts, dim=-1) * 0.5 + 0.5
)
),
)
# instantiate the corresponding mesh renderer
lights = PointLights(device=cameras.device, location=[[0.0, 0.0, 0.0]])
renderer_textured = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras_randomized,
raster_settings=RasterizationSettings(
image_size=image_size, blur_radius=1e-3, faces_per_pixel=10
),
),
shader=SoftPhongShader(
device=cameras.device,
cameras=cameras_randomized,
lights=lights,
materials=Materials(
ambient_color=((2.0, 2.0, 2.0),),
diffuse_color=((0.0, 0.0, 0.0),),
specular_color=((0.0, 0.0, 0.0),),
shininess=64,
device=cameras.device,
),
blend_params=BlendParams(
sigma=1e-3, gamma=1e-4, background_color=(0.0, 0.0, 0.0)
),
),
)
# get the mesh render
images_opacities_meshes = renderer_textured(
meshes, cameras=cameras_randomized, lights=lights
)
if DEBUG:
outdir = tempfile.gettempdir() + "/test_implicit_vs_mesh_renderer"
os.makedirs(outdir, exist_ok=True)
frames = []
for (image_opacity, image_opacity_mesh) in zip(
images_opacities, images_opacities_meshes
):
image, opacity = image_opacity.split([3, 1], dim=-1)
image_mesh, opacity_mesh = image_opacity_mesh.split([3, 1], dim=-1)
diff_image = (
((image - image_mesh) * 0.5 + 0.5)
.mean(dim=2, keepdim=True)
.repeat(1, 1, 3)
)
image_pil = Image.fromarray(
(
torch.cat(
(
image,
image_mesh,
diff_image,
opacity.repeat(1, 1, 3),
opacity_mesh.repeat(1, 1, 3),
),
dim=1,
)
.detach()
.cpu()
.numpy()
* 255.0
).astype(np.uint8)
)
frames.append(image_pil)
# export gif
outfile = os.path.join(outdir, "implicit_vs_mesh_render.gif")
frames[0].save(
outfile,
save_all=True,
append_images=frames[1:],
duration=batch_size // 15,
loop=0,
)
print(f"exported {outfile}")
# export concatenated frames
outfile_cat = os.path.join(outdir, "implicit_vs_mesh_render.png")
Image.fromarray(np.concatenate([np.array(f) for f in frames], axis=0)).save(
outfile_cat
)
print(f"exported {outfile_cat}")
# compare the renders
diff = (images_opacities - images_opacities_meshes).abs().mean(dim=-1)
mu_diff = diff.mean(dim=(1, 2))
std_diff = diff.std(dim=(1, 2))
self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=5e-2)
self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2)
def test_rotating_gif(
self, n_frames=50, fps=15, image_size=(100, 100), sphere_diameter=0.5
):
"""
Render a gif animation of a rotating sphere (runs only if `DEBUG==True`).
"""
if not DEBUG:
# do not run this if debug is False
return
# generate camera extrinsics and intrinsics
cameras = init_cameras(n_frames, image_size=image_size)
# init the grid raysampler
raysampler = GridRaysampler(
min_x=0.5,
max_x=image_size[1] - 0.5,
min_y=0.5,
max_y=image_size[0] - 0.5,
image_width=image_size[1],
image_height=image_size[0],
n_pts_per_ray=256,
min_depth=0.1,
max_depth=2.0,
)
# get the EA raymarcher
raymarcher = EmissionAbsorptionRaymarcher()
# get the implicit render
renderer = ImplicitRenderer(raysampler=raysampler, raymarcher=raymarcher)
# get the (0) centroid of the sphere
sphere_centroid = torch.zeros(n_frames, 3, device=cameras.device) * 0.1
# run the renderer
images_opacities = renderer(
cameras=cameras,
volumetric_function=spherical_volumetric_function,
sphere_centroid=sphere_centroid,
sphere_diameter=sphere_diameter,
)[0]
# split output to the alpha channel and rendered images
images, opacities = images_opacities[..., :3], images_opacities[..., 3]
# export the gif
outdir = tempfile.gettempdir() + "/test_implicit_renderer_gifs"
os.makedirs(outdir, exist_ok=True)
frames = []
for image, opacity in zip(images, opacities):
image_pil = Image.fromarray(
(
torch.cat(
(image, opacity[..., None].clamp(0.0, 1.0).repeat(1, 1, 3)),
dim=1,
)
.detach()
.cpu()
.numpy()
* 255.0
).astype(np.uint8)
)
frames.append(image_pil)
outfile = os.path.join(outdir, "rotating_sphere.gif")
frames[0].save(
outfile,
save_all=True,
append_images=frames[1:],
duration=n_frames // fps,
loop=0,
)
print(f"exported {outfile}")

View File

@ -0,0 +1,711 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from typing import Optional, Tuple
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import knn_points
from pytorch3d.renderer import (
AbsorptionOnlyRaymarcher,
AlphaCompositor,
EmissionAbsorptionRaymarcher,
GridRaysampler,
MonteCarloRaysampler,
NDCGridRaysampler,
PerspectiveCameras,
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
RayBundle,
VolumeRenderer,
VolumeSampler,
)
from pytorch3d.renderer.implicit.utils import _validate_ray_bundle_variables
from pytorch3d.structures import Pointclouds, Volumes
from test_points_to_volumes import init_uniform_y_rotations
DEBUG = False
if DEBUG:
import os
import tempfile
from PIL import Image
ZERO_TRANSLATION = torch.zeros(1, 3)
def init_boundary_volume(
batch_size: int,
volume_size: Tuple[int, int, int],
border_offset: int = 2,
shape: str = "cube",
volume_translation: torch.Tensor = ZERO_TRANSLATION,
):
"""
Generate a volume with sides colored with distinct colors.
"""
device = torch.device("cuda")
# first center the volume for the purpose of generating the canonical shape
volume_translation_tmp = (0.0, 0.0, 0.0)
# set the voxel size to 1 / (volume_size-1)
volume_voxel_size = 1 / (volume_size[0] - 1.0)
# colors of the sides of the cube
clr_sides = torch.tensor(
[
[1.0, 1.0, 1.0],
[1.0, 0.0, 0.0],
[1.0, 0.0, 1.0],
[1.0, 1.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 1.0, 1.0],
],
dtype=torch.float32,
device=device,
)
# get the coord grid of the volume
coord_grid = Volumes(
densities=torch.zeros(1, 1, *volume_size, device=device),
voxel_size=volume_voxel_size,
volume_translation=volume_translation_tmp,
).get_coord_grid()[0]
# extract the boundary points and their colors of the cube
if shape == "cube":
boundary_points, boundary_colors = [], []
for side, clr_side in enumerate(clr_sides):
first = side % 2
dim = side // 2
slices = [slice(border_offset, -border_offset, 1)] * 3
slices[dim] = int(border_offset * (2 * first - 1))
slices.append(slice(0, 3, 1))
boundary_points_ = coord_grid[slices].reshape(-1, 3)
boundary_points.append(boundary_points_)
boundary_colors.append(clr_side[None].expand_as(boundary_points_))
# set the internal part of the volume to be completely opaque
volume_densities = torch.zeros(*volume_size, device=device)
volume_densities[[slice(border_offset, -border_offset, 1)] * 3] = 1.0
boundary_points, boundary_colors = [
torch.cat(p, dim=0) for p in [boundary_points, boundary_colors]
]
# color the volume voxels with the nearest boundary points' color
_, idx, _ = knn_points(
coord_grid.view(1, -1, 3), boundary_points.view(1, -1, 3)
)
volume_colors = (
boundary_colors[idx.view(-1)].view(*volume_size, 3).permute(3, 0, 1, 2)
)
elif shape == "sphere":
# set all voxels within a certain distance from the origin to be opaque
volume_densities = (
coord_grid.norm(dim=-1)
<= 0.5 * volume_voxel_size * (volume_size[0] - border_offset)
).float()
# color each voxel with the standrd spherical color
volume_colors = (
(torch.nn.functional.normalize(coord_grid, dim=-1) + 1.0) * 0.5
).permute(3, 0, 1, 2)
else:
raise ValueError(shape)
volume_voxel_size = torch.ones((batch_size, 1), device=device) * volume_voxel_size
volume_translation = volume_translation.expand(batch_size, 3)
volumes = Volumes(
densities=volume_densities[None, None].expand(batch_size, 1, *volume_size),
features=volume_colors[None].expand(batch_size, 3, *volume_size),
voxel_size=volume_voxel_size,
volume_translation=volume_translation,
)
return volumes, volume_voxel_size, volume_translation
def init_cameras(
batch_size: int = 10,
image_size: Optional[Tuple[int, int]] = (50, 50),
ndc: bool = False,
):
"""
Initialize a batch of cameras whose extrinsics rotate the cameras around
the world's y axis.
Depending on whether we want an NDC-space (`ndc==True`) or a screen-space camera,
the camera's focal length and principal point are initialized accordingly:
For `ndc==False`, p0=focal_length=image_size/2.
For `ndc==True`, focal_length=1.0, p0 = 0.0.
The the z-coordinate of the translation vector of each camera is fixed to 1.5.
"""
device = torch.device("cuda:0")
# trivial rotations
R = init_uniform_y_rotations(batch_size).to(device)
# move camera 1.5 m away from the scene center
T = torch.zeros((batch_size, 3), device=device)
T[:, 2] = 1.5
if ndc:
p0 = torch.zeros(batch_size, 2, device=device)
focal = torch.ones(batch_size, device=device)
else:
p0 = torch.ones(batch_size, 2, device=device)
p0[:, 0] *= image_size[1] * 0.5
p0[:, 1] *= image_size[0] * 0.5
focal = image_size[0] * torch.ones(batch_size, device=device)
# convert to a Camera object
cameras = PerspectiveCameras(focal, p0, R=R, T=T, device=device)
return cameras
class TestRenderVolumes(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
@staticmethod
def renderer(
volume_size=(25, 25, 25),
batch_size=10,
shape="sphere",
raymarcher_type=EmissionAbsorptionRaymarcher,
n_rays_per_image=10,
n_pts_per_ray=10,
):
# get the volumes
volumes = init_boundary_volume(
volume_size=volume_size, batch_size=batch_size, shape=shape
)[0]
# init the mc raysampler
raysampler = MonteCarloRaysampler(
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
max_y=1.0,
n_rays_per_image=n_rays_per_image,
n_pts_per_ray=n_pts_per_ray,
min_depth=0.1,
max_depth=2.0,
).to(volumes.device)
# get the raymarcher
raymarcher = raymarcher_type()
renderer = VolumeRenderer(
raysampler=raysampler, raymarcher=raymarcher, sample_mode="bilinear"
)
# generate NDC camera extrinsics and intrinsics
cameras = init_cameras(batch_size, image_size=None, ndc=True)
def run_renderer():
renderer(cameras=cameras, volumes=volumes)
return run_renderer
def test_input_types(self, batch_size: int = 10):
"""
Check that ValueErrors are thrown where expected.
"""
# check the constructor
for bad_raysampler in (None, 5, []):
for bad_raymarcher in (None, 5, []):
with self.assertRaises(ValueError):
VolumeRenderer(raysampler=bad_raysampler, raymarcher=bad_raymarcher)
raysampler = NDCGridRaysampler(
image_width=100,
image_height=100,
n_pts_per_ray=10,
min_depth=0.1,
max_depth=1.0,
)
# init a trivial renderer
renderer = VolumeRenderer(
raysampler=raysampler, raymarcher=EmissionAbsorptionRaymarcher()
)
# get cameras
cameras = init_cameras(batch_size=batch_size)
# get volumes
volumes = init_boundary_volume(volume_size=(10, 10, 10), batch_size=batch_size)[
0
]
# different batch sizes for cameras / volumes
with self.assertRaises(ValueError):
renderer(cameras=cameras, volumes=volumes[:-1])
# ray checks for VolumeSampler
volume_sampler = VolumeSampler(volumes=volumes)
n_rays = 100
for bad_ray_bundle in (
(
torch.rand(batch_size, n_rays, 3),
torch.rand(batch_size, n_rays + 1, 3),
torch.rand(batch_size, n_rays, 10),
),
(
torch.rand(batch_size + 1, n_rays, 3),
torch.rand(batch_size, n_rays, 3),
torch.rand(batch_size, n_rays, 10),
),
(
torch.rand(batch_size, n_rays, 3),
torch.rand(batch_size, n_rays, 2),
torch.rand(batch_size, n_rays, 10),
),
(
torch.rand(batch_size, n_rays, 3),
torch.rand(batch_size, n_rays, 3),
torch.rand(batch_size, n_rays),
),
):
ray_bundle = RayBundle(
**dict(
zip(
("origins", "directions", "lengths"),
[r.to(cameras.device) for r in bad_ray_bundle],
)
),
xys=None,
)
with self.assertRaises(ValueError):
volume_sampler(ray_bundle)
# check also explicitly the ray bundle validation function
with self.assertRaises(ValueError):
_validate_ray_bundle_variables(*bad_ray_bundle)
def test_compare_with_pointclouds_renderer(
self, batch_size=11, volume_size=(30, 30, 30), image_size=200
):
"""
Generate a volume and its corresponding point cloud and check whether
PointsRenderer returns the same images as the corresponding VolumeRenderer.
"""
# generate NDC camera extrinsics and intrinsics
cameras = init_cameras(
batch_size, image_size=[image_size, image_size], ndc=True
)
# init the boundary volume
for shape in ("sphere", "cube"):
if not DEBUG and shape == "cube":
# do not run numeric checks for the cube as the
# differences in rendering equations make the renders incomparable
continue
# get rand offset of the volume
volume_translation = torch.randn(batch_size, 3) * 0.1
# volume_translation[2] = 0.1
volumes = init_boundary_volume(
volume_size=volume_size,
batch_size=batch_size,
shape=shape,
volume_translation=volume_translation,
)[0]
# convert the volumes to a pointcloud
points = []
points_features = []
for densities_one, features_one, grid_one in zip(
volumes.densities(),
volumes.features(),
volumes.get_coord_grid(world_coordinates=True),
):
opaque = densities_one.view(-1) > 1e-4
points.append(grid_one.view(-1, 3)[opaque])
points_features.append(features_one.reshape(3, -1).t()[opaque])
pointclouds = Pointclouds(points, features=points_features)
# init the grid raysampler with the ndc grid
coord_range = 1.0
half_pix_size = coord_range / image_size
raysampler = NDCGridRaysampler(
image_width=image_size,
image_height=image_size,
n_pts_per_ray=256,
min_depth=0.1,
max_depth=2.0,
)
# get the EA raymarcher
raymarcher = EmissionAbsorptionRaymarcher()
# jitter the camera intrinsics a bit for each render
cameras_randomized = cameras.clone()
cameras_randomized.principal_point = (
torch.randn_like(cameras.principal_point) * 0.3
)
cameras_randomized.focal_length = (
cameras.focal_length + torch.randn_like(cameras.focal_length) * 0.2
)
# get the volumetric render
images = VolumeRenderer(
raysampler=raysampler, raymarcher=raymarcher, sample_mode="bilinear"
)(cameras=cameras_randomized, volumes=volumes)[0][..., :3]
# instantiate the points renderer
point_radius = 6 * half_pix_size
points_renderer = PointsRenderer(
rasterizer=PointsRasterizer(
cameras=cameras_randomized,
raster_settings=PointsRasterizationSettings(
image_size=image_size, radius=point_radius, points_per_pixel=10
),
),
compositor=AlphaCompositor(),
)
# get the point render
images_pts = points_renderer(pointclouds)
if shape == "sphere":
diff = (images - images_pts).abs().mean(dim=-1)
mu_diff = diff.mean(dim=(1, 2))
std_diff = diff.std(dim=(1, 2))
self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=3e-2)
self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2)
if DEBUG:
outdir = tempfile.gettempdir() + "/test_volume_vs_pts_renderer"
os.makedirs(outdir, exist_ok=True)
frames = []
for (image, image_pts) in zip(images, images_pts):
diff_image = (
((image - image_pts) * 0.5 + 0.5)
.mean(dim=2, keepdim=True)
.repeat(1, 1, 3)
)
image_pil = Image.fromarray(
(
torch.cat((image, image_pts, diff_image), dim=1)
.detach()
.cpu()
.numpy()
* 255.0
).astype(np.uint8)
)
frames.append(image_pil)
# export gif
outfile = os.path.join(outdir, f"volume_vs_pts_render_{shape}.gif")
frames[0].save(
outfile,
save_all=True,
append_images=frames[1:],
duration=batch_size // 15,
loop=0,
)
print(f"exported {outfile}")
# export concatenated frames
outfile_cat = os.path.join(outdir, f"volume_vs_pts_render_{shape}.png")
Image.fromarray(
np.concatenate([np.array(f) for f in frames], axis=0)
).save(outfile_cat)
print(f"exported {outfile_cat}")
def test_monte_carlo_rendering(
self, n_frames=20, volume_size=(30, 30, 30), image_size=(40, 50)
):
"""
Tests that rendering with the MonteCarloRaysampler matches the
rendering with GridRaysampler sampled at the corresponding
MonteCarlo locations.
"""
volumes = init_boundary_volume(
volume_size=volume_size, batch_size=n_frames, shape="sphere"
)[0]
# generate camera extrinsics and intrinsics
cameras = init_cameras(n_frames, image_size=image_size)
# init the grid raysampler
raysampler_grid = GridRaysampler(
min_x=0.5,
max_x=image_size[1] - 0.5,
min_y=0.5,
max_y=image_size[0] - 0.5,
image_width=image_size[1],
image_height=image_size[0],
n_pts_per_ray=256,
min_depth=0.5,
max_depth=2.0,
)
# init the mc raysampler
raysampler_mc = MonteCarloRaysampler(
min_x=0.5,
max_x=image_size[1] - 0.5,
min_y=0.5,
max_y=image_size[0] - 0.5,
n_rays_per_image=3000,
n_pts_per_ray=256,
min_depth=0.5,
max_depth=2.0,
)
# get the EA raymarcher
raymarcher = EmissionAbsorptionRaymarcher()
# get both mc and grid renders
(
(images_opacities_mc, ray_bundle_mc),
(images_opacities_grid, ray_bundle_grid),
) = [
VolumeRenderer(
raysampler=raysampler_grid,
raymarcher=raymarcher,
sample_mode="bilinear",
)(cameras=cameras, volumes=volumes)
for raysampler in (raysampler_mc, raysampler_grid)
]
# convert the mc sampling locations to [-1, 1]
sample_loc = ray_bundle_mc.xys.clone()
sample_loc[..., 0] = 2 * (sample_loc[..., 0] / image_size[1]) - 1
sample_loc[..., 1] = 2 * (sample_loc[..., 1] / image_size[0]) - 1
# sample the grid render at the mc locations
images_opacities_mc_ = torch.nn.functional.grid_sample(
images_opacities_grid.permute(0, 3, 1, 2), sample_loc, align_corners=False
)
# check that the samples are the same
self.assertClose(
images_opacities_mc.permute(0, 3, 1, 2), images_opacities_mc_, atol=1e-4
)
def test_rotating_gif(
self, n_frames=50, fps=15, volume_size=(100, 100, 100), image_size=(100, 100)
):
"""
Render a gif animation of a rotating cube/sphere (runs only if `DEBUG==True`).
"""
if not DEBUG:
# do not run this if debug is False
return
for shape in ("sphere", "cube"):
for sample_mode in ("bilinear", "nearest"):
volumes = init_boundary_volume(
volume_size=volume_size, batch_size=n_frames, shape=shape
)[0]
# generate camera extrinsics and intrinsics
cameras = init_cameras(n_frames, image_size=image_size)
# init the grid raysampler
raysampler = GridRaysampler(
min_x=0.5,
max_x=image_size[1] - 0.5,
min_y=0.5,
max_y=image_size[0] - 0.5,
image_width=image_size[1],
image_height=image_size[0],
n_pts_per_ray=256,
min_depth=0.5,
max_depth=2.0,
)
# get the EA raymarcher
raymarcher = EmissionAbsorptionRaymarcher()
# intialize the renderer
renderer = VolumeRenderer(
raysampler=raysampler,
raymarcher=raymarcher,
sample_mode=sample_mode,
)
# run the renderer
images_opacities = renderer(cameras=cameras, volumes=volumes)[0]
# split output to the alpha channel and rendered images
images, opacities = images_opacities[..., :3], images_opacities[..., 3]
# export the gif
outdir = tempfile.gettempdir() + "/test_volume_renderer_gifs"
os.makedirs(outdir, exist_ok=True)
frames = []
for image, opacity in zip(images, opacities):
image_pil = Image.fromarray(
(
torch.cat(
(image, opacity[..., None].repeat(1, 1, 3)), dim=1
)
.detach()
.cpu()
.numpy()
* 255.0
).astype(np.uint8)
)
frames.append(image_pil)
outfile = os.path.join(outdir, f"{shape}_{sample_mode}.gif")
frames[0].save(
outfile,
save_all=True,
append_images=frames[1:],
duration=n_frames // fps,
loop=0,
)
print(f"exported {outfile}")
def test_rotating_cube_volume_render(self):
"""
Generates 4 renders of 4 sides of a volume representing a 3D cube.
Since each side of the cube is homogenously colored with
a different color, this should result in 4 images of homogenous color
with the depth of each pixel equal to a constant.
"""
# batch_size = 4 sides of the cube
batch_size = 4
image_size = (50, 50)
for volume_size in ([25, 25, 25],):
for sample_mode in ("bilinear", "nearest"):
volume_translation = torch.zeros(4, 3)
volume_translation.requires_grad = True
volumes, volume_voxel_size, _ = init_boundary_volume(
volume_size=volume_size,
batch_size=batch_size,
shape="cube",
volume_translation=volume_translation,
)
# generate camera extrinsics and intrinsics
cameras = init_cameras(batch_size, image_size=image_size)
# enable the gradient caching for the camera variables
# the list of differentiable camera vars
cam_vars = ("R", "T", "focal_length", "principal_point")
for cam_var in cam_vars:
getattr(cameras, cam_var).requires_grad = True
# enable the grad for volume vars as well
volumes.features().requires_grad = True
volumes.densities().requires_grad = True
raysampler = GridRaysampler(
min_x=0.5,
max_x=image_size[1] - 0.5,
min_y=0.5,
max_y=image_size[0] - 0.5,
image_width=image_size[1],
image_height=image_size[0],
n_pts_per_ray=128,
min_depth=0.01,
max_depth=3.0,
)
raymarcher = EmissionAbsorptionRaymarcher()
renderer = VolumeRenderer(
raysampler=raysampler,
raymarcher=raymarcher,
sample_mode=sample_mode,
)
images_opacities = renderer(cameras=cameras, volumes=volumes)[0]
images, opacities = images_opacities[..., :3], images_opacities[..., 3]
# check that the renderer does not erase gradients
loss = images_opacities.sum()
loss.backward()
for check_var in (
*[getattr(cameras, cam_var) for cam_var in cam_vars],
volumes.features(),
volumes.densities(),
volume_translation,
):
self.assertIsNotNone(check_var.grad)
# ao opacities should be exactly the same as the ea ones
# we can further get the ea opacities from a feature-less
# version of our volumes
raymarcher_ao = AbsorptionOnlyRaymarcher()
renderer_ao = VolumeRenderer(
raysampler=raysampler,
raymarcher=raymarcher_ao,
sample_mode=sample_mode,
)
volumes_featureless = Volumes(
densities=volumes.densities(),
volume_translation=volume_translation,
voxel_size=volume_voxel_size,
)
opacities_ao = renderer_ao(
cameras=cameras, volumes=volumes_featureless
)[0][..., 0]
self.assertClose(opacities, opacities_ao)
# colors of the sides of the cube
gt_clr_sides = torch.tensor(
[
[1.0, 0.0, 0.0],
[0.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[0.0, 1.0, 0.0],
],
dtype=torch.float32,
device=images.device,
)
if DEBUG:
outdir = tempfile.gettempdir() + "/test_volume_renderer"
os.makedirs(outdir, exist_ok=True)
for imidx, (image, opacity) in enumerate(zip(images, opacities)):
for image_ in (image, opacity):
image_pil = Image.fromarray(
(image_.detach().cpu().numpy() * 255.0).astype(np.uint8)
)
outfile = (
outdir
+ f"/rgb_{sample_mode}"
+ f"_{str(volume_size).replace(' ','')}"
+ f"_{imidx:003d}"
)
if image_ is image:
outfile += "_rgb.png"
else:
outfile += "_opacity.png"
image_pil.save(outfile)
print(f"exported {outfile}")
border = 10
for image, opacity, gt_color in zip(images, opacities, gt_clr_sides):
image_crop = image[border:-border, border:-border]
opacity_crop = opacity[border:-border, border:-border]
# check mean and std difference from gt
err = (
(image_crop - gt_color[None, None].expand_as(image_crop))
.abs()
.mean(dim=-1)
)
zero = err.new_zeros(1)[0]
self.assertClose(err.mean(), zero, atol=1e-2)
self.assertClose(err.std(), zero, atol=1e-2)
err_opacity = (opacity_crop - 1.0).abs()
self.assertClose(err_opacity.mean(), zero, atol=1e-2)
self.assertClose(err_opacity.std(), zero, atol=1e-2)