Raysampling

Summary: Implements 3 basic raysamplers.

Reviewed By: nikhilaravi

Differential Revision: D24110643

fbshipit-source-id: eb67d0e56773c7871ebdcb23e7e520302dc1b3c9
This commit is contained in:
David Novotny 2021-01-06 03:59:56 -08:00 committed by Facebook GitHub Bot
parent 1f9cf91e1b
commit e6bc960fb5
6 changed files with 880 additions and 1 deletions

View File

@ -20,7 +20,16 @@ from .cameras import (
look_at_rotation,
look_at_view_transform,
)
from .implicit import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from .implicit import (
AbsorptionOnlyRaymarcher,
EmissionAbsorptionRaymarcher,
GridRaysampler,
MonteCarloRaysampler,
NDCGridRaysampler,
RayBundle,
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
)
from .lighting import DirectionalLights, PointLights, diffuse, specular
from .materials import Materials
from .mesh import (

View File

@ -1,6 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
from .utils import (
RayBundle,
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
)
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -0,0 +1,320 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from ..cameras import CamerasBase
from .utils import RayBundle
"""
This file defines three raysampling techniques:
- GridRaysampler which can be used to sample rays from pixels of an image grid
- NDCGridRaysampler which can be used to sample rays from pixels of an image grid,
which follows the pytorch3d convention for image grid coordinates
- MonteCarloRaysampler which randomly selects image pixels and emits rays from them
"""
class GridRaysampler(torch.nn.Module):
"""
Samples a fixed number of points along rays which are regulary distributed
in a batch of rectangular image grids. Points along each ray
have uniformly-spaced z-coordinates between a predefined
minimum and maximum depth.
The raysampler first generates a 3D coordinate grid of the following form:
```
/ min_x, min_y, max_depth -------------- / max_x, min_y, max_depth
/ /|
/ / | ^
/ min_depth min_depth / | |
min_x ----------------------------- max_x | | image
min_y min_y | | height
| | | |
| | | v
| | |
| | / max_x, max_y, ^
| | / max_depth /
min_x max_y / / n_pts_per_ray
max_y ----------------------------- max_x/ min_depth v
< --- image_width --- >
```
In order to generate ray points, `GridRaysampler` takes each 3D point of
the grid (with coordinates `[x, y, depth]`) and unprojects it
with `cameras.unproject_points([x, y, depth])`, where `cameras` are an
additional input to the `forward` function.
Note that this is a generic implementation that can support any image grid
coordinate convention. For a raysampler which follows the PyTorch3D
coordinate conventions please refer to `NDCGridRaysampler`.
As such, `NDCGridRaysampler` is a special case of `GridRaysampler`.
"""
def __init__(
self,
min_x: float,
max_x: float,
min_y: float,
max_y: float,
image_width: int,
image_height: int,
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
):
"""
Args:
min_x: The leftmost x-coordinate of each ray's source pixel's center.
max_x: The rightmost x-coordinate of each ray's source pixel's center.
min_y: The topmost y-coordinate of each ray's source pixel's center.
max_y: The bottommost y-coordinate of each ray's source pixel's center.
image_width: The horizontal size of the image grid.
image_height: The vertical size of the image grid.
n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
"""
super().__init__()
self._n_pts_per_ray = n_pts_per_ray
self._min_depth = min_depth
self._max_depth = max_depth
# get the initial grid of image xy coords
_xy_grid = torch.stack(
tuple(
reversed(
torch.meshgrid(
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
)
)
),
dim=-1,
)
self.register_buffer("_xy_grid", _xy_grid)
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
Returns:
A named tuple RayBundle with the following fields:
origins: A tensor of shape
`(batch_size, image_height, image_width, 3)`
denoting the locations of ray origins in the world coordinates.
directions: A tensor of shape
`(batch_size, image_height, image_width, 3)`
denoting the directions of each ray in the world coordinates.
lengths: A tensor of shape
`(batch_size, image_height, image_width, n_pts_per_ray)`
containing the z-coordinate (=depth) of each ray in world units.
xys: A tensor of shape
`(batch_size, image_height, image_width, 2)`
containing the 2D image coordinates of each ray.
"""
batch_size = cameras.R.shape[0] # pyre-ignore
device = cameras.device
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
xy_grid = self._xy_grid.to(device)[None].expand( # pyre-ignore
batch_size, *self._xy_grid.shape
)
return _xy_to_ray_bundle(
cameras, xy_grid, self._min_depth, self._max_depth, self._n_pts_per_ray
)
class NDCGridRaysampler(GridRaysampler):
"""
Samples a fixed number of points along rays which are regulary distributed
in a batch of rectangular image grids. Points along each ray
have uniformly-spaced z-coordinates between a predefined minimum and maximum depth.
`NDCGridRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds`
renderers. I.e. the border of the leftmost / rightmost / topmost / bottommost pixel
has coordinates 1.0 / -1.0 / 1.0 / -1.0 respectively.
"""
def __init__(
self,
image_width: int,
image_height: int,
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
):
"""
Args:
image_width: The horizontal size of the image grid.
image_height: The vertical size of the image grid.
n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
"""
half_pix_width = 1.0 / image_width
half_pix_height = 1.0 / image_height
super().__init__(
min_x=1.0 - half_pix_width,
max_x=-1.0 + half_pix_width,
min_y=1.0 - half_pix_height,
max_y=-1.0 + half_pix_height,
image_width=image_width,
image_height=image_height,
n_pts_per_ray=n_pts_per_ray,
min_depth=min_depth,
max_depth=max_depth,
)
class MonteCarloRaysampler(torch.nn.Module):
"""
Samples a fixed number of pixels within denoted xy bounds uniformly at random.
For each pixel, a fixed number of points is sampled along its ray at uniformly-spaced
z-coordinates such that the z-coordinates range between a predefined minimum
and maximum depth.
"""
def __init__(
self,
min_x: float,
max_x: float,
min_y: float,
max_y: float,
n_rays_per_image: int,
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
):
"""
Args:
min_x: The smallest x-coordinate of each ray's source pixel.
max_x: The largest x-coordinate of each ray's source pixel.
min_y: The smallest y-coordinate of each ray's source pixel.
max_y: The largest y-coordinate of each ray's source pixel.
n_rays_per_image: The number of rays randomly sampled in each camera.
n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of each ray-point.
max_depth: The maximum depth of each ray-point.
"""
super().__init__()
self._min_x = min_x
self._max_x = max_x
self._min_y = min_y
self._max_y = max_y
self._n_rays_per_image = n_rays_per_image
self._n_pts_per_ray = n_pts_per_ray
self._min_depth = min_depth
self._max_depth = max_depth
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
Returns:
A named tuple RayBundle with the following fields:
origins: A tensor of shape
`(batch_size, n_rays_per_image, 3)`
denoting the locations of ray origins in the world coordinates.
directions: A tensor of shape
`(batch_size, n_rays_per_image, 3)`
denoting the directions of each ray in the world coordinates.
lengths: A tensor of shape
`(batch_size, n_rays_per_image, n_pts_per_ray)`
containing the z-coordinate (=depth) of each ray in world units.
xys: A tensor of shape
`(batch_size, n_rays_per_image, 2)`
containing the 2D image coordinates of each ray.
"""
batch_size = cameras.R.shape[0] # pyre-ignore
device = cameras.device
# get the initial grid of image xy coords
# of shape (batch_size, n_rays_per_image, 2)
rays_xy = torch.cat(
[
torch.rand(
size=(batch_size, self._n_rays_per_image, 1),
dtype=torch.float32,
device=device,
)
* (high - low)
+ low
for low, high in (
(self._min_x, self._max_x),
(self._min_y, self._max_y),
)
],
dim=2,
)
return _xy_to_ray_bundle(
cameras, rays_xy, self._min_depth, self._max_depth, self._n_pts_per_ray
)
def _xy_to_ray_bundle(
cameras: CamerasBase,
xy_grid: torch.Tensor,
min_depth: float,
max_depth: float,
n_pts_per_ray: int,
) -> RayBundle:
"""
Extends the `xy_grid` input of shape `(batch_size, ..., 2)` to rays.
This adds to each xy location in the grid a vector of `n_pts_per_ray` depths
uniformly spaced between `min_depth` and `max_depth`.
The extended grid is then unprojected with `cameras` to yield
ray origins, directions and depths.
"""
batch_size = xy_grid.shape[0]
spatial_size = xy_grid.shape[1:-1]
n_rays_per_image = spatial_size.numel() # pyre-ignore
# ray z-coords
depths = torch.linspace(
min_depth, max_depth, n_pts_per_ray, dtype=xy_grid.dtype, device=xy_grid.device
)
rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray)
# make two sets of points at a constant depth=1 and 2
to_unproject = torch.cat(
(
xy_grid.view(batch_size, 1, n_rays_per_image, 2)
.expand(batch_size, 2, n_rays_per_image, 2)
.reshape(batch_size, n_rays_per_image * 2, 2),
torch.cat(
(
xy_grid.new_ones(batch_size, n_rays_per_image, 1), # pyre-ignore
2.0 * xy_grid.new_ones(batch_size, n_rays_per_image, 1),
),
dim=1,
),
),
dim=-1,
)
# unproject the points
unprojected = cameras.unproject_points(to_unproject) # pyre-ignore
# split the two planes back
rays_plane_1_world = unprojected[:, :n_rays_per_image]
rays_plane_2_world = unprojected[:, n_rays_per_image:]
# directions are the differences between the two planes of points
rays_directions_world = rays_plane_2_world - rays_plane_1_world
# origins are given by subtracting the ray directions from the first plane
rays_origins_world = rays_plane_1_world - rays_directions_world
return RayBundle(
rays_origins_world.view(batch_size, *spatial_size, 3),
rays_directions_world.view(batch_size, *spatial_size, 3),
rays_zs.view(batch_size, *spatial_size, n_pts_per_ray),
xy_grid,
)

