From e6bc960fb565cd3a8bbc26edcd1109be4b0856a2 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Wed, 6 Jan 2021 03:59:56 -0800 Subject: [PATCH] Raysampling Summary: Implements 3 basic raysamplers. Reviewed By: nikhilaravi Differential Revision: D24110643 fbshipit-source-id: eb67d0e56773c7871ebdcb23e7e520302dc1b3c9 --- pytorch3d/renderer/__init__.py | 11 +- pytorch3d/renderer/implicit/__init__.py | 6 + pytorch3d/renderer/implicit/raysampling.py | 320 ++++++++++++++++ pytorch3d/renderer/implicit/utils.py | 82 ++++ tests/bm_raysampling.py | 39 ++ tests/test_raysampling.py | 423 +++++++++++++++++++++ 6 files changed, 880 insertions(+), 1 deletion(-) create mode 100644 pytorch3d/renderer/implicit/raysampling.py create mode 100644 pytorch3d/renderer/implicit/utils.py create mode 100644 tests/bm_raysampling.py create mode 100644 tests/test_raysampling.py diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 9ae62aa0..8a3eb1ee 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -20,7 +20,16 @@ from .cameras import ( look_at_rotation, look_at_view_transform, ) -from .implicit import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher +from .implicit import ( + AbsorptionOnlyRaymarcher, + EmissionAbsorptionRaymarcher, + GridRaysampler, + MonteCarloRaysampler, + NDCGridRaysampler, + RayBundle, + ray_bundle_to_ray_points, + ray_bundle_variables_to_ray_points, +) from .lighting import DirectionalLights, PointLights, diffuse, specular from .materials import Materials from .mesh import ( diff --git a/pytorch3d/renderer/implicit/__init__.py b/pytorch3d/renderer/implicit/__init__.py index f5da3e2f..ec245b01 100644 --- a/pytorch3d/renderer/implicit/__init__.py +++ b/pytorch3d/renderer/implicit/__init__.py @@ -1,6 +1,12 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher +from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler +from .utils import ( + RayBundle, + ray_bundle_to_ray_points, + ray_bundle_variables_to_ray_points, +) __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py new file mode 100644 index 00000000..9536a42c --- /dev/null +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -0,0 +1,320 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import torch + +from ..cameras import CamerasBase +from .utils import RayBundle + + +""" +This file defines three raysampling techniques: + - GridRaysampler which can be used to sample rays from pixels of an image grid + - NDCGridRaysampler which can be used to sample rays from pixels of an image grid, + which follows the pytorch3d convention for image grid coordinates + - MonteCarloRaysampler which randomly selects image pixels and emits rays from them +""" + + +class GridRaysampler(torch.nn.Module): + """ + Samples a fixed number of points along rays which are regulary distributed + in a batch of rectangular image grids. Points along each ray + have uniformly-spaced z-coordinates between a predefined + minimum and maximum depth. + + The raysampler first generates a 3D coordinate grid of the following form: + ``` + / min_x, min_y, max_depth -------------- / max_x, min_y, max_depth + / /| + / / | ^ + / min_depth min_depth / | | + min_x ----------------------------- max_x | | image + min_y min_y | | height + | | | | + | | | v + | | | + | | / max_x, max_y, ^ + | | / max_depth / + min_x max_y / / n_pts_per_ray + max_y ----------------------------- max_x/ min_depth v + < --- image_width --- > + ``` + + In order to generate ray points, `GridRaysampler` takes each 3D point of + the grid (with coordinates `[x, y, depth]`) and unprojects it + with `cameras.unproject_points([x, y, depth])`, where `cameras` are an + additional input to the `forward` function. + + Note that this is a generic implementation that can support any image grid + coordinate convention. For a raysampler which follows the PyTorch3D + coordinate conventions please refer to `NDCGridRaysampler`. + As such, `NDCGridRaysampler` is a special case of `GridRaysampler`. + """ + + def __init__( + self, + min_x: float, + max_x: float, + min_y: float, + max_y: float, + image_width: int, + image_height: int, + n_pts_per_ray: int, + min_depth: float, + max_depth: float, + ): + """ + Args: + min_x: The leftmost x-coordinate of each ray's source pixel's center. + max_x: The rightmost x-coordinate of each ray's source pixel's center. + min_y: The topmost y-coordinate of each ray's source pixel's center. + max_y: The bottommost y-coordinate of each ray's source pixel's center. + image_width: The horizontal size of the image grid. + image_height: The vertical size of the image grid. + n_pts_per_ray: The number of points sampled along each ray. + min_depth: The minimum depth of a ray-point. + max_depth: The maximum depth of a ray-point. + """ + super().__init__() + self._n_pts_per_ray = n_pts_per_ray + self._min_depth = min_depth + self._max_depth = max_depth + + # get the initial grid of image xy coords + _xy_grid = torch.stack( + tuple( + reversed( + torch.meshgrid( + torch.linspace(min_y, max_y, image_height, dtype=torch.float32), + torch.linspace(min_x, max_x, image_width, dtype=torch.float32), + ) + ) + ), + dim=-1, + ) + self.register_buffer("_xy_grid", _xy_grid) + + def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle: + """ + Args: + cameras: A batch of `batch_size` cameras from which the rays are emitted. + Returns: + A named tuple RayBundle with the following fields: + origins: A tensor of shape + `(batch_size, image_height, image_width, 3)` + denoting the locations of ray origins in the world coordinates. + directions: A tensor of shape + `(batch_size, image_height, image_width, 3)` + denoting the directions of each ray in the world coordinates. + lengths: A tensor of shape + `(batch_size, image_height, image_width, n_pts_per_ray)` + containing the z-coordinate (=depth) of each ray in world units. + xys: A tensor of shape + `(batch_size, image_height, image_width, 2)` + containing the 2D image coordinates of each ray. + """ + + batch_size = cameras.R.shape[0] # pyre-ignore + + device = cameras.device + + # expand the (H, W, 2) grid batch_size-times to (B, H, W, 2) + xy_grid = self._xy_grid.to(device)[None].expand( # pyre-ignore + batch_size, *self._xy_grid.shape + ) + + return _xy_to_ray_bundle( + cameras, xy_grid, self._min_depth, self._max_depth, self._n_pts_per_ray + ) + + +class NDCGridRaysampler(GridRaysampler): + """ + Samples a fixed number of points along rays which are regulary distributed + in a batch of rectangular image grids. Points along each ray + have uniformly-spaced z-coordinates between a predefined minimum and maximum depth. + + `NDCGridRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds` + renderers. I.e. the border of the leftmost / rightmost / topmost / bottommost pixel + has coordinates 1.0 / -1.0 / 1.0 / -1.0 respectively. + """ + + def __init__( + self, + image_width: int, + image_height: int, + n_pts_per_ray: int, + min_depth: float, + max_depth: float, + ): + """ + Args: + image_width: The horizontal size of the image grid. + image_height: The vertical size of the image grid. + n_pts_per_ray: The number of points sampled along each ray. + min_depth: The minimum depth of a ray-point. + max_depth: The maximum depth of a ray-point. + """ + half_pix_width = 1.0 / image_width + half_pix_height = 1.0 / image_height + super().__init__( + min_x=1.0 - half_pix_width, + max_x=-1.0 + half_pix_width, + min_y=1.0 - half_pix_height, + max_y=-1.0 + half_pix_height, + image_width=image_width, + image_height=image_height, + n_pts_per_ray=n_pts_per_ray, + min_depth=min_depth, + max_depth=max_depth, + ) + + +class MonteCarloRaysampler(torch.nn.Module): + """ + Samples a fixed number of pixels within denoted xy bounds uniformly at random. + For each pixel, a fixed number of points is sampled along its ray at uniformly-spaced + z-coordinates such that the z-coordinates range between a predefined minimum + and maximum depth. + """ + + def __init__( + self, + min_x: float, + max_x: float, + min_y: float, + max_y: float, + n_rays_per_image: int, + n_pts_per_ray: int, + min_depth: float, + max_depth: float, + ): + """ + Args: + min_x: The smallest x-coordinate of each ray's source pixel. + max_x: The largest x-coordinate of each ray's source pixel. + min_y: The smallest y-coordinate of each ray's source pixel. + max_y: The largest y-coordinate of each ray's source pixel. + n_rays_per_image: The number of rays randomly sampled in each camera. + n_pts_per_ray: The number of points sampled along each ray. + min_depth: The minimum depth of each ray-point. + max_depth: The maximum depth of each ray-point. + """ + super().__init__() + self._min_x = min_x + self._max_x = max_x + self._min_y = min_y + self._max_y = max_y + self._n_rays_per_image = n_rays_per_image + self._n_pts_per_ray = n_pts_per_ray + self._min_depth = min_depth + self._max_depth = max_depth + + def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle: + """ + Args: + cameras: A batch of `batch_size` cameras from which the rays are emitted. + Returns: + A named tuple RayBundle with the following fields: + origins: A tensor of shape + `(batch_size, n_rays_per_image, 3)` + denoting the locations of ray origins in the world coordinates. + directions: A tensor of shape + `(batch_size, n_rays_per_image, 3)` + denoting the directions of each ray in the world coordinates. + lengths: A tensor of shape + `(batch_size, n_rays_per_image, n_pts_per_ray)` + containing the z-coordinate (=depth) of each ray in world units. + xys: A tensor of shape + `(batch_size, n_rays_per_image, 2)` + containing the 2D image coordinates of each ray. + """ + + batch_size = cameras.R.shape[0] # pyre-ignore + + device = cameras.device + + # get the initial grid of image xy coords + # of shape (batch_size, n_rays_per_image, 2) + rays_xy = torch.cat( + [ + torch.rand( + size=(batch_size, self._n_rays_per_image, 1), + dtype=torch.float32, + device=device, + ) + * (high - low) + + low + for low, high in ( + (self._min_x, self._max_x), + (self._min_y, self._max_y), + ) + ], + dim=2, + ) + + return _xy_to_ray_bundle( + cameras, rays_xy, self._min_depth, self._max_depth, self._n_pts_per_ray + ) + + +def _xy_to_ray_bundle( + cameras: CamerasBase, + xy_grid: torch.Tensor, + min_depth: float, + max_depth: float, + n_pts_per_ray: int, +) -> RayBundle: + """ + Extends the `xy_grid` input of shape `(batch_size, ..., 2)` to rays. + This adds to each xy location in the grid a vector of `n_pts_per_ray` depths + uniformly spaced between `min_depth` and `max_depth`. + + The extended grid is then unprojected with `cameras` to yield + ray origins, directions and depths. + """ + batch_size = xy_grid.shape[0] + spatial_size = xy_grid.shape[1:-1] + n_rays_per_image = spatial_size.numel() # pyre-ignore + + # ray z-coords + depths = torch.linspace( + min_depth, max_depth, n_pts_per_ray, dtype=xy_grid.dtype, device=xy_grid.device + ) + rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray) + + # make two sets of points at a constant depth=1 and 2 + to_unproject = torch.cat( + ( + xy_grid.view(batch_size, 1, n_rays_per_image, 2) + .expand(batch_size, 2, n_rays_per_image, 2) + .reshape(batch_size, n_rays_per_image * 2, 2), + torch.cat( + ( + xy_grid.new_ones(batch_size, n_rays_per_image, 1), # pyre-ignore + 2.0 * xy_grid.new_ones(batch_size, n_rays_per_image, 1), + ), + dim=1, + ), + ), + dim=-1, + ) + + # unproject the points + unprojected = cameras.unproject_points(to_unproject) # pyre-ignore + + # split the two planes back + rays_plane_1_world = unprojected[:, :n_rays_per_image] + rays_plane_2_world = unprojected[:, n_rays_per_image:] + + # directions are the differences between the two planes of points + rays_directions_world = rays_plane_2_world - rays_plane_1_world + + # origins are given by subtracting the ray directions from the first plane + rays_origins_world = rays_plane_1_world - rays_directions_world + + return RayBundle( + rays_origins_world.view(batch_size, *spatial_size, 3), + rays_directions_world.view(batch_size, *spatial_size, 3), + rays_zs.view(batch_size, *spatial_size, n_pts_per_ray), + xy_grid, + ) diff --git a/pytorch3d/renderer/implicit/utils.py b/pytorch3d/renderer/implicit/utils.py new file mode 100644 index 00000000..b5f01812 --- /dev/null +++ b/pytorch3d/renderer/implicit/utils.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from typing import NamedTuple + +import torch + + +class RayBundle(NamedTuple): + """ + RayBundle parametrizes points along projection rays by storing ray `origins`, + `directions` vectors and `lengths` at which the ray-points are sampled. + Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well. + """ + + origins: torch.Tensor + directions: torch.Tensor + lengths: torch.Tensor + xys: torch.Tensor + + +def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor: + """ + Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle` + named tuple) to 3D points by extending each ray according to the corresponding + length. + + E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions` + and `ray_bundle.lengths`, the ray point at position `[i, j]` is: + ``` + ray_bundle.points[i, j, :] = ( + ray_bundle.origins[i, :] + + ray_bundle.directions[i, :] * ray_bundle.lengths[i, j] + ) + ``` + + Args: + ray_bundle: A `RayBundle` object with fields: + origins: A tensor of shape `(..., 3)` + directions: A tensor of shape `(..., 3)` + lengths: A tensor of shape `(..., num_points_per_ray)` + + Returns: + rays_points: A tensor of shape `(..., num_points_per_ray, 3)` + containing the points sampled along each ray. + """ + return ray_bundle_variables_to_ray_points( + ray_bundle.origins, ray_bundle.directions, ray_bundle.lengths + ) + + +def ray_bundle_variables_to_ray_points( + rays_origins: torch.Tensor, + rays_directions: torch.Tensor, + rays_lengths: torch.Tensor, +) -> torch.Tensor: + """ + Converts rays parametrized with origins, directions + to 3D points by extending each ray according to the corresponding + ray_length: + + E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions` + and `rays_lengths`, the ray point at position `[i, j]` is: + ``` + rays_points[i, j, :] = ( + rays_origins[i, :] + + rays_directions[i, :] * rays_lengths[i, j] + ) + ``` + + Args: + rays_origins: A tensor of shape `(..., 3)` + rays_directions: A tensor of shape `(..., 3)` + rays_lengths: A tensor of shape `(..., num_points_per_ray)` + + Returns: + rays_points: A tensor of shape `(..., num_points_per_ray, 3)` + containing the points sampled along each ray. + """ + rays_points = ( + rays_origins[..., None, :] + + rays_lengths[..., :, None] * rays_directions[..., None, :] + ) + return rays_points diff --git a/tests/bm_raysampling.py b/tests/bm_raysampling.py new file mode 100644 index 00000000..b971158c --- /dev/null +++ b/tests/bm_raysampling.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import itertools + +from fvcore.common.benchmark import benchmark +from pytorch3d.renderer import ( + GridRaysampler, + MonteCarloRaysampler, + NDCGridRaysampler, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, +) +from test_raysampling import TestRaysampling + + +def bm_raysampling() -> None: + case_grid = { + "raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler], + "camera_type": [ + PerspectiveCameras, + OrthographicCameras, + FoVPerspectiveCameras, + FoVOrthographicCameras, + ], + "batch_size": [1, 10], + "n_pts_per_ray": [10, 1000, 10000], + "image_width": [10, 300], + "image_height": [10, 300], + } + test_cases = itertools.product(*case_grid.values()) + kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases] + + benchmark(TestRaysampling.raysampler, "RAYSAMPLER", kwargs_list, warmup_iters=1) + + +if __name__ == "__main__": + bm_raysampling() diff --git a/tests/test_raysampling.py b/tests/test_raysampling.py new file mode 100644 index 00000000..dffa7510 --- /dev/null +++ b/tests/test_raysampling.py @@ -0,0 +1,423 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.ops import eyes +from pytorch3d.renderer import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler +from pytorch3d.renderer.cameras import ( + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, +) +from pytorch3d.renderer.implicit.utils import ( + ray_bundle_to_ray_points, + ray_bundle_variables_to_ray_points, +) +from pytorch3d.transforms import Rotate +from test_cameras import init_random_cameras + + +class TestRaysampling(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(42) + + @staticmethod + def raysampler( + raysampler_type=GridRaysampler, + camera_type=PerspectiveCameras, + n_pts_per_ray=10, + batch_size=1, + image_width=10, + image_height=20, + ): + + device = torch.device("cuda") + + # init raysamplers + raysampler = TestRaysampling.init_raysampler( + raysampler_type=raysampler_type, + min_x=-1.0, + max_x=1.0, + min_y=-1.0, + max_y=1.0, + image_width=image_width, + image_height=image_height, + min_depth=1.0, + max_depth=10.0, + n_pts_per_ray=n_pts_per_ray, + ).to(device) + + # init a batch of random cameras + cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device) + + def run_raysampler(): + raysampler(cameras=cameras) + torch.cuda.synchronize() + + return run_raysampler + + @staticmethod + def init_raysampler( + raysampler_type=GridRaysampler, + min_x=-1.0, + max_x=1.0, + min_y=-1.0, + max_y=1.0, + image_width=10, + image_height=20, + min_depth=1.0, + max_depth=10.0, + n_pts_per_ray=10, + ): + raysampler_params = { + "min_x": min_x, + "max_x": max_x, + "min_y": min_y, + "max_y": max_y, + "n_pts_per_ray": n_pts_per_ray, + "min_depth": min_depth, + "max_depth": max_depth, + } + + if issubclass(raysampler_type, GridRaysampler): + raysampler_params.update( + {"image_width": image_width, "image_height": image_height} + ) + elif issubclass(raysampler_type, MonteCarloRaysampler): + raysampler_params["n_rays_per_image"] = image_width * image_height + else: + raise ValueError(str(raysampler_type)) + + if issubclass(raysampler_type, NDCGridRaysampler): + # NDCGridRaysampler does not use min/max_x/y + for k in ("min_x", "max_x", "min_y", "max_y"): + del raysampler_params[k] + + # instantiate the raysampler + raysampler = raysampler_type(**raysampler_params) + + return raysampler + + def test_raysamplers( + self, + batch_size=25, + min_x=-1.0, + max_x=1.0, + min_y=-1.0, + max_y=1.0, + image_width=10, + image_height=20, + min_depth=1.0, + max_depth=10.0, + ): + """ + Tests the shapes and outputs of MC and GridRaysamplers for randomly + generated cameras and different numbers of points per ray. + """ + + device = torch.device("cuda") + + for n_pts_per_ray in (100, 1): + + for raysampler_type in ( + MonteCarloRaysampler, + GridRaysampler, + NDCGridRaysampler, + ): + + raysampler = TestRaysampling.init_raysampler( + raysampler_type=raysampler_type, + min_x=min_x, + max_x=max_x, + min_y=min_y, + max_y=max_y, + image_width=image_width, + image_height=image_height, + min_depth=min_depth, + max_depth=max_depth, + n_pts_per_ray=n_pts_per_ray, + ) + + if issubclass(raysampler_type, NDCGridRaysampler): + # adjust the gt bounds for NDCGridRaysampler + half_pix_width = 1.0 / image_width + half_pix_height = 1.0 / image_height + min_x_ = 1.0 - half_pix_width + max_x_ = -1.0 + half_pix_width + min_y_ = 1.0 - half_pix_height + max_y_ = -1.0 + half_pix_height + else: + min_x_ = min_x + max_x_ = max_x + min_y_ = min_y + max_y_ = max_y + + # carry out the test over several camera types + for cam_type in ( + FoVPerspectiveCameras, + FoVOrthographicCameras, + OrthographicCameras, + PerspectiveCameras, + ): + + # init a batch of random cameras + cameras = init_random_cameras( + cam_type, batch_size, random_z=True + ).to(device) + + # call the raysampler + ray_bundle = raysampler(cameras=cameras) + + # check the shapes of the raysampler outputs + self._check_raysampler_output_shapes( + raysampler, + ray_bundle, + batch_size, + image_width, + image_height, + n_pts_per_ray, + ) + + # check the points sampled along each ray + self._check_raysampler_ray_points( + raysampler, + cameras, + ray_bundle, + min_x_, + max_x_, + min_y_, + max_y_, + image_width, + image_height, + min_depth, + max_depth, + ) + + # check the output direction vectors + self._check_raysampler_ray_directions( + cameras, raysampler, ray_bundle + ) + + def _check_grid_shape(self, grid, batch_size, spatial_size, n_pts_per_ray, dim): + """ + A helper for checking the desired size of a variable output by a raysampler. + """ + tgt_shape = [ + x for x in [batch_size, *spatial_size, n_pts_per_ray, dim] if x > 0 + ] + self.assertTrue(all(sz1 == sz2 for sz1, sz2 in zip(grid.shape, tgt_shape))) + + def _check_raysampler_output_shapes( + self, + raysampler, + ray_bundle, + batch_size, + image_width, + image_height, + n_pts_per_ray, + ): + """ + Checks the shapes of raysampler outputs. + """ + + if isinstance(raysampler, GridRaysampler): + spatial_size = [image_height, image_width] + elif isinstance(raysampler, MonteCarloRaysampler): + spatial_size = [image_height * image_width] + else: + raise ValueError(str(type(raysampler))) + + self._check_grid_shape(ray_bundle.xys, batch_size, spatial_size, 0, 2) + self._check_grid_shape(ray_bundle.origins, batch_size, spatial_size, 0, 3) + self._check_grid_shape(ray_bundle.directions, batch_size, spatial_size, 0, 3) + self._check_grid_shape( + ray_bundle.lengths, batch_size, spatial_size, n_pts_per_ray, 0 + ) + + def _check_raysampler_ray_points( + self, + raysampler, + cameras, + ray_bundle, + min_x, + max_x, + min_y, + max_y, + image_width, + image_height, + min_depth, + max_depth, + ): + """ + Check rays_points_world and rays_zs outputs of raysamplers. + """ + + batch_size = cameras.R.shape[0] + + # convert to ray points + rays_points_world = ray_bundle_variables_to_ray_points( + ray_bundle.origins, ray_bundle.directions, ray_bundle.lengths + ) + n_pts_per_ray = rays_points_world.shape[-2] + + # check that the outputs if ray_bundle_variables_to_ray_points and + # ray_bundle_to_ray_points match + rays_points_world_ = ray_bundle_to_ray_points(ray_bundle) + self.assertClose(rays_points_world, rays_points_world_) + + # check that the depth of each ray point in camera coords + # matches the expected linearly-spaced depth + depth_expected = torch.linspace( + min_depth, + max_depth, + n_pts_per_ray, + dtype=torch.float32, + device=rays_points_world.device, + ) + ray_points_camera = ( + cameras.get_world_to_view_transform() + .transform_points(rays_points_world.view(batch_size, -1, 3)) + .view(batch_size, -1, n_pts_per_ray, 3) + ) + self.assertClose( + ray_points_camera[..., 2], + depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]), + atol=1e-4, + ) + + # check also that rays_zs is consistent with depth_expected + self.assertClose( + ray_bundle.lengths.view(batch_size, -1, n_pts_per_ray), + depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]), + atol=1e-6, + ) + + # project the world ray points back to screen space + ray_points_projected = cameras.transform_points( + rays_points_world.view(batch_size, -1, 3) + ).view(rays_points_world.shape) + + # check that ray_xy matches rays_points_projected xy + rays_xy_projected = ray_points_projected[..., :2].view( + batch_size, -1, n_pts_per_ray, 2 + ) + self.assertClose( + ray_bundle.xys.view(batch_size, -1, 1, 2).expand_as(rays_xy_projected), + rays_xy_projected, + atol=1e-4, + ) + + # check that projected world points' xy coordinates + # range correctly between [minx/y, max/y] + if isinstance(raysampler, GridRaysampler): + # get the expected coordinates along each grid axis + ys, xs = [ + torch.linspace( + low, high, sz, dtype=torch.float32, device=rays_points_world.device + ) + for low, high, sz in ( + (min_y, max_y, image_height), + (min_x, max_x, image_width), + ) + ] + # compare expected xy with the output xy + for dim, gt_axis in zip( + (0, 1), (xs[None, None, :, None], ys[None, :, None, None]) + ): + self.assertClose( + ray_points_projected[..., dim], + gt_axis.expand_as(ray_points_projected[..., dim]), + atol=1e-4, + ) + + elif isinstance(raysampler, MonteCarloRaysampler): + # check that the randomly sampled locations + # are within the allowed bounds for both x and y axes + for dim, axis_bounds in zip((0, 1), ((min_x, max_x), (min_y, max_y))): + self.assertTrue( + ( + (ray_points_projected[..., dim] <= axis_bounds[1]) + & (ray_points_projected[..., dim] >= axis_bounds[0]) + ).all() + ) + + # also check that x,y along each ray is constant + if n_pts_per_ray > 1: + self.assertClose( + ray_points_projected[..., :2].std(dim=-2), + torch.zeros_like(ray_points_projected[..., 0, :2]), + atol=1e-5, + ) + + else: + raise ValueError(str(type(raysampler))) + + def _check_raysampler_ray_directions(self, cameras, raysampler, ray_bundle): + """ + Check the rays_directions_world output of raysamplers. + """ + + batch_size = cameras.R.shape[0] + n_pts_per_ray = ray_bundle.lengths.shape[-1] + spatial_size = ray_bundle.xys.shape[1:-1] + n_rays_per_image = spatial_size.numel() + + # obtain the ray points in world coords + rays_points_world = cameras.unproject_points( + torch.cat( + ( + ray_bundle.xys.view(batch_size, n_rays_per_image, 1, 2).expand( + batch_size, n_rays_per_image, n_pts_per_ray, 2 + ), + ray_bundle.lengths.view( + batch_size, n_rays_per_image, n_pts_per_ray, 1 + ), + ), + dim=-1, + ).view(batch_size, -1, 3) + ).view(batch_size, -1, n_pts_per_ray, 3) + + # reshape to common testing size + rays_directions_world_normed = torch.nn.functional.normalize( + ray_bundle.directions.view(batch_size, -1, 3), dim=-1 + ) + + # check that the l2-normed difference of all consecutive planes + # of points in world coords matches ray_directions_world + rays_directions_world_ = torch.nn.functional.normalize( + rays_points_world[:, :, -1:] - rays_points_world[:, :, :-1], dim=-1 + ) + self.assertClose( + rays_directions_world_normed[:, :, None].expand_as(rays_directions_world_), + rays_directions_world_, + atol=1e-4, + ) + + # check the ray directions rotated using camera rotation matrix + # match the ray directions of a camera with trivial extrinsics + cameras_trivial_extrinsic = cameras.clone() + cameras_trivial_extrinsic.R = eyes( + N=batch_size, dim=3, dtype=cameras.R.dtype, device=cameras.device + ) + cameras_trivial_extrinsic.T = torch.zeros_like(cameras.T) + + # make sure we get the same random rays in case we call the + # MonteCarloRaysampler twice below + with torch.random.fork_rng(devices=range(torch.cuda.device_count())): + torch.random.manual_seed(42) + ray_bundle_world_fix_seed = raysampler(cameras=cameras) + torch.random.manual_seed(42) + ray_bundle_camera_fix_seed = raysampler(cameras=cameras_trivial_extrinsic) + + rays_directions_camera_fix_seed_ = Rotate( + cameras.R, device=cameras.R.device + ).transform_points(ray_bundle_world_fix_seed.directions.view(batch_size, -1, 3)) + + self.assertClose( + rays_directions_camera_fix_seed_, + ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3), + atol=1e-5, + )