diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index f383f6f7..9ae62aa0 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -20,6 +20,7 @@ from .cameras import ( look_at_rotation, look_at_view_transform, ) +from .implicit import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher from .lighting import DirectionalLights, PointLights, diffuse, specular from .materials import Materials from .mesh import ( diff --git a/pytorch3d/renderer/implicit/__init__.py b/pytorch3d/renderer/implicit/__init__.py new file mode 100644 index 00000000..f5da3e2f --- /dev/null +++ b/pytorch3d/renderer/implicit/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher + + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/renderer/implicit/raymarching.py b/pytorch3d/renderer/implicit/raymarching.py new file mode 100644 index 00000000..44308226 --- /dev/null +++ b/pytorch3d/renderer/implicit/raymarching.py @@ -0,0 +1,223 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import warnings +from typing import Optional, Tuple, Union + +import torch + + +class EmissionAbsorptionRaymarcher(torch.nn.Module): + """ + Raymarch using the Emission-Absorption (EA) algorithm. + + The algorithm independently renders each ray by analyzing density and + feature values sampled at (typically uniformly) spaced 3D locations along + each ray. The density values `rays_densities` are of shape + `(..., n_points_per_ray)`, their values should range between [0, 1], and + represent the opaqueness of each point (the higher the less transparent). + The feature values `rays_features` of shape + `(..., n_points_per_ray, feature_dim)` represent the content of the + point that is supposed to be rendered in case the given point is opaque + (i.e. its density -> 1.0). + + EA first utilizes `rays_densities` to compute the absorption function + along each ray as follows: + ``` + absorption = cumprod(1 - rays_densities, dim=-1) + ``` + The value of absorption at position `absorption[..., k]` specifies + how much light has reached `k`-th point along a ray since starting + its trajectory at `k=0`-th point. + + Each ray is then rendered into a tensor `features` of shape `(..., feature_dim)` + by taking a weighed combination of per-ray features `rays_features` as follows: + ``` + weights = absorption * rays_densities + features = (rays_features * weights).sum(dim=-2) + ``` + Where `weights` denote a function that has a strong peak around the location + of the first surface point that a given ray passes through. + + Note that for a perfectly bounded volume (with a strictly binary density), + the `weights = cumprod(1 - rays_densities, dim=-1) * rays_densities` + function would yield 0 everywhere. In order to prevent this, + the result of the cumulative product is shifted `self.surface_thickness` + elements along the ray direction. + """ + + def __init__(self, surface_thickness: int = 1): + """ + Args: + surface_thickness: Denotes the overlap between the absorption + function and the density function. + """ + super().__init__() + self.surface_thickness = surface_thickness + + def forward( + self, + rays_densities: torch.Tensor, + rays_features: torch.Tensor, + eps: float = 1e-10, + **kwargs, + ) -> torch.Tensor: + """ + Args: + rays_densities: Per-ray density values represented with a tensor + of shape `(..., n_points_per_ray, 1)` whose values range in [0, 1]. + rays_features: Per-ray feature values represented with a tensor + of shape `(..., n_points_per_ray, feature_dim)`. + eps: A lower bound added to `rays_densities` before computing + the absorbtion function (cumprod of `1-rays_densities` along + each ray). This prevents the cumprod to yield exact 0 + which would inhibit any gradient-based learning. + + Returns: + features_opacities: A tensor of shape `(..., feature_dim+1)` + that concatenates two tensors alonng the last dimension: + 1) features: A tensor of per-ray renders + of shape `(..., feature_dim)`. + 2) opacities: A tensor of per-ray opacity values + of shape `(..., 1)`. Its values range between [0, 1] and + denote the total amount of light that has been absorbed + for each ray. E.g. a value of 0 corresponds to the ray + completely passing through a volume. Please refer to the + `AbsorptionOnlyRaymarcher` documentation for the + explanation of the algorithm that computes `opacities`. + """ + _check_raymarcher_inputs( + rays_densities, + rays_features, + None, + z_can_be_none=True, + features_can_be_none=False, + density_1d=True, + ) + _check_density_bounds(rays_densities) + rays_densities = rays_densities[..., 0] + absorption = _shifted_cumprod( + (1.0 + eps) - rays_densities, shift=self.surface_thickness + ) + weights = rays_densities * absorption + features = (weights[..., None] * rays_features).sum(dim=-2) + opacities = 1.0 - torch.prod(1.0 - rays_densities, dim=-1, keepdim=True) + + return torch.cat((features, opacities), dim=-1) + + +class AbsorptionOnlyRaymarcher(torch.nn.Module): + """ + Raymarch using the Absorption-Only (AO) algorithm. + + The algorithm independently renders each ray by analyzing density and + feature values sampled at (typically uniformly) spaced 3D locations along + each ray. The density values `rays_densities` are of shape + `(..., n_points_per_ray, 1)`, their values should range between [0, 1], and + represent the opaqueness of each point (the higher the less transparent). + The algorithm only measures the total amount of light absorbed along each ray + and, besides outputting per-ray `opacity` values of shape `(...,)`, + does not produce any feature renderings. + + The algorithm simply computes `total_transmission = prod(1 - rays_densities)` + of shape `(..., 1)` which, for each ray, measures the total amount of light + that passed through the volume. + It then returns `opacities = 1 - total_transmission`. + """ + + def __init__(self): + super().__init__() + + def forward( + self, rays_densities: torch.Tensor, **kwargs + ) -> Union[None, torch.Tensor]: + """ + Args: + rays_densities: Per-ray density values represented with a tensor + of shape `(..., n_points_per_ray)` whose values range in [0, 1]. + + Returns: + opacities: A tensor of per-ray opacity values of shape `(..., 1)`. + Its values range between [0, 1] and denote the total amount + of light that has been absorbed for each ray. E.g. a value + of 0 corresponds to the ray completely passing through a volume. + """ + + _check_raymarcher_inputs( + rays_densities, + None, + None, + features_can_be_none=True, + z_can_be_none=True, + density_1d=True, + ) + rays_densities = rays_densities[..., 0] + _check_density_bounds(rays_densities) + total_transmission = torch.prod(1 - rays_densities, dim=-1, keepdim=True) + opacities = 1.0 - total_transmission + return opacities + + +def _shifted_cumprod(x, shift=1): + """ + Computes `torch.cumprod(x, dim=-1)` and prepends `shift` number of + ones and removes `shift` trailing elements to/from the last dimension + of the result. + """ + x_cumprod = torch.cumprod(x, dim=-1) + x_cumprod_shift = torch.cat( + [torch.ones_like(x_cumprod[..., :shift]), x_cumprod[..., :-shift]], dim=-1 + ) + return x_cumprod_shift + + +def _check_density_bounds( + rays_densities: torch.Tensor, bounds: Tuple[float, float] = (0.0, 1.0) +): + """ + Checks whether the elements of `rays_densities` range within `bounds`. + If not issues a warning. + """ + if ((rays_densities > bounds[1]) | (rays_densities < bounds[0])).any(): + warnings.warn( + "One or more elements of rays_densities are outside of valid" + + f"range {str(bounds)}" + ) + + +def _check_raymarcher_inputs( + rays_densities: torch.Tensor, + rays_features: Optional[torch.Tensor], + rays_z: Optional[torch.Tensor], + features_can_be_none: bool = False, + z_can_be_none: bool = False, + density_1d: bool = True, +): + """ + Checks the validity of the inputs to raymarching algorithms. + """ + if not torch.is_tensor(rays_densities): + raise ValueError("rays_densities has to be an instance of torch.Tensor.") + + if not z_can_be_none and not torch.is_tensor(rays_z): + raise ValueError("rays_z has to be an instance of torch.Tensor.") + + if not features_can_be_none and not torch.is_tensor(rays_features): + raise ValueError("rays_features has to be an instance of torch.Tensor.") + + if rays_densities.ndim < 1: + raise ValueError("rays_densities have to have at least one dimension.") + + if density_1d and rays_densities.shape[-1] != 1: + raise ValueError( + "The size of the last dimension of rays_densities has to be one." + ) + + rays_shape = rays_densities.shape[:-1] + + if not z_can_be_none and rays_z.shape != rays_shape: + raise ValueError("rays_z have to be of the same shape as rays_densities.") + + if not features_can_be_none and rays_features.shape[:-1] != rays_shape: + raise ValueError( + "The first to previous to last dimensions of rays_features" + " have to be the same as all dimensions of rays_densities." + ) diff --git a/tests/bm_raymarching.py b/tests/bm_raymarching.py new file mode 100644 index 00000000..7e6ad923 --- /dev/null +++ b/tests/bm_raymarching.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import itertools + +from fvcore.common.benchmark import benchmark +from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher +from test_raymarching import TestRaymarching + + +def bm_raymarching() -> None: + case_grid = { + "raymarcher_type": [EmissionAbsorptionRaymarcher, AbsorptionOnlyRaymarcher], + "n_rays": [10, 1000, 10000], + "n_pts_per_ray": [10, 1000, 10000], + } + test_cases = itertools.product(*case_grid.values()) + kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases] + + benchmark(TestRaymarching.raymarcher, "RAYMARCHER", kwargs_list, warmup_iters=1) diff --git a/tests/test_raymarching.py b/tests/test_raymarching.py new file mode 100644 index 00000000..9e3a64bc --- /dev/null +++ b/tests/test_raymarching.py @@ -0,0 +1,196 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.renderer import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher + + +class TestRaymarching(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(42) + + @staticmethod + def _init_random_rays( + n_rays=10, n_pts_per_ray=9, device="cuda", dtype=torch.float32 + ): + """ + Generate a batch of ray points with features, densities, and z-coodinates + such that their EmissionAbsorption renderring results in + feature renders `features_gt`, depth renders `depths_gt`, + and opacity renders `opacities_gt`. + """ + + # generate trivial ray z-coordinates of sampled points coinciding with + # each point's order along a ray. + rays_z = torch.arange(n_pts_per_ray, dtype=dtype, device=device)[None].repeat( + n_rays, 1 + ) + + # generate ground truth depth values of the underlying surface. + depths_gt = torch.randint( + low=1, high=n_pts_per_ray + 2, size=(n_rays,) + ).type_as(rays_z) + + # compute ideal densities that are 0 before the surface and 1 after + # the corresponding ground truth depth value + rays_densities = (rays_z >= depths_gt[..., None]).type_as(rays_z)[..., None] + opacities_gt = (depths_gt < n_pts_per_ray).type_as(rays_z) + + # generate random per-ray features + rays_features = torch.rand( + (n_rays, n_pts_per_ray, 3), device=rays_z.device, dtype=rays_z.dtype + ) + + # infer the expected feature render "features_gt" + gt_surface = ((rays_z - depths_gt[..., None]).abs() <= 1e-4).type_as(rays_z) + features_gt = (rays_features * gt_surface[..., None]).sum(dim=-2) + + return ( + rays_z, + rays_densities, + rays_features, + depths_gt, + features_gt, + opacities_gt, + ) + + @staticmethod + def raymarcher( + raymarcher_type=EmissionAbsorptionRaymarcher, n_rays=10, n_pts_per_ray=10 + ): + ( + rays_z, + rays_densities, + rays_features, + depths_gt, + features_gt, + opacities_gt, + ) = TestRaymarching._init_random_rays( + n_rays=n_rays, n_pts_per_ray=n_pts_per_ray + ) + + raymarcher = raymarcher_type() + + def run_raymarcher(): + raymarcher( + rays_densities=rays_densities, + rays_features=rays_features, + rays_z=rays_z, + ) + torch.cuda.synchronize() + + return run_raymarcher + + def test_emission_absorption_inputs(self): + """ + Test the checks of validity of the inputs to `EmissionAbsorptionRaymarcher`. + """ + + # init the EA raymarcher + raymarcher_ea = EmissionAbsorptionRaymarcher() + + # bad ways of passing densities and features + # [rays_densities, rays_features, rays_z] + bad_inputs = [ + [torch.rand(10, 5, 4), None], + [torch.Tensor(3)[0], torch.rand(10, 5, 4)], + [1.0, torch.rand(10, 5, 4)], + [torch.rand(10, 5, 4), 1.0], + [torch.rand(10, 5, 4), None], + [torch.rand(10, 5, 4), torch.rand(10, 5, 4)], + [torch.rand(10, 5, 4), torch.rand(10, 5, 4, 3)], + [torch.rand(10, 5, 4, 3), torch.rand(10, 5, 4, 3)], + ] + + for bad_input in bad_inputs: + with self.assertRaises(ValueError): + raymarcher_ea(*bad_input) + + def test_absorption_only_inputs(self): + """ + Test the checks of validity of the inputs to `AbsorptionOnlyRaymarcher`. + """ + + # init the AO raymarcher + raymarcher_ao = AbsorptionOnlyRaymarcher() + + # bad ways of passing densities and features + # [rays_densities, rays_features, rays_z] + bad_inputs = [[torch.Tensor(3)[0]]] + + for bad_input in bad_inputs: + with self.assertRaises(ValueError): + raymarcher_ao(*bad_input) + + def test_emission_absorption(self): + """ + Test the EA raymarching algorithm. + """ + ( + rays_z, + rays_densities, + rays_features, + depths_gt, + features_gt, + opacities_gt, + ) = TestRaymarching._init_random_rays( + n_rays=1000, n_pts_per_ray=9, device=None, dtype=torch.float32 + ) + + # init the EA raymarcher + raymarcher_ea = EmissionAbsorptionRaymarcher() + + # allow gradients for a differentiability check + rays_densities.requires_grad = True + rays_features.requires_grad = True + + # render the features first and check with gt + data_render = raymarcher_ea(rays_densities, rays_features) + features_render, opacities_render = data_render[..., :-1], data_render[..., -1] + self.assertClose(opacities_render, opacities_gt) + self.assertClose( + features_render * opacities_render[..., None], + features_gt * opacities_gt[..., None], + ) + + # get the depth map by rendering the ray z components and check with gt + depths_render = raymarcher_ea(rays_densities, rays_z[..., None])[..., 0] + self.assertClose(depths_render * opacities_render, depths_gt * opacities_gt) + + # check differentiability + loss = features_render.mean() + loss.backward() + for field in (rays_densities, rays_features): + self.assertTrue(field.grad.data.isfinite().all()) + + def test_absorption_only(self): + """ + Test the AO raymarching algorithm. + """ + ( + rays_z, + rays_densities, + rays_features, + depths_gt, + features_gt, + opacities_gt, + ) = TestRaymarching._init_random_rays( + n_rays=1000, n_pts_per_ray=9, device=None, dtype=torch.float32 + ) + + # init the AO raymarcher + raymarcher_ao = AbsorptionOnlyRaymarcher() + + # allow gradients for a differentiability check + rays_densities.requires_grad = True + + # render opacities, check with gt and check that returned features are None + opacities_render = raymarcher_ao(rays_densities)[..., 0] + self.assertClose(opacities_render, opacities_gt) + + # check differentiability + loss = opacities_render.mean() + loss.backward() + self.assertTrue(rays_densities.grad.data.isfinite().all())