View File

@ -0,0 +1,82 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple
import torch
class RayBundle(NamedTuple):
"""
RayBundle parametrizes points along projection rays by storing ray `origins`,
`directions` vectors and `lengths` at which the ray-points are sampled.
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
"""
origins: torch.Tensor
directions: torch.Tensor
lengths: torch.Tensor
xys: torch.Tensor
def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor:
"""
Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle`
named tuple) to 3D points by extending each ray according to the corresponding
length.
E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions`
and `ray_bundle.lengths`, the ray point at position `[i, j]` is:
```
ray_bundle.points[i, j, :] = (
ray_bundle.origins[i, :]
+ ray_bundle.directions[i, :] * ray_bundle.lengths[i, j]
)
```
Args:
ray_bundle: A `RayBundle` object with fields:
origins: A tensor of shape `(..., 3)`
directions: A tensor of shape `(..., 3)`
lengths: A tensor of shape `(..., num_points_per_ray)`
Returns:
rays_points: A tensor of shape `(..., num_points_per_ray, 3)`
containing the points sampled along each ray.
"""
return ray_bundle_variables_to_ray_points(
ray_bundle.origins, ray_bundle.directions, ray_bundle.lengths
)
def ray_bundle_variables_to_ray_points(
rays_origins: torch.Tensor,
rays_directions: torch.Tensor,
rays_lengths: torch.Tensor,
) -> torch.Tensor:
"""
Converts rays parametrized with origins, directions
to 3D points by extending each ray according to the corresponding
ray_length:
E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions`
and `rays_lengths`, the ray point at position `[i, j]` is:
```
rays_points[i, j, :] = (
rays_origins[i, :]
+ rays_directions[i, :] * rays_lengths[i, j]
)
```
Args:
rays_origins: A tensor of shape `(..., 3)`
rays_directions: A tensor of shape `(..., 3)`
rays_lengths: A tensor of shape `(..., num_points_per_ray)`
Returns:
rays_points: A tensor of shape `(..., num_points_per_ray, 3)`
containing the points sampled along each ray.
"""
rays_points = (
rays_origins[..., None, :]
+ rays_lengths[..., :, None] * rays_directions[..., None, :]
)
return rays_points

