diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 8a3eb1ee..f1f09057 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -24,9 +24,12 @@ from .implicit import ( AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher, GridRaysampler, + ImplicitRenderer, MonteCarloRaysampler, NDCGridRaysampler, RayBundle, + VolumeRenderer, + VolumeSampler, ray_bundle_to_ray_points, ray_bundle_variables_to_ray_points, ) diff --git a/pytorch3d/renderer/implicit/__init__.py b/pytorch3d/renderer/implicit/__init__.py index ec245b01..634e7046 100644 --- a/pytorch3d/renderer/implicit/__init__.py +++ b/pytorch3d/renderer/implicit/__init__.py @@ -2,6 +2,7 @@ from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler +from .renderer import ImplicitRenderer, VolumeRenderer, VolumeSampler from .utils import ( RayBundle, ray_bundle_to_ray_points, diff --git a/pytorch3d/renderer/implicit/renderer.py b/pytorch3d/renderer/implicit/renderer.py new file mode 100644 index 00000000..e529be6d --- /dev/null +++ b/pytorch3d/renderer/implicit/renderer.py @@ -0,0 +1,372 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import Callable, Tuple + +import torch + +from ...ops.utils import eyes +from ...structures import Volumes +from ...transforms import Transform3d +from ..cameras import CamerasBase +from .raysampling import RayBundle +from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points + + +# The implicit renderer class should be initialized with a +# function for raysampling and a function for raymarching. + +# During the forward pass: +# 1) The raysampler: +# - samples rays from input cameras +# - transforms the rays to world coordinates +# 2) The volumetric_function (which is a callable argument of the forwad pass) +# evaluates ray_densities and ray_features at the sampled ray-points. +# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching +# algorithm to render each ray. + + +class ImplicitRenderer(torch.nn.Module): + """ + A class for rendering a batch of implicit surfaces. The class should + be initialized with a raysampler and raymarcher class which both have + to be a `Callable`. + + VOLUMETRIC_FUNCTION + + The `forward` function of the renderer accepts as input the rendering cameras as well + as the `volumetric_function` `Callable`, which defines a field of opacity + and feature vectors over the 3D domain of the scene. + + A standard `volumetric_function` has the following signature: + ``` + def volumetric_function(ray_bundle: RayBundle) -> Tuple[torch.Tensor, torch.Tensor] + ``` + With the following arguments: + `ray_bundle`: A RayBundle object containing the following variables: + `rays_origins`: A tensor of shape `(minibatch, ..., 3)` denoting + the origins of the rendering rays. + `rays_directions`: A tensor of shape `(minibatch, ..., 3)` + containing the direction vectors of rendering rays. + `rays_lengths`: A tensor of shape + `(minibatch, ..., num_points_per_ray)`containing the + lengths at which the ray points are sampled. + Calling `volumetric_function` then returns the following: + `rays_densities`: A tensor of shape + `(minibatch, ..., num_points_per_ray, opacity_dim)` containing + the an opacity vector for each ray point. + `rays_features`: A tensor of shape + `(minibatch, ..., num_points_per_ray, feature_dim)` containing + the an feature vector for each ray point. + + Example: + A simple volumetric function of a 0-centered + RGB sphere with a unit diameter is defined as follows: + ``` + def volumetric_function( + ray_bundle: RayBundle, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + # first convert the ray origins, directions and lengths + # to 3D ray point locations in world coords + rays_points_world = ray_bundle_to_ray_points(ray_bundle) + + # set the densities as an inverse sigmoid of the + # ray point distance from the sphere centroid + rays_densities = torch.sigmoid( + -100.0 * rays_points_world.norm(dim=-1, keepdim=True) + ) + + # set the ray features to RGB colors proportional + # to the 3D location of the projection of ray points + # on the sphere surface + rays_features = torch.nn.functional.normalize( + rays_points_world, dim=-1 + ) * 0.5 + 0.5 + + return rays_densities, rays_features + ``` + """ + + def __init__(self, raysampler: Callable, raymarcher: Callable): + """ + Args: + raysampler: A `Callable` that takes as input scene cameras + (an instance of `CamerasBase`) and returns a `RayBundle` that + describes the rays emitted from the cameras. + raymarcher: A `Callable` that receives the response of the + `volumetric_function` (an input to `self.forward`) evaluated + along the sampled rays, and renders the rays with a + ray-marching algorithm. + """ + super().__init__() + + if not callable(raysampler): + raise ValueError('"raysampler" has to be a "Callable" object.') + if not callable(raymarcher): + raise ValueError('"raymarcher" has to be a "Callable" object.') + + self.raysampler = raysampler + self.raymarcher = raymarcher + + def forward( + self, cameras: CamerasBase, volumetric_function: Callable, **kwargs + ) -> Tuple[torch.Tensor, RayBundle]: + """ + Render a batch of images using a volumetric function + represented as a callable (e.g. a Pytorch module). + + Args: + cameras: A batch of cameras that render the scene. A `self.raysampler` + takes the cameras as input and samples rays that pass through the + domain of the volumentric function. + volumetric_function: A `Callable` that accepts the parametrizations + of the rendering rays and returns the densities and features + at the respective 3D of the rendering rays. Please refer to + the main class documentation for details. + + Returns: + images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)` + containing the result of the rendering. + ray_bundle: A `RayBundle` containing the parametrizations of the + sampled rendering rays. + """ + + if not callable(volumetric_function): + raise ValueError('"volumetric_function" has to be a "Callable" object.') + + # first call the ray sampler that returns the RayBundle parametrizing + # the rendering rays. + ray_bundle = self.raysampler( + cameras=cameras, volumetric_function=volumetric_function, **kwargs + ) + # ray_bundle.origins - minibatch x ... x 3 + # ray_bundle.directions - minibatch x ... x 3 + # ray_bundle.lengths - minibatch x ... x n_pts_per_ray + # ray_bundle.xys - minibatch x ... x 2 + + # given sampled rays, call the volumetric function that + # evaluates the densities and features at the locations of the + # ray points + rays_densities, rays_features = volumetric_function( + ray_bundle=ray_bundle, cameras=cameras, **kwargs + ) + # ray_densities - minibatch x ... x n_pts_per_ray x density_dim + # ray_features - minibatch x ... x n_pts_per_ray x feature_dim + + # finally, march along the sampled rays to obtain the renders + images = self.raymarcher( + rays_densities=rays_densities, + rays_features=rays_features, + ray_bundle=ray_bundle, + **kwargs + ) + # images - minibatch x ... x (feature_dim + opacity_dim) + + return images, ray_bundle + + +# The volume renderer class should be initialized with a +# function for raysampling and a function for raymarching. + +# During the forward pass: +# 1) The raysampler: +# - samples rays from input cameras +# - transforms the rays to world coordinates +# 2) The scene volumes (which are an argument of the forward function) +# are then sampled at the locations of the ray-points to generate +# ray_densities and ray_features. +# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching +# algorithm to render each ray. + + +class VolumeRenderer(torch.nn.Module): + """ + A class for rendering a batch of Volumes. The class should + be initialized with a raysampler and a raymarcher class which both have + to be a `Callable`. + """ + + def __init__( + self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear" + ): + """ + Args: + raysampler: A `Callable` that takes as input scene cameras + (an instance of `CamerasBase`) and returns a `RayBundle` that + describes the rays emitted from the cameras. + raymarcher: A `Callable` that receives the `volumes` + (an instance of `Volumes` input to `self.forward`) + sampled at the ray-points, and renders the rays with a + ray-marching algorithm. + sample_mode: Defines the algorithm used to sample the volumetric + voxel grid. Can be either "bilinear" or "nearest". + """ + super().__init__() + + self.renderer = ImplicitRenderer(raysampler, raymarcher) + self._sample_mode = sample_mode + + def forward( + self, cameras: CamerasBase, volumes: Volumes, **kwargs + ) -> Tuple[torch.Tensor, RayBundle]: + """ + Render a batch of images using raymarching over rays cast through + input `Volumes`. + + Args: + cameras: A batch of cameras that render the scene. A `self.raysampler` + takes the cameras as input and samples rays that pass through the + domain of the volumentric function. + volumes: An instance of the `Volumes` class representing a + batch of volumes that are being rendered. + + Returns: + images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)` + containing the result of the rendering. + ray_bundle: A `RayBundle` containing the parametrizations of the + sampled rendering rays. + """ + volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode) + return self.renderer( + cameras=cameras, volumetric_function=volumetric_function, **kwargs + ) + + +class VolumeSampler(torch.nn.Module): + """ + A class that allows to sample a batch of volumes `Volumes` + at 3D points sampled along projection rays. + """ + + def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"): + """ + Args: + volumes: An instance of the `Volumes` class representing a + batch if volumes that are being rendered. + sample_mode: Defines the algorithm used to sample the volumetric + voxel grid. Can be either "bilinear" or "nearest". + """ + super().__init__() + if not isinstance(volumes, Volumes): + raise ValueError("'volumes' have to be an instance of the 'Volumes' class.") + self._volumes = volumes + self._sample_mode = sample_mode + + def _get_ray_directions_transform(self): + """ + Compose the ray-directions transform by removing the translation component + from the volume global-to-local coords transform. + """ + world2local = self._volumes.get_world_to_local_coords_transform().get_matrix() + directions_transform_matrix = eyes( + 4, + N=world2local.shape[0], + device=world2local.device, + dtype=world2local.dtype, + ) + directions_transform_matrix[:, :3, :3] = world2local[:, :3, :3] + directions_transform = Transform3d(matrix=directions_transform_matrix) + return directions_transform + + def forward( + self, ray_bundle: RayBundle, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Given an input ray parametrization, the forward function samples + `self._volumes` at the respective 3D ray-points. + + Args: + ray_bundle: A RayBundle object with the following fields: + rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the + origins of the sampling rays in world coords. + rays_directions_world: A tensor of shape `(minibatch, ..., 3)` + containing the direction vectors of sampling rays in world coords. + rays_lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` + containing the lengths at which the rays are sampled. + + Returns: + rays_densities: A tensor of shape + `(minibatch, ..., num_points_per_ray, opacity_dim)` containing the + densitity vectors sampled from the volume at the locations of + the ray points. + rays_features: A tensor of shape + `(minibatch, ..., num_points_per_ray, feature_dim)` containing the + feature vectors sampled from the volume at the locations of + the ray points. + """ + + # take out the interesting parts of ray_bundle + rays_origins_world = ray_bundle.origins + rays_directions_world = ray_bundle.directions + rays_lengths = ray_bundle.lengths + + # validate the inputs + _validate_ray_bundle_variables( + rays_origins_world, rays_directions_world, rays_lengths + ) + if self._volumes.densities().shape[0] != rays_origins_world.shape[0]: + raise ValueError("Input volumes have to have the same batch size as rays.") + + ######################################################### + # 1) convert the origins/directions to the local coords # + ######################################################### + + # origins are mapped with the world_to_local transform of the volumes + rays_origins_local = self._volumes.world_to_local_coords(rays_origins_world) + + # obtain the Transform3d object that transforms ray directions to local coords + directions_transform = self._get_ray_directions_transform() + + # transform the directions to the local coords + rays_directions_local = directions_transform.transform_points( + rays_directions_world.view(rays_lengths.shape[0], -1, 3) + ).view(rays_directions_world.shape) + + ############################ + # 2) obtain the ray points # + ############################ + + # this op produces a fairly big tensor (minibatch, ..., n_samples_per_ray, 3) + rays_points_local = ray_bundle_variables_to_ray_points( + rays_origins_local, rays_directions_local, rays_lengths + ) + + ######################## + # 3) sample the volume # + ######################## + + # generate the tensor for sampling + volumes_densities = self._volumes.densities() + dim_density = volumes_densities.shape[1] + volumes_features = self._volumes.features() + # adjust the volumes_features variable in case we have a feature-less volume + if volumes_features is None: + dim_feature = 0 + data_to_sample = volumes_densities + else: + dim_feature = volumes_features.shape[1] + data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1) + + # reshape to a size which grid_sample likes + rays_points_local_flat = rays_points_local.view( + rays_points_local.shape[0], -1, 1, 1, 3 + ) + + # run the grid sampler + data_sampled = torch.nn.functional.grid_sample( + data_to_sample, + rays_points_local_flat, + align_corners=True, + mode=self._sample_mode, + ) + + # permute the dimensions & reshape after sampling + data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view( + *rays_points_local.shape[:-1], data_sampled.shape[1] + ) + + # split back to densities and features + rays_densities, rays_features = data_sampled.split( + [dim_density, dim_feature], dim=-1 + ) + + return rays_densities, rays_features diff --git a/pytorch3d/renderer/implicit/utils.py b/pytorch3d/renderer/implicit/utils.py index b5f01812..5a6a4265 100644 --- a/pytorch3d/renderer/implicit/utils.py +++ b/pytorch3d/renderer/implicit/utils.py @@ -53,12 +53,12 @@ def ray_bundle_variables_to_ray_points( rays_lengths: torch.Tensor, ) -> torch.Tensor: """ - Converts rays parametrized with origins, directions + Converts rays parametrized with origins and directions to 3D points by extending each ray according to the corresponding - ray_length: + ray length: E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions` - and `rays_lengths`, the ray point at position `[i, j]` is: + and `rays_lengths`, the ray point at position `[i, j]` is: ``` rays_points[i, j, :] = ( rays_origins[i, :] @@ -80,3 +80,39 @@ def ray_bundle_variables_to_ray_points( + rays_lengths[..., :, None] * rays_directions[..., None, :] ) return rays_points + + +def _validate_ray_bundle_variables( + rays_origins: torch.Tensor, + rays_directions: torch.Tensor, + rays_lengths: torch.Tensor, +): + """ + Validate the shapes of RayBundle variables + `rays_origins`, `rays_directions`, and `rays_lengths`. + """ + ndim = rays_origins.ndim + if any(r.ndim != ndim for r in (rays_directions, rays_lengths)): + raise ValueError( + "rays_origins, rays_directions and rays_lengths" + + " have to have the same number of dimensions." + ) + + if ndim <= 2: + raise ValueError( + "rays_origins, rays_directions and rays_lengths" + + " have to have at least 3 dimensions." + ) + + spatial_size = rays_origins.shape[:-1] + if any(spatial_size != r.shape[:-1] for r in (rays_directions, rays_lengths)): + raise ValueError( + "The shapes of rays_origins, rays_directions and rays_lengths" + + " may differ only in the last dimension." + ) + + if any(r.shape[-1] != 3 for r in (rays_origins, rays_directions)): + raise ValueError( + "The size of the last dimension of rays_origins/rays_directions" + + "has to be 3." + ) diff --git a/tests/bm_render_implicit.py b/tests/bm_render_implicit.py new file mode 100644 index 00000000..7063d334 --- /dev/null +++ b/tests/bm_render_implicit.py @@ -0,0 +1,22 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import itertools + +from fvcore.common.benchmark import benchmark +from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher +from test_render_implicit import TestRenderImplicit + + +def bm_render_volumes() -> None: + case_grid = { + "batch_size": [1, 5], + "raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher], + "n_rays_per_image": [64 ** 2, 256 ** 2], + "n_pts_per_ray": [16, 128], + } + test_cases = itertools.product(*case_grid.values()) + kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases] + + benchmark( + TestRenderImplicit.renderer, "IMPLICIT_RENDERER", kwargs_list, warmup_iters=1 + ) diff --git a/tests/bm_render_volumes.py b/tests/bm_render_volumes.py new file mode 100644 index 00000000..c8fe17cd --- /dev/null +++ b/tests/bm_render_volumes.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import itertools + +from fvcore.common.benchmark import benchmark +from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher +from test_render_volumes import TestRenderVolumes + + +def bm_render_volumes() -> None: + case_grid = { + "volume_size": [tuple([17] * 3), tuple([129] * 3)], + "batch_size": [1, 5], + "shape": ["sphere", "cube"], + "raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher], + "n_rays_per_image": [64 ** 2, 256 ** 2], + "n_pts_per_ray": [16, 128], + } + test_cases = itertools.product(*case_grid.values()) + kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases] + + benchmark( + TestRenderVolumes.renderer, "VOLUME_RENDERER", kwargs_list, warmup_iters=1 + ) diff --git a/tests/test_render_implicit.py b/tests/test_render_implicit.py new file mode 100644 index 00000000..0884a818 --- /dev/null +++ b/tests/test_render_implicit.py @@ -0,0 +1,403 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import unittest + +import numpy as np +import torch +from common_testing import TestCaseMixin +from pytorch3d.renderer import ( + BlendParams, + EmissionAbsorptionRaymarcher, + GridRaysampler, + ImplicitRenderer, + Materials, + MeshRasterizer, + MeshRenderer, + MonteCarloRaysampler, + NDCGridRaysampler, + PointLights, + RasterizationSettings, + RayBundle, + SoftPhongShader, + TexturesVertex, + ray_bundle_to_ray_points, +) +from pytorch3d.structures import Meshes +from pytorch3d.utils import ico_sphere +from test_render_volumes import init_cameras + + +DEBUG = False +if DEBUG: + import os + import tempfile + + from PIL import Image + + +def spherical_volumetric_function( + ray_bundle: RayBundle, + sphere_centroid: torch.Tensor, + sphere_diameter: float, + **kwargs, +): + """ + Volumetric function of a simple RGB sphere with diameter `sphere_diameter` + and centroid `sphere_centroid`. + """ + # convert the ray bundle to world points + rays_points_world = ray_bundle_to_ray_points(ray_bundle) + batch_size = rays_points_world.shape[0] + + # surface_vectors = vectors from world coords towards the sphere centroid + surface_vectors = ( + rays_points_world.view(batch_size, -1, 3) - sphere_centroid[:, None] + ) + + # the squared distance of each ray point to the centroid of the sphere + surface_dist = ( + (surface_vectors ** 2) + .sum(-1, keepdim=True) + .view(*rays_points_world.shape[:-1], 1) + ) + + # set all ray densities within the sphere_diameter distance from the centroid to 1 + rays_densities = torch.sigmoid(-100.0 * (surface_dist - sphere_diameter ** 2)) + + # ray colors are proportional to the normalized surface_vectors + rays_features = ( + torch.nn.functional.normalize( + surface_vectors.view(rays_points_world.shape), dim=-1 + ) + * 0.5 + + 0.5 + ) + + return rays_densities, rays_features + + +class TestRenderImplicit(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + + @staticmethod + def renderer( + batch_size=10, + raymarcher_type=EmissionAbsorptionRaymarcher, + n_rays_per_image=10, + n_pts_per_ray=10, + sphere_diameter=0.75, + ): + # generate NDC camera extrinsics and intrinsics + cameras = init_cameras(batch_size, image_size=None, ndc=True) + + # get rand offset of the volume + sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1 + + # init the mc raysampler + raysampler = MonteCarloRaysampler( + min_x=-1.0, + max_x=1.0, + min_y=-1.0, + max_y=1.0, + n_rays_per_image=n_rays_per_image, + n_pts_per_ray=n_pts_per_ray, + min_depth=0.1, + max_depth=2.0, + ).to(cameras.device) + + # get the raymarcher + raymarcher = raymarcher_type() + + # get the implicit renderer + renderer = ImplicitRenderer(raysampler=raysampler, raymarcher=raymarcher) + + def run_renderer(): + renderer( + cameras=cameras, + volumetric_function=spherical_volumetric_function, + sphere_centroid=sphere_centroid, + sphere_diameter=sphere_diameter, + ) + + return run_renderer + + def test_input_types(self): + """ + Check that ValueErrors are thrown where expected. + """ + # check the constructor + for bad_raysampler in (None, 5, []): + for bad_raymarcher in (None, 5, []): + with self.assertRaises(ValueError): + ImplicitRenderer( + raysampler=bad_raysampler, raymarcher=bad_raymarcher + ) + + # init a trivial renderer + renderer = ImplicitRenderer( + raysampler=NDCGridRaysampler( + image_width=100, + image_height=100, + n_pts_per_ray=10, + min_depth=0.1, + max_depth=1.0, + ), + raymarcher=EmissionAbsorptionRaymarcher(), + ) + + # get default cameras + cameras = init_cameras() + + for bad_volumetric_function in (None, 5, []): + with self.assertRaises(ValueError): + renderer(cameras=cameras, volumetric_function=bad_volumetric_function) + + def test_compare_with_meshes_renderer( + self, batch_size=11, image_size=100, sphere_diameter=0.6 + ): + """ + Generate a spherical RGB volumetric function and its corresponding mesh + and check whether MeshesRenderer returns the same images as the + corresponding ImplicitRenderer. + """ + + # generate NDC camera extrinsics and intrinsics + cameras = init_cameras( + batch_size, image_size=[image_size, image_size], ndc=True + ) + + # get rand offset of the volume + sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1 + sphere_centroid.requires_grad = True + + # init the grid raysampler with the ndc grid + raysampler = NDCGridRaysampler( + image_width=image_size, + image_height=image_size, + n_pts_per_ray=256, + min_depth=0.1, + max_depth=2.0, + ) + + # get the EA raymarcher + raymarcher = EmissionAbsorptionRaymarcher() + + # jitter the camera intrinsics a bit for each render + cameras_randomized = cameras.clone() + cameras_randomized.principal_point = ( + torch.randn_like(cameras.principal_point) * 0.3 + ) + cameras_randomized.focal_length = ( + cameras.focal_length + torch.randn_like(cameras.focal_length) * 0.2 + ) + + # the list of differentiable camera vars + cam_vars = ("R", "T", "focal_length", "principal_point") + # enable the gradient caching for the camera variables + for cam_var in cam_vars: + getattr(cameras_randomized, cam_var).requires_grad = True + + # get the implicit renderer + images_opacities = ImplicitRenderer( + raysampler=raysampler, raymarcher=raymarcher + )( + cameras=cameras_randomized, + volumetric_function=spherical_volumetric_function, + sphere_centroid=sphere_centroid, + sphere_diameter=sphere_diameter, + )[ + 0 + ] + + # check that the renderer does not erase gradients + loss = images_opacities.sum() + loss.backward() + for check_var in ( + *[getattr(cameras_randomized, cam_var) for cam_var in cam_vars], + sphere_centroid, + ): + self.assertIsNotNone(check_var.grad) + + # instantiate the corresponding spherical mesh + ico = ico_sphere(level=4, device=cameras.device).extend(batch_size) + verts = ( + torch.nn.functional.normalize(ico.verts_padded(), dim=-1) * sphere_diameter + + sphere_centroid[:, None] + ) + meshes = Meshes( + verts=verts, + faces=ico.faces_padded(), + textures=TexturesVertex( + verts_features=( + torch.nn.functional.normalize(verts, dim=-1) * 0.5 + 0.5 + ) + ), + ) + + # instantiate the corresponding mesh renderer + lights = PointLights(device=cameras.device, location=[[0.0, 0.0, 0.0]]) + renderer_textured = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras_randomized, + raster_settings=RasterizationSettings( + image_size=image_size, blur_radius=1e-3, faces_per_pixel=10 + ), + ), + shader=SoftPhongShader( + device=cameras.device, + cameras=cameras_randomized, + lights=lights, + materials=Materials( + ambient_color=((2.0, 2.0, 2.0),), + diffuse_color=((0.0, 0.0, 0.0),), + specular_color=((0.0, 0.0, 0.0),), + shininess=64, + device=cameras.device, + ), + blend_params=BlendParams( + sigma=1e-3, gamma=1e-4, background_color=(0.0, 0.0, 0.0) + ), + ), + ) + + # get the mesh render + images_opacities_meshes = renderer_textured( + meshes, cameras=cameras_randomized, lights=lights + ) + + if DEBUG: + outdir = tempfile.gettempdir() + "/test_implicit_vs_mesh_renderer" + os.makedirs(outdir, exist_ok=True) + + frames = [] + for (image_opacity, image_opacity_mesh) in zip( + images_opacities, images_opacities_meshes + ): + image, opacity = image_opacity.split([3, 1], dim=-1) + image_mesh, opacity_mesh = image_opacity_mesh.split([3, 1], dim=-1) + diff_image = ( + ((image - image_mesh) * 0.5 + 0.5) + .mean(dim=2, keepdim=True) + .repeat(1, 1, 3) + ) + image_pil = Image.fromarray( + ( + torch.cat( + ( + image, + image_mesh, + diff_image, + opacity.repeat(1, 1, 3), + opacity_mesh.repeat(1, 1, 3), + ), + dim=1, + ) + .detach() + .cpu() + .numpy() + * 255.0 + ).astype(np.uint8) + ) + frames.append(image_pil) + + # export gif + outfile = os.path.join(outdir, "implicit_vs_mesh_render.gif") + frames[0].save( + outfile, + save_all=True, + append_images=frames[1:], + duration=batch_size // 15, + loop=0, + ) + print(f"exported {outfile}") + + # export concatenated frames + outfile_cat = os.path.join(outdir, "implicit_vs_mesh_render.png") + Image.fromarray(np.concatenate([np.array(f) for f in frames], axis=0)).save( + outfile_cat + ) + print(f"exported {outfile_cat}") + + # compare the renders + diff = (images_opacities - images_opacities_meshes).abs().mean(dim=-1) + mu_diff = diff.mean(dim=(1, 2)) + std_diff = diff.std(dim=(1, 2)) + self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=5e-2) + self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2) + + def test_rotating_gif( + self, n_frames=50, fps=15, image_size=(100, 100), sphere_diameter=0.5 + ): + """ + Render a gif animation of a rotating sphere (runs only if `DEBUG==True`). + """ + + if not DEBUG: + # do not run this if debug is False + return + + # generate camera extrinsics and intrinsics + cameras = init_cameras(n_frames, image_size=image_size) + + # init the grid raysampler + raysampler = GridRaysampler( + min_x=0.5, + max_x=image_size[1] - 0.5, + min_y=0.5, + max_y=image_size[0] - 0.5, + image_width=image_size[1], + image_height=image_size[0], + n_pts_per_ray=256, + min_depth=0.1, + max_depth=2.0, + ) + + # get the EA raymarcher + raymarcher = EmissionAbsorptionRaymarcher() + + # get the implicit render + renderer = ImplicitRenderer(raysampler=raysampler, raymarcher=raymarcher) + + # get the (0) centroid of the sphere + sphere_centroid = torch.zeros(n_frames, 3, device=cameras.device) * 0.1 + + # run the renderer + images_opacities = renderer( + cameras=cameras, + volumetric_function=spherical_volumetric_function, + sphere_centroid=sphere_centroid, + sphere_diameter=sphere_diameter, + )[0] + + # split output to the alpha channel and rendered images + images, opacities = images_opacities[..., :3], images_opacities[..., 3] + + # export the gif + outdir = tempfile.gettempdir() + "/test_implicit_renderer_gifs" + os.makedirs(outdir, exist_ok=True) + frames = [] + for image, opacity in zip(images, opacities): + image_pil = Image.fromarray( + ( + torch.cat( + (image, opacity[..., None].clamp(0.0, 1.0).repeat(1, 1, 3)), + dim=1, + ) + .detach() + .cpu() + .numpy() + * 255.0 + ).astype(np.uint8) + ) + frames.append(image_pil) + outfile = os.path.join(outdir, "rotating_sphere.gif") + frames[0].save( + outfile, + save_all=True, + append_images=frames[1:], + duration=n_frames // fps, + loop=0, + ) + print(f"exported {outfile}") diff --git a/tests/test_render_volumes.py b/tests/test_render_volumes.py new file mode 100644 index 00000000..5ff22fc7 --- /dev/null +++ b/tests/test_render_volumes.py @@ -0,0 +1,711 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import unittest +from typing import Optional, Tuple + +import numpy as np +import torch +from common_testing import TestCaseMixin +from pytorch3d.ops import knn_points +from pytorch3d.renderer import ( + AbsorptionOnlyRaymarcher, + AlphaCompositor, + EmissionAbsorptionRaymarcher, + GridRaysampler, + MonteCarloRaysampler, + NDCGridRaysampler, + PerspectiveCameras, + PointsRasterizationSettings, + PointsRasterizer, + PointsRenderer, + RayBundle, + VolumeRenderer, + VolumeSampler, +) +from pytorch3d.renderer.implicit.utils import _validate_ray_bundle_variables +from pytorch3d.structures import Pointclouds, Volumes +from test_points_to_volumes import init_uniform_y_rotations + + +DEBUG = False +if DEBUG: + import os + import tempfile + + from PIL import Image + + +ZERO_TRANSLATION = torch.zeros(1, 3) + + +def init_boundary_volume( + batch_size: int, + volume_size: Tuple[int, int, int], + border_offset: int = 2, + shape: str = "cube", + volume_translation: torch.Tensor = ZERO_TRANSLATION, +): + """ + Generate a volume with sides colored with distinct colors. + """ + + device = torch.device("cuda") + + # first center the volume for the purpose of generating the canonical shape + volume_translation_tmp = (0.0, 0.0, 0.0) + + # set the voxel size to 1 / (volume_size-1) + volume_voxel_size = 1 / (volume_size[0] - 1.0) + + # colors of the sides of the cube + clr_sides = torch.tensor( + [ + [1.0, 1.0, 1.0], + [1.0, 0.0, 0.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 1.0, 1.0], + ], + dtype=torch.float32, + device=device, + ) + + # get the coord grid of the volume + coord_grid = Volumes( + densities=torch.zeros(1, 1, *volume_size, device=device), + voxel_size=volume_voxel_size, + volume_translation=volume_translation_tmp, + ).get_coord_grid()[0] + + # extract the boundary points and their colors of the cube + if shape == "cube": + boundary_points, boundary_colors = [], [] + for side, clr_side in enumerate(clr_sides): + first = side % 2 + dim = side // 2 + slices = [slice(border_offset, -border_offset, 1)] * 3 + slices[dim] = int(border_offset * (2 * first - 1)) + slices.append(slice(0, 3, 1)) + boundary_points_ = coord_grid[slices].reshape(-1, 3) + boundary_points.append(boundary_points_) + boundary_colors.append(clr_side[None].expand_as(boundary_points_)) + # set the internal part of the volume to be completely opaque + volume_densities = torch.zeros(*volume_size, device=device) + volume_densities[[slice(border_offset, -border_offset, 1)] * 3] = 1.0 + boundary_points, boundary_colors = [ + torch.cat(p, dim=0) for p in [boundary_points, boundary_colors] + ] + # color the volume voxels with the nearest boundary points' color + _, idx, _ = knn_points( + coord_grid.view(1, -1, 3), boundary_points.view(1, -1, 3) + ) + volume_colors = ( + boundary_colors[idx.view(-1)].view(*volume_size, 3).permute(3, 0, 1, 2) + ) + + elif shape == "sphere": + # set all voxels within a certain distance from the origin to be opaque + volume_densities = ( + coord_grid.norm(dim=-1) + <= 0.5 * volume_voxel_size * (volume_size[0] - border_offset) + ).float() + # color each voxel with the standrd spherical color + volume_colors = ( + (torch.nn.functional.normalize(coord_grid, dim=-1) + 1.0) * 0.5 + ).permute(3, 0, 1, 2) + + else: + raise ValueError(shape) + + volume_voxel_size = torch.ones((batch_size, 1), device=device) * volume_voxel_size + volume_translation = volume_translation.expand(batch_size, 3) + volumes = Volumes( + densities=volume_densities[None, None].expand(batch_size, 1, *volume_size), + features=volume_colors[None].expand(batch_size, 3, *volume_size), + voxel_size=volume_voxel_size, + volume_translation=volume_translation, + ) + + return volumes, volume_voxel_size, volume_translation + + +def init_cameras( + batch_size: int = 10, + image_size: Optional[Tuple[int, int]] = (50, 50), + ndc: bool = False, +): + """ + Initialize a batch of cameras whose extrinsics rotate the cameras around + the world's y axis. + Depending on whether we want an NDC-space (`ndc==True`) or a screen-space camera, + the camera's focal length and principal point are initialized accordingly: + For `ndc==False`, p0=focal_length=image_size/2. + For `ndc==True`, focal_length=1.0, p0 = 0.0. + The the z-coordinate of the translation vector of each camera is fixed to 1.5. + """ + device = torch.device("cuda:0") + + # trivial rotations + R = init_uniform_y_rotations(batch_size).to(device) + + # move camera 1.5 m away from the scene center + T = torch.zeros((batch_size, 3), device=device) + T[:, 2] = 1.5 + + if ndc: + p0 = torch.zeros(batch_size, 2, device=device) + focal = torch.ones(batch_size, device=device) + else: + p0 = torch.ones(batch_size, 2, device=device) + p0[:, 0] *= image_size[1] * 0.5 + p0[:, 1] *= image_size[0] * 0.5 + focal = image_size[0] * torch.ones(batch_size, device=device) + + # convert to a Camera object + cameras = PerspectiveCameras(focal, p0, R=R, T=T, device=device) + return cameras + + +class TestRenderVolumes(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + + @staticmethod + def renderer( + volume_size=(25, 25, 25), + batch_size=10, + shape="sphere", + raymarcher_type=EmissionAbsorptionRaymarcher, + n_rays_per_image=10, + n_pts_per_ray=10, + ): + # get the volumes + volumes = init_boundary_volume( + volume_size=volume_size, batch_size=batch_size, shape=shape + )[0] + + # init the mc raysampler + raysampler = MonteCarloRaysampler( + min_x=-1.0, + max_x=1.0, + min_y=-1.0, + max_y=1.0, + n_rays_per_image=n_rays_per_image, + n_pts_per_ray=n_pts_per_ray, + min_depth=0.1, + max_depth=2.0, + ).to(volumes.device) + + # get the raymarcher + raymarcher = raymarcher_type() + + renderer = VolumeRenderer( + raysampler=raysampler, raymarcher=raymarcher, sample_mode="bilinear" + ) + + # generate NDC camera extrinsics and intrinsics + cameras = init_cameras(batch_size, image_size=None, ndc=True) + + def run_renderer(): + renderer(cameras=cameras, volumes=volumes) + + return run_renderer + + def test_input_types(self, batch_size: int = 10): + """ + Check that ValueErrors are thrown where expected. + """ + # check the constructor + for bad_raysampler in (None, 5, []): + for bad_raymarcher in (None, 5, []): + with self.assertRaises(ValueError): + VolumeRenderer(raysampler=bad_raysampler, raymarcher=bad_raymarcher) + + raysampler = NDCGridRaysampler( + image_width=100, + image_height=100, + n_pts_per_ray=10, + min_depth=0.1, + max_depth=1.0, + ) + + # init a trivial renderer + renderer = VolumeRenderer( + raysampler=raysampler, raymarcher=EmissionAbsorptionRaymarcher() + ) + + # get cameras + cameras = init_cameras(batch_size=batch_size) + + # get volumes + volumes = init_boundary_volume(volume_size=(10, 10, 10), batch_size=batch_size)[ + 0 + ] + + # different batch sizes for cameras / volumes + with self.assertRaises(ValueError): + renderer(cameras=cameras, volumes=volumes[:-1]) + + # ray checks for VolumeSampler + volume_sampler = VolumeSampler(volumes=volumes) + n_rays = 100 + for bad_ray_bundle in ( + ( + torch.rand(batch_size, n_rays, 3), + torch.rand(batch_size, n_rays + 1, 3), + torch.rand(batch_size, n_rays, 10), + ), + ( + torch.rand(batch_size + 1, n_rays, 3), + torch.rand(batch_size, n_rays, 3), + torch.rand(batch_size, n_rays, 10), + ), + ( + torch.rand(batch_size, n_rays, 3), + torch.rand(batch_size, n_rays, 2), + torch.rand(batch_size, n_rays, 10), + ), + ( + torch.rand(batch_size, n_rays, 3), + torch.rand(batch_size, n_rays, 3), + torch.rand(batch_size, n_rays), + ), + ): + ray_bundle = RayBundle( + **dict( + zip( + ("origins", "directions", "lengths"), + [r.to(cameras.device) for r in bad_ray_bundle], + ) + ), + xys=None, + ) + with self.assertRaises(ValueError): + volume_sampler(ray_bundle) + + # check also explicitly the ray bundle validation function + with self.assertRaises(ValueError): + _validate_ray_bundle_variables(*bad_ray_bundle) + + def test_compare_with_pointclouds_renderer( + self, batch_size=11, volume_size=(30, 30, 30), image_size=200 + ): + """ + Generate a volume and its corresponding point cloud and check whether + PointsRenderer returns the same images as the corresponding VolumeRenderer. + """ + + # generate NDC camera extrinsics and intrinsics + cameras = init_cameras( + batch_size, image_size=[image_size, image_size], ndc=True + ) + + # init the boundary volume + for shape in ("sphere", "cube"): + + if not DEBUG and shape == "cube": + # do not run numeric checks for the cube as the + # differences in rendering equations make the renders incomparable + continue + + # get rand offset of the volume + volume_translation = torch.randn(batch_size, 3) * 0.1 + # volume_translation[2] = 0.1 + volumes = init_boundary_volume( + volume_size=volume_size, + batch_size=batch_size, + shape=shape, + volume_translation=volume_translation, + )[0] + + # convert the volumes to a pointcloud + points = [] + points_features = [] + for densities_one, features_one, grid_one in zip( + volumes.densities(), + volumes.features(), + volumes.get_coord_grid(world_coordinates=True), + ): + opaque = densities_one.view(-1) > 1e-4 + points.append(grid_one.view(-1, 3)[opaque]) + points_features.append(features_one.reshape(3, -1).t()[opaque]) + pointclouds = Pointclouds(points, features=points_features) + + # init the grid raysampler with the ndc grid + coord_range = 1.0 + half_pix_size = coord_range / image_size + raysampler = NDCGridRaysampler( + image_width=image_size, + image_height=image_size, + n_pts_per_ray=256, + min_depth=0.1, + max_depth=2.0, + ) + + # get the EA raymarcher + raymarcher = EmissionAbsorptionRaymarcher() + + # jitter the camera intrinsics a bit for each render + cameras_randomized = cameras.clone() + cameras_randomized.principal_point = ( + torch.randn_like(cameras.principal_point) * 0.3 + ) + cameras_randomized.focal_length = ( + cameras.focal_length + torch.randn_like(cameras.focal_length) * 0.2 + ) + + # get the volumetric render + images = VolumeRenderer( + raysampler=raysampler, raymarcher=raymarcher, sample_mode="bilinear" + )(cameras=cameras_randomized, volumes=volumes)[0][..., :3] + + # instantiate the points renderer + point_radius = 6 * half_pix_size + points_renderer = PointsRenderer( + rasterizer=PointsRasterizer( + cameras=cameras_randomized, + raster_settings=PointsRasterizationSettings( + image_size=image_size, radius=point_radius, points_per_pixel=10 + ), + ), + compositor=AlphaCompositor(), + ) + + # get the point render + images_pts = points_renderer(pointclouds) + + if shape == "sphere": + diff = (images - images_pts).abs().mean(dim=-1) + mu_diff = diff.mean(dim=(1, 2)) + std_diff = diff.std(dim=(1, 2)) + self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=3e-2) + self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2) + + if DEBUG: + outdir = tempfile.gettempdir() + "/test_volume_vs_pts_renderer" + os.makedirs(outdir, exist_ok=True) + + frames = [] + for (image, image_pts) in zip(images, images_pts): + diff_image = ( + ((image - image_pts) * 0.5 + 0.5) + .mean(dim=2, keepdim=True) + .repeat(1, 1, 3) + ) + image_pil = Image.fromarray( + ( + torch.cat((image, image_pts, diff_image), dim=1) + .detach() + .cpu() + .numpy() + * 255.0 + ).astype(np.uint8) + ) + frames.append(image_pil) + + # export gif + outfile = os.path.join(outdir, f"volume_vs_pts_render_{shape}.gif") + frames[0].save( + outfile, + save_all=True, + append_images=frames[1:], + duration=batch_size // 15, + loop=0, + ) + print(f"exported {outfile}") + + # export concatenated frames + outfile_cat = os.path.join(outdir, f"volume_vs_pts_render_{shape}.png") + Image.fromarray( + np.concatenate([np.array(f) for f in frames], axis=0) + ).save(outfile_cat) + print(f"exported {outfile_cat}") + + def test_monte_carlo_rendering( + self, n_frames=20, volume_size=(30, 30, 30), image_size=(40, 50) + ): + """ + Tests that rendering with the MonteCarloRaysampler matches the + rendering with GridRaysampler sampled at the corresponding + MonteCarlo locations. + """ + volumes = init_boundary_volume( + volume_size=volume_size, batch_size=n_frames, shape="sphere" + )[0] + + # generate camera extrinsics and intrinsics + cameras = init_cameras(n_frames, image_size=image_size) + + # init the grid raysampler + raysampler_grid = GridRaysampler( + min_x=0.5, + max_x=image_size[1] - 0.5, + min_y=0.5, + max_y=image_size[0] - 0.5, + image_width=image_size[1], + image_height=image_size[0], + n_pts_per_ray=256, + min_depth=0.5, + max_depth=2.0, + ) + + # init the mc raysampler + raysampler_mc = MonteCarloRaysampler( + min_x=0.5, + max_x=image_size[1] - 0.5, + min_y=0.5, + max_y=image_size[0] - 0.5, + n_rays_per_image=3000, + n_pts_per_ray=256, + min_depth=0.5, + max_depth=2.0, + ) + + # get the EA raymarcher + raymarcher = EmissionAbsorptionRaymarcher() + + # get both mc and grid renders + ( + (images_opacities_mc, ray_bundle_mc), + (images_opacities_grid, ray_bundle_grid), + ) = [ + VolumeRenderer( + raysampler=raysampler_grid, + raymarcher=raymarcher, + sample_mode="bilinear", + )(cameras=cameras, volumes=volumes) + for raysampler in (raysampler_mc, raysampler_grid) + ] + + # convert the mc sampling locations to [-1, 1] + sample_loc = ray_bundle_mc.xys.clone() + sample_loc[..., 0] = 2 * (sample_loc[..., 0] / image_size[1]) - 1 + sample_loc[..., 1] = 2 * (sample_loc[..., 1] / image_size[0]) - 1 + + # sample the grid render at the mc locations + images_opacities_mc_ = torch.nn.functional.grid_sample( + images_opacities_grid.permute(0, 3, 1, 2), sample_loc, align_corners=False + ) + + # check that the samples are the same + self.assertClose( + images_opacities_mc.permute(0, 3, 1, 2), images_opacities_mc_, atol=1e-4 + ) + + def test_rotating_gif( + self, n_frames=50, fps=15, volume_size=(100, 100, 100), image_size=(100, 100) + ): + """ + Render a gif animation of a rotating cube/sphere (runs only if `DEBUG==True`). + """ + + if not DEBUG: + # do not run this if debug is False + return + + for shape in ("sphere", "cube"): + for sample_mode in ("bilinear", "nearest"): + + volumes = init_boundary_volume( + volume_size=volume_size, batch_size=n_frames, shape=shape + )[0] + + # generate camera extrinsics and intrinsics + cameras = init_cameras(n_frames, image_size=image_size) + + # init the grid raysampler + raysampler = GridRaysampler( + min_x=0.5, + max_x=image_size[1] - 0.5, + min_y=0.5, + max_y=image_size[0] - 0.5, + image_width=image_size[1], + image_height=image_size[0], + n_pts_per_ray=256, + min_depth=0.5, + max_depth=2.0, + ) + + # get the EA raymarcher + raymarcher = EmissionAbsorptionRaymarcher() + + # intialize the renderer + renderer = VolumeRenderer( + raysampler=raysampler, + raymarcher=raymarcher, + sample_mode=sample_mode, + ) + + # run the renderer + images_opacities = renderer(cameras=cameras, volumes=volumes)[0] + + # split output to the alpha channel and rendered images + images, opacities = images_opacities[..., :3], images_opacities[..., 3] + + # export the gif + outdir = tempfile.gettempdir() + "/test_volume_renderer_gifs" + os.makedirs(outdir, exist_ok=True) + frames = [] + for image, opacity in zip(images, opacities): + image_pil = Image.fromarray( + ( + torch.cat( + (image, opacity[..., None].repeat(1, 1, 3)), dim=1 + ) + .detach() + .cpu() + .numpy() + * 255.0 + ).astype(np.uint8) + ) + frames.append(image_pil) + outfile = os.path.join(outdir, f"{shape}_{sample_mode}.gif") + frames[0].save( + outfile, + save_all=True, + append_images=frames[1:], + duration=n_frames // fps, + loop=0, + ) + print(f"exported {outfile}") + + def test_rotating_cube_volume_render(self): + """ + Generates 4 renders of 4 sides of a volume representing a 3D cube. + Since each side of the cube is homogenously colored with + a different color, this should result in 4 images of homogenous color + with the depth of each pixel equal to a constant. + """ + + # batch_size = 4 sides of the cube + batch_size = 4 + image_size = (50, 50) + + for volume_size in ([25, 25, 25],): + for sample_mode in ("bilinear", "nearest"): + + volume_translation = torch.zeros(4, 3) + volume_translation.requires_grad = True + volumes, volume_voxel_size, _ = init_boundary_volume( + volume_size=volume_size, + batch_size=batch_size, + shape="cube", + volume_translation=volume_translation, + ) + + # generate camera extrinsics and intrinsics + cameras = init_cameras(batch_size, image_size=image_size) + + # enable the gradient caching for the camera variables + # the list of differentiable camera vars + cam_vars = ("R", "T", "focal_length", "principal_point") + for cam_var in cam_vars: + getattr(cameras, cam_var).requires_grad = True + # enable the grad for volume vars as well + volumes.features().requires_grad = True + volumes.densities().requires_grad = True + + raysampler = GridRaysampler( + min_x=0.5, + max_x=image_size[1] - 0.5, + min_y=0.5, + max_y=image_size[0] - 0.5, + image_width=image_size[1], + image_height=image_size[0], + n_pts_per_ray=128, + min_depth=0.01, + max_depth=3.0, + ) + + raymarcher = EmissionAbsorptionRaymarcher() + renderer = VolumeRenderer( + raysampler=raysampler, + raymarcher=raymarcher, + sample_mode=sample_mode, + ) + images_opacities = renderer(cameras=cameras, volumes=volumes)[0] + images, opacities = images_opacities[..., :3], images_opacities[..., 3] + + # check that the renderer does not erase gradients + loss = images_opacities.sum() + loss.backward() + for check_var in ( + *[getattr(cameras, cam_var) for cam_var in cam_vars], + volumes.features(), + volumes.densities(), + volume_translation, + ): + self.assertIsNotNone(check_var.grad) + + # ao opacities should be exactly the same as the ea ones + # we can further get the ea opacities from a feature-less + # version of our volumes + raymarcher_ao = AbsorptionOnlyRaymarcher() + renderer_ao = VolumeRenderer( + raysampler=raysampler, + raymarcher=raymarcher_ao, + sample_mode=sample_mode, + ) + volumes_featureless = Volumes( + densities=volumes.densities(), + volume_translation=volume_translation, + voxel_size=volume_voxel_size, + ) + opacities_ao = renderer_ao( + cameras=cameras, volumes=volumes_featureless + )[0][..., 0] + self.assertClose(opacities, opacities_ao) + + # colors of the sides of the cube + gt_clr_sides = torch.tensor( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [0.0, 1.0, 0.0], + ], + dtype=torch.float32, + device=images.device, + ) + + if DEBUG: + outdir = tempfile.gettempdir() + "/test_volume_renderer" + os.makedirs(outdir, exist_ok=True) + for imidx, (image, opacity) in enumerate(zip(images, opacities)): + for image_ in (image, opacity): + image_pil = Image.fromarray( + (image_.detach().cpu().numpy() * 255.0).astype(np.uint8) + ) + outfile = ( + outdir + + f"/rgb_{sample_mode}" + + f"_{str(volume_size).replace(' ','')}" + + f"_{imidx:003d}" + ) + if image_ is image: + outfile += "_rgb.png" + else: + outfile += "_opacity.png" + image_pil.save(outfile) + print(f"exported {outfile}") + + border = 10 + for image, opacity, gt_color in zip(images, opacities, gt_clr_sides): + image_crop = image[border:-border, border:-border] + opacity_crop = opacity[border:-border, border:-border] + + # check mean and std difference from gt + err = ( + (image_crop - gt_color[None, None].expand_as(image_crop)) + .abs() + .mean(dim=-1) + ) + zero = err.new_zeros(1)[0] + self.assertClose(err.mean(), zero, atol=1e-2) + self.assertClose(err.std(), zero, atol=1e-2) + + err_opacity = (opacity_crop - 1.0).abs() + self.assertClose(err_opacity.mean(), zero, atol=1e-2) + self.assertClose(err_opacity.std(), zero, atol=1e-2)