39
tests/bm_raysampling.py Normal file
View File

@ -0,0 +1,39 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
from fvcore.common.benchmark import benchmark
from pytorch3d.renderer import (
GridRaysampler,
MonteCarloRaysampler,
NDCGridRaysampler,
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
)
from test_raysampling import TestRaysampling
def bm_raysampling() -> None:
case_grid = {
"raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler],
"camera_type": [
PerspectiveCameras,
OrthographicCameras,
FoVPerspectiveCameras,
FoVOrthographicCameras,
],
"batch_size": [1, 10],
"n_pts_per_ray": [10, 1000, 10000],
"image_width": [10, 300],
"image_height": [10, 300],
}
test_cases = itertools.product(*case_grid.values())
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
benchmark(TestRaysampling.raysampler, "RAYSAMPLER", kwargs_list, warmup_iters=1)
if __name__ == "__main__":
bm_raysampling()

423
tests/test_raysampling.py Normal file
View File

@ -0,0 +1,423 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import eyes
from pytorch3d.renderer import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
from pytorch3d.renderer.cameras import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
)
from pytorch3d.renderer.implicit.utils import (
ray_bundle_to_ray_points,
ray_bundle_variables_to_ray_points,
)
from pytorch3d.transforms import Rotate
from test_cameras import init_random_cameras
class TestRaysampling(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
@staticmethod
def raysampler(
raysampler_type=GridRaysampler,
camera_type=PerspectiveCameras,
n_pts_per_ray=10,
batch_size=1,
image_width=10,
image_height=20,
):
device = torch.device("cuda")
# init raysamplers
raysampler = TestRaysampling.init_raysampler(
raysampler_type=raysampler_type,
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
max_y=1.0,
image_width=image_width,
image_height=image_height,
min_depth=1.0,
max_depth=10.0,
n_pts_per_ray=n_pts_per_ray,
).to(device)
# init a batch of random cameras
cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device)
def run_raysampler():
raysampler(cameras=cameras)
torch.cuda.synchronize()
return run_raysampler
@staticmethod
def init_raysampler(
raysampler_type=GridRaysampler,
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
max_y=1.0,
image_width=10,
image_height=20,
min_depth=1.0,
max_depth=10.0,
n_pts_per_ray=10,
):
raysampler_params = {
"min_x": min_x,
"max_x": max_x,
"min_y": min_y,
"max_y": max_y,
"n_pts_per_ray": n_pts_per_ray,
"min_depth": min_depth,
"max_depth": max_depth,
}
if issubclass(raysampler_type, GridRaysampler):
raysampler_params.update(
{"image_width": image_width, "image_height": image_height}
)
elif issubclass(raysampler_type, MonteCarloRaysampler):
raysampler_params["n_rays_per_image"] = image_width * image_height
else:
raise ValueError(str(raysampler_type))
if issubclass(raysampler_type, NDCGridRaysampler):
# NDCGridRaysampler does not use min/max_x/y
for k in ("min_x", "max_x", "min_y", "max_y"):
del raysampler_params[k]
# instantiate the raysampler
raysampler = raysampler_type(**raysampler_params)
return raysampler
def test_raysamplers(
self,
batch_size=25,
min_x=-1.0,
max_x=1.0,
min_y=-1.0,
max_y=1.0,
image_width=10,
image_height=20,
min_depth=1.0,
max_depth=10.0,
):
"""
Tests the shapes and outputs of MC and GridRaysamplers for randomly
generated cameras and different numbers of points per ray.
"""
device = torch.device("cuda")
for n_pts_per_ray in (100, 1):
for raysampler_type in (
MonteCarloRaysampler,
GridRaysampler,
NDCGridRaysampler,
):
raysampler = TestRaysampling.init_raysampler(
raysampler_type=raysampler_type,
min_x=min_x,
max_x=max_x,
min_y=min_y,
max_y=max_y,
image_width=image_width,
image_height=image_height,
min_depth=min_depth,
max_depth=max_depth,
n_pts_per_ray=n_pts_per_ray,
)
if issubclass(raysampler_type, NDCGridRaysampler):
# adjust the gt bounds for NDCGridRaysampler
half_pix_width = 1.0 / image_width
half_pix_height = 1.0 / image_height
min_x_ = 1.0 - half_pix_width
max_x_ = -1.0 + half_pix_width
min_y_ = 1.0 - half_pix_height
max_y_ = -1.0 + half_pix_height
else:
min_x_ = min_x
max_x_ = max_x
min_y_ = min_y
max_y_ = max_y
# carry out the test over several camera types
for cam_type in (
FoVPerspectiveCameras,
FoVOrthographicCameras,
OrthographicCameras,
PerspectiveCameras,
):
# init a batch of random cameras
cameras = init_random_cameras(
cam_type, batch_size, random_z=True
).to(device)
# call the raysampler
ray_bundle = raysampler(cameras=cameras)
# check the shapes of the raysampler outputs
self._check_raysampler_output_shapes(
raysampler,
ray_bundle,
batch_size,
image_width,
image_height,
n_pts_per_ray,
)
# check the points sampled along each ray
self._check_raysampler_ray_points(
raysampler,
cameras,
ray_bundle,
min_x_,
max_x_,
min_y_,
max_y_,
image_width,
image_height,
min_depth,
max_depth,
)
# check the output direction vectors
self._check_raysampler_ray_directions(
cameras, raysampler, ray_bundle
)
def _check_grid_shape(self, grid, batch_size, spatial_size, n_pts_per_ray, dim):
"""
A helper for checking the desired size of a variable output by a raysampler.
"""
tgt_shape = [
x for x in [batch_size, *spatial_size, n_pts_per_ray, dim] if x > 0
]
self.assertTrue(all(sz1 == sz2 for sz1, sz2 in zip(grid.shape, tgt_shape)))
def _check_raysampler_output_shapes(
self,
raysampler,
ray_bundle,
batch_size,
image_width,
image_height,
n_pts_per_ray,
):
"""
Checks the shapes of raysampler outputs.
"""
if isinstance(raysampler, GridRaysampler):
spatial_size = [image_height, image_width]
elif isinstance(raysampler, MonteCarloRaysampler):
spatial_size = [image_height * image_width]
else:
raise ValueError(str(type(raysampler)))
self._check_grid_shape(ray_bundle.xys, batch_size, spatial_size, 0, 2)
self._check_grid_shape(ray_bundle.origins, batch_size, spatial_size, 0, 3)
self._check_grid_shape(ray_bundle.directions, batch_size, spatial_size, 0, 3)
self._check_grid_shape(
ray_bundle.lengths, batch_size, spatial_size, n_pts_per_ray, 0
)
def _check_raysampler_ray_points(
self,
raysampler,
cameras,
ray_bundle,
min_x,
max_x,
min_y,
max_y,
image_width,
image_height,
min_depth,
max_depth,
):
"""
Check rays_points_world and rays_zs outputs of raysamplers.
"""
batch_size = cameras.R.shape[0]
# convert to ray points
rays_points_world = ray_bundle_variables_to_ray_points(
ray_bundle.origins, ray_bundle.directions, ray_bundle.lengths
)
n_pts_per_ray = rays_points_world.shape[-2]
# check that the outputs if ray_bundle_variables_to_ray_points and
# ray_bundle_to_ray_points match
rays_points_world_ = ray_bundle_to_ray_points(ray_bundle)
self.assertClose(rays_points_world, rays_points_world_)
# check that the depth of each ray point in camera coords
# matches the expected linearly-spaced depth
depth_expected = torch.linspace(
min_depth,
max_depth,
n_pts_per_ray,
dtype=torch.float32,
device=rays_points_world.device,
)
ray_points_camera = (
cameras.get_world_to_view_transform()
.transform_points(rays_points_world.view(batch_size, -1, 3))
.view(batch_size, -1, n_pts_per_ray, 3)
)
self.assertClose(
ray_points_camera[..., 2],
depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]),
atol=1e-4,
)
# check also that rays_zs is consistent with depth_expected
self.assertClose(
ray_bundle.lengths.view(batch_size, -1, n_pts_per_ray),
depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]),
atol=1e-6,
)
# project the world ray points back to screen space
ray_points_projected = cameras.transform_points(
rays_points_world.view(batch_size, -1, 3)
).view(rays_points_world.shape)
# check that ray_xy matches rays_points_projected xy
rays_xy_projected = ray_points_projected[..., :2].view(
batch_size, -1, n_pts_per_ray, 2
)
self.assertClose(
ray_bundle.xys.view(batch_size, -1, 1, 2).expand_as(rays_xy_projected),
rays_xy_projected,
atol=1e-4,
)
# check that projected world points' xy coordinates
# range correctly between [minx/y, max/y]
if isinstance(raysampler, GridRaysampler):
# get the expected coordinates along each grid axis
ys, xs = [
torch.linspace(
low, high, sz, dtype=torch.float32, device=rays_points_world.device
)
for low, high, sz in (
(min_y, max_y, image_height),
(min_x, max_x, image_width),
)
]
# compare expected xy with the output xy
for dim, gt_axis in zip(
(0, 1), (xs[None, None, :, None], ys[None, :, None, None])
):
self.assertClose(
ray_points_projected[..., dim],
gt_axis.expand_as(ray_points_projected[..., dim]),
atol=1e-4,
)
elif isinstance(raysampler, MonteCarloRaysampler):
# check that the randomly sampled locations
# are within the allowed bounds for both x and y axes
for dim, axis_bounds in zip((0, 1), ((min_x, max_x), (min_y, max_y))):
self.assertTrue(
(
(ray_points_projected[..., dim] <= axis_bounds[1])
& (ray_points_projected[..., dim] >= axis_bounds[0])
).all()
)
# also check that x,y along each ray is constant
if n_pts_per_ray > 1:
self.assertClose(
ray_points_projected[..., :2].std(dim=-2),
torch.zeros_like(ray_points_projected[..., 0, :2]),
atol=1e-5,
)
else:
raise ValueError(str(type(raysampler)))
def _check_raysampler_ray_directions(self, cameras, raysampler, ray_bundle):
"""
Check the rays_directions_world output of raysamplers.
"""
batch_size = cameras.R.shape[0]
n_pts_per_ray = ray_bundle.lengths.shape[-1]
spatial_size = ray_bundle.xys.shape[1:-1]
n_rays_per_image = spatial_size.numel()
# obtain the ray points in world coords
rays_points_world = cameras.unproject_points(
torch.cat(
(
ray_bundle.xys.view(batch_size, n_rays_per_image, 1, 2).expand(
batch_size, n_rays_per_image, n_pts_per_ray, 2
),
ray_bundle.lengths.view(
batch_size, n_rays_per_image, n_pts_per_ray, 1
),
),
dim=-1,
).view(batch_size, -1, 3)
).view(batch_size, -1, n_pts_per_ray, 3)
# reshape to common testing size
rays_directions_world_normed = torch.nn.functional.normalize(
ray_bundle.directions.view(batch_size, -1, 3), dim=-1
)
# check that the l2-normed difference of all consecutive planes
# of points in world coords matches ray_directions_world
rays_directions_world_ = torch.nn.functional.normalize(
rays_points_world[:, :, -1:] - rays_points_world[:, :, :-1], dim=-1
)
self.assertClose(
rays_directions_world_normed[:, :, None].expand_as(rays_directions_world_),
rays_directions_world_,
atol=1e-4,
)
# check the ray directions rotated using camera rotation matrix
# match the ray directions of a camera with trivial extrinsics
cameras_trivial_extrinsic = cameras.clone()
cameras_trivial_extrinsic.R = eyes(
N=batch_size, dim=3, dtype=cameras.R.dtype, device=cameras.device
)
cameras_trivial_extrinsic.T = torch.zeros_like(cameras.T)
# make sure we get the same random rays in case we call the
# MonteCarloRaysampler twice below
with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
torch.random.manual_seed(42)
ray_bundle_world_fix_seed = raysampler(cameras=cameras)
torch.random.manual_seed(42)
ray_bundle_camera_fix_seed = raysampler(cameras=cameras_trivial_extrinsic)
rays_directions_camera_fix_seed_ = Rotate(
cameras.R, device=cameras.R.device
).transform_points(ray_bundle_world_fix_seed.directions.view(batch_size, -1, 3))
self.assertClose(
rays_directions_camera_fix_seed_,
ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3),
atol=1e-5,
)