mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Raysampling
Summary: Implements 3 basic raysamplers. Reviewed By: nikhilaravi Differential Revision: D24110643 fbshipit-source-id: eb67d0e56773c7871ebdcb23e7e520302dc1b3c9
This commit is contained in:
parent
1f9cf91e1b
commit
e6bc960fb5
@ -20,7 +20,16 @@ from .cameras import (
|
||||
look_at_rotation,
|
||||
look_at_view_transform,
|
||||
)
|
||||
from .implicit import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||
from .implicit import (
|
||||
AbsorptionOnlyRaymarcher,
|
||||
EmissionAbsorptionRaymarcher,
|
||||
GridRaysampler,
|
||||
MonteCarloRaysampler,
|
||||
NDCGridRaysampler,
|
||||
RayBundle,
|
||||
ray_bundle_to_ray_points,
|
||||
ray_bundle_variables_to_ray_points,
|
||||
)
|
||||
from .lighting import DirectionalLights, PointLights, diffuse, specular
|
||||
from .materials import Materials
|
||||
from .mesh import (
|
||||
|
@ -1,6 +1,12 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||
from .utils import (
|
||||
RayBundle,
|
||||
ray_bundle_to_ray_points,
|
||||
ray_bundle_variables_to_ray_points,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
320
pytorch3d/renderer/implicit/raysampling.py
Normal file
320
pytorch3d/renderer/implicit/raysampling.py
Normal file
@ -0,0 +1,320 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import torch
|
||||
|
||||
from ..cameras import CamerasBase
|
||||
from .utils import RayBundle
|
||||
|
||||
|
||||
"""
|
||||
This file defines three raysampling techniques:
|
||||
- GridRaysampler which can be used to sample rays from pixels of an image grid
|
||||
- NDCGridRaysampler which can be used to sample rays from pixels of an image grid,
|
||||
which follows the pytorch3d convention for image grid coordinates
|
||||
- MonteCarloRaysampler which randomly selects image pixels and emits rays from them
|
||||
"""
|
||||
|
||||
|
||||
class GridRaysampler(torch.nn.Module):
|
||||
"""
|
||||
Samples a fixed number of points along rays which are regulary distributed
|
||||
in a batch of rectangular image grids. Points along each ray
|
||||
have uniformly-spaced z-coordinates between a predefined
|
||||
minimum and maximum depth.
|
||||
|
||||
The raysampler first generates a 3D coordinate grid of the following form:
|
||||
```
|
||||
/ min_x, min_y, max_depth -------------- / max_x, min_y, max_depth
|
||||
/ /|
|
||||
/ / | ^
|
||||
/ min_depth min_depth / | |
|
||||
min_x ----------------------------- max_x | | image
|
||||
min_y min_y | | height
|
||||
| | | |
|
||||
| | | v
|
||||
| | |
|
||||
| | / max_x, max_y, ^
|
||||
| | / max_depth /
|
||||
min_x max_y / / n_pts_per_ray
|
||||
max_y ----------------------------- max_x/ min_depth v
|
||||
< --- image_width --- >
|
||||
```
|
||||
|
||||
In order to generate ray points, `GridRaysampler` takes each 3D point of
|
||||
the grid (with coordinates `[x, y, depth]`) and unprojects it
|
||||
with `cameras.unproject_points([x, y, depth])`, where `cameras` are an
|
||||
additional input to the `forward` function.
|
||||
|
||||
Note that this is a generic implementation that can support any image grid
|
||||
coordinate convention. For a raysampler which follows the PyTorch3D
|
||||
coordinate conventions please refer to `NDCGridRaysampler`.
|
||||
As such, `NDCGridRaysampler` is a special case of `GridRaysampler`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_x: float,
|
||||
max_x: float,
|
||||
min_y: float,
|
||||
max_y: float,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
n_pts_per_ray: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
min_x: The leftmost x-coordinate of each ray's source pixel's center.
|
||||
max_x: The rightmost x-coordinate of each ray's source pixel's center.
|
||||
min_y: The topmost y-coordinate of each ray's source pixel's center.
|
||||
max_y: The bottommost y-coordinate of each ray's source pixel's center.
|
||||
image_width: The horizontal size of the image grid.
|
||||
image_height: The vertical size of the image grid.
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
min_depth: The minimum depth of a ray-point.
|
||||
max_depth: The maximum depth of a ray-point.
|
||||
"""
|
||||
super().__init__()
|
||||
self._n_pts_per_ray = n_pts_per_ray
|
||||
self._min_depth = min_depth
|
||||
self._max_depth = max_depth
|
||||
|
||||
# get the initial grid of image xy coords
|
||||
_xy_grid = torch.stack(
|
||||
tuple(
|
||||
reversed(
|
||||
torch.meshgrid(
|
||||
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
|
||||
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
|
||||
)
|
||||
)
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
self.register_buffer("_xy_grid", _xy_grid)
|
||||
|
||||
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
||||
"""
|
||||
Args:
|
||||
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||
Returns:
|
||||
A named tuple RayBundle with the following fields:
|
||||
origins: A tensor of shape
|
||||
`(batch_size, image_height, image_width, 3)`
|
||||
denoting the locations of ray origins in the world coordinates.
|
||||
directions: A tensor of shape
|
||||
`(batch_size, image_height, image_width, 3)`
|
||||
denoting the directions of each ray in the world coordinates.
|
||||
lengths: A tensor of shape
|
||||
`(batch_size, image_height, image_width, n_pts_per_ray)`
|
||||
containing the z-coordinate (=depth) of each ray in world units.
|
||||
xys: A tensor of shape
|
||||
`(batch_size, image_height, image_width, 2)`
|
||||
containing the 2D image coordinates of each ray.
|
||||
"""
|
||||
|
||||
batch_size = cameras.R.shape[0] # pyre-ignore
|
||||
|
||||
device = cameras.device
|
||||
|
||||
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
|
||||
xy_grid = self._xy_grid.to(device)[None].expand( # pyre-ignore
|
||||
batch_size, *self._xy_grid.shape
|
||||
)
|
||||
|
||||
return _xy_to_ray_bundle(
|
||||
cameras, xy_grid, self._min_depth, self._max_depth, self._n_pts_per_ray
|
||||
)
|
||||
|
||||
|
||||
class NDCGridRaysampler(GridRaysampler):
|
||||
"""
|
||||
Samples a fixed number of points along rays which are regulary distributed
|
||||
in a batch of rectangular image grids. Points along each ray
|
||||
have uniformly-spaced z-coordinates between a predefined minimum and maximum depth.
|
||||
|
||||
`NDCGridRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds`
|
||||
renderers. I.e. the border of the leftmost / rightmost / topmost / bottommost pixel
|
||||
has coordinates 1.0 / -1.0 / 1.0 / -1.0 respectively.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
n_pts_per_ray: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
image_width: The horizontal size of the image grid.
|
||||
image_height: The vertical size of the image grid.
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
min_depth: The minimum depth of a ray-point.
|
||||
max_depth: The maximum depth of a ray-point.
|
||||
"""
|
||||
half_pix_width = 1.0 / image_width
|
||||
half_pix_height = 1.0 / image_height
|
||||
super().__init__(
|
||||
min_x=1.0 - half_pix_width,
|
||||
max_x=-1.0 + half_pix_width,
|
||||
min_y=1.0 - half_pix_height,
|
||||
max_y=-1.0 + half_pix_height,
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
|
||||
class MonteCarloRaysampler(torch.nn.Module):
|
||||
"""
|
||||
Samples a fixed number of pixels within denoted xy bounds uniformly at random.
|
||||
For each pixel, a fixed number of points is sampled along its ray at uniformly-spaced
|
||||
z-coordinates such that the z-coordinates range between a predefined minimum
|
||||
and maximum depth.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_x: float,
|
||||
max_x: float,
|
||||
min_y: float,
|
||||
max_y: float,
|
||||
n_rays_per_image: int,
|
||||
n_pts_per_ray: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
min_x: The smallest x-coordinate of each ray's source pixel.
|
||||
max_x: The largest x-coordinate of each ray's source pixel.
|
||||
min_y: The smallest y-coordinate of each ray's source pixel.
|
||||
max_y: The largest y-coordinate of each ray's source pixel.
|
||||
n_rays_per_image: The number of rays randomly sampled in each camera.
|
||||
n_pts_per_ray: The number of points sampled along each ray.
|
||||
min_depth: The minimum depth of each ray-point.
|
||||
max_depth: The maximum depth of each ray-point.
|
||||
"""
|
||||
super().__init__()
|
||||
self._min_x = min_x
|
||||
self._max_x = max_x
|
||||
self._min_y = min_y
|
||||
self._max_y = max_y
|
||||
self._n_rays_per_image = n_rays_per_image
|
||||
self._n_pts_per_ray = n_pts_per_ray
|
||||
self._min_depth = min_depth
|
||||
self._max_depth = max_depth
|
||||
|
||||
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
||||
"""
|
||||
Args:
|
||||
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||
Returns:
|
||||
A named tuple RayBundle with the following fields:
|
||||
origins: A tensor of shape
|
||||
`(batch_size, n_rays_per_image, 3)`
|
||||
denoting the locations of ray origins in the world coordinates.
|
||||
directions: A tensor of shape
|
||||
`(batch_size, n_rays_per_image, 3)`
|
||||
denoting the directions of each ray in the world coordinates.
|
||||
lengths: A tensor of shape
|
||||
`(batch_size, n_rays_per_image, n_pts_per_ray)`
|
||||
containing the z-coordinate (=depth) of each ray in world units.
|
||||
xys: A tensor of shape
|
||||
`(batch_size, n_rays_per_image, 2)`
|
||||
containing the 2D image coordinates of each ray.
|
||||
"""
|
||||
|
||||
batch_size = cameras.R.shape[0] # pyre-ignore
|
||||
|
||||
device = cameras.device
|
||||
|
||||
# get the initial grid of image xy coords
|
||||
# of shape (batch_size, n_rays_per_image, 2)
|
||||
rays_xy = torch.cat(
|
||||
[
|
||||
torch.rand(
|
||||
size=(batch_size, self._n_rays_per_image, 1),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
* (high - low)
|
||||
+ low
|
||||
for low, high in (
|
||||
(self._min_x, self._max_x),
|
||||
(self._min_y, self._max_y),
|
||||
)
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
return _xy_to_ray_bundle(
|
||||
cameras, rays_xy, self._min_depth, self._max_depth, self._n_pts_per_ray
|
||||
)
|
||||
|
||||
|
||||
def _xy_to_ray_bundle(
|
||||
cameras: CamerasBase,
|
||||
xy_grid: torch.Tensor,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
n_pts_per_ray: int,
|
||||
) -> RayBundle:
|
||||
"""
|
||||
Extends the `xy_grid` input of shape `(batch_size, ..., 2)` to rays.
|
||||
This adds to each xy location in the grid a vector of `n_pts_per_ray` depths
|
||||
uniformly spaced between `min_depth` and `max_depth`.
|
||||
|
||||
The extended grid is then unprojected with `cameras` to yield
|
||||
ray origins, directions and depths.
|
||||
"""
|
||||
batch_size = xy_grid.shape[0]
|
||||
spatial_size = xy_grid.shape[1:-1]
|
||||
n_rays_per_image = spatial_size.numel() # pyre-ignore
|
||||
|
||||
# ray z-coords
|
||||
depths = torch.linspace(
|
||||
min_depth, max_depth, n_pts_per_ray, dtype=xy_grid.dtype, device=xy_grid.device
|
||||
)
|
||||
rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray)
|
||||
|
||||
# make two sets of points at a constant depth=1 and 2
|
||||
to_unproject = torch.cat(
|
||||
(
|
||||
xy_grid.view(batch_size, 1, n_rays_per_image, 2)
|
||||
.expand(batch_size, 2, n_rays_per_image, 2)
|
||||
.reshape(batch_size, n_rays_per_image * 2, 2),
|
||||
torch.cat(
|
||||
(
|
||||
xy_grid.new_ones(batch_size, n_rays_per_image, 1), # pyre-ignore
|
||||
2.0 * xy_grid.new_ones(batch_size, n_rays_per_image, 1),
|
||||
),
|
||||
dim=1,
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# unproject the points
|
||||
unprojected = cameras.unproject_points(to_unproject) # pyre-ignore
|
||||
|
||||
# split the two planes back
|
||||
rays_plane_1_world = unprojected[:, :n_rays_per_image]
|
||||
rays_plane_2_world = unprojected[:, n_rays_per_image:]
|
||||
|
||||
# directions are the differences between the two planes of points
|
||||
rays_directions_world = rays_plane_2_world - rays_plane_1_world
|
||||
|
||||
# origins are given by subtracting the ray directions from the first plane
|
||||
rays_origins_world = rays_plane_1_world - rays_directions_world
|
||||
|
||||
return RayBundle(
|
||||
rays_origins_world.view(batch_size, *spatial_size, 3),
|
||||
rays_directions_world.view(batch_size, *spatial_size, 3),
|
||||
rays_zs.view(batch_size, *spatial_size, n_pts_per_ray),
|
||||
xy_grid,
|
||||
)
|
82
pytorch3d/renderer/implicit/utils.py
Normal file
82
pytorch3d/renderer/implicit/utils.py
Normal file
@ -0,0 +1,82 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RayBundle(NamedTuple):
|
||||
"""
|
||||
RayBundle parametrizes points along projection rays by storing ray `origins`,
|
||||
`directions` vectors and `lengths` at which the ray-points are sampled.
|
||||
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
|
||||
"""
|
||||
|
||||
origins: torch.Tensor
|
||||
directions: torch.Tensor
|
||||
lengths: torch.Tensor
|
||||
xys: torch.Tensor
|
||||
|
||||
|
||||
def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor:
|
||||
"""
|
||||
Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle`
|
||||
named tuple) to 3D points by extending each ray according to the corresponding
|
||||
length.
|
||||
|
||||
E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions`
|
||||
and `ray_bundle.lengths`, the ray point at position `[i, j]` is:
|
||||
```
|
||||
ray_bundle.points[i, j, :] = (
|
||||
ray_bundle.origins[i, :]
|
||||
+ ray_bundle.directions[i, :] * ray_bundle.lengths[i, j]
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
ray_bundle: A `RayBundle` object with fields:
|
||||
origins: A tensor of shape `(..., 3)`
|
||||
directions: A tensor of shape `(..., 3)`
|
||||
lengths: A tensor of shape `(..., num_points_per_ray)`
|
||||
|
||||
Returns:
|
||||
rays_points: A tensor of shape `(..., num_points_per_ray, 3)`
|
||||
containing the points sampled along each ray.
|
||||
"""
|
||||
return ray_bundle_variables_to_ray_points(
|
||||
ray_bundle.origins, ray_bundle.directions, ray_bundle.lengths
|
||||
)
|
||||
|
||||
|
||||
def ray_bundle_variables_to_ray_points(
|
||||
rays_origins: torch.Tensor,
|
||||
rays_directions: torch.Tensor,
|
||||
rays_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts rays parametrized with origins, directions
|
||||
to 3D points by extending each ray according to the corresponding
|
||||
ray_length:
|
||||
|
||||
E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions`
|
||||
and `rays_lengths`, the ray point at position `[i, j]` is:
|
||||
```
|
||||
rays_points[i, j, :] = (
|
||||
rays_origins[i, :]
|
||||
+ rays_directions[i, :] * rays_lengths[i, j]
|
||||
)
|
||||
```
|
||||
|
||||
Args:
|
||||
rays_origins: A tensor of shape `(..., 3)`
|
||||
rays_directions: A tensor of shape `(..., 3)`
|
||||
rays_lengths: A tensor of shape `(..., num_points_per_ray)`
|
||||
|
||||
Returns:
|
||||
rays_points: A tensor of shape `(..., num_points_per_ray, 3)`
|
||||
containing the points sampled along each ray.
|
||||
"""
|
||||
rays_points = (
|
||||
rays_origins[..., None, :]
|
||||
+ rays_lengths[..., :, None] * rays_directions[..., None, :]
|
||||
)
|
||||
return rays_points
|
39
tests/bm_raysampling.py
Normal file
39
tests/bm_raysampling.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import itertools
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from pytorch3d.renderer import (
|
||||
GridRaysampler,
|
||||
MonteCarloRaysampler,
|
||||
NDCGridRaysampler,
|
||||
FoVOrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
)
|
||||
from test_raysampling import TestRaysampling
|
||||
|
||||
|
||||
def bm_raysampling() -> None:
|
||||
case_grid = {
|
||||
"raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler],
|
||||
"camera_type": [
|
||||
PerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
FoVOrthographicCameras,
|
||||
],
|
||||
"batch_size": [1, 10],
|
||||
"n_pts_per_ray": [10, 1000, 10000],
|
||||
"image_width": [10, 300],
|
||||
"image_height": [10, 300],
|
||||
}
|
||||
test_cases = itertools.product(*case_grid.values())
|
||||
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||
|
||||
benchmark(TestRaysampling.raysampler, "RAYSAMPLER", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_raysampling()
|
423
tests/test_raysampling.py
Normal file
423
tests/test_raysampling.py
Normal file
@ -0,0 +1,423 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops import eyes
|
||||
from pytorch3d.renderer import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||
from pytorch3d.renderer.cameras import (
|
||||
FoVOrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
)
|
||||
from pytorch3d.renderer.implicit.utils import (
|
||||
ray_bundle_to_ray_points,
|
||||
ray_bundle_variables_to_ray_points,
|
||||
)
|
||||
from pytorch3d.transforms import Rotate
|
||||
from test_cameras import init_random_cameras
|
||||
|
||||
|
||||
class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
|
||||
@staticmethod
|
||||
def raysampler(
|
||||
raysampler_type=GridRaysampler,
|
||||
camera_type=PerspectiveCameras,
|
||||
n_pts_per_ray=10,
|
||||
batch_size=1,
|
||||
image_width=10,
|
||||
image_height=20,
|
||||
):
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
# init raysamplers
|
||||
raysampler = TestRaysampling.init_raysampler(
|
||||
raysampler_type=raysampler_type,
|
||||
min_x=-1.0,
|
||||
max_x=1.0,
|
||||
min_y=-1.0,
|
||||
max_y=1.0,
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
min_depth=1.0,
|
||||
max_depth=10.0,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
).to(device)
|
||||
|
||||
# init a batch of random cameras
|
||||
cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device)
|
||||
|
||||
def run_raysampler():
|
||||
raysampler(cameras=cameras)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return run_raysampler
|
||||
|
||||
@staticmethod
|
||||
def init_raysampler(
|
||||
raysampler_type=GridRaysampler,
|
||||
min_x=-1.0,
|
||||
max_x=1.0,
|
||||
min_y=-1.0,
|
||||
max_y=1.0,
|
||||
image_width=10,
|
||||
image_height=20,
|
||||
min_depth=1.0,
|
||||
max_depth=10.0,
|
||||
n_pts_per_ray=10,
|
||||
):
|
||||
raysampler_params = {
|
||||
"min_x": min_x,
|
||||
"max_x": max_x,
|
||||
"min_y": min_y,
|
||||
"max_y": max_y,
|
||||
"n_pts_per_ray": n_pts_per_ray,
|
||||
"min_depth": min_depth,
|
||||
"max_depth": max_depth,
|
||||
}
|
||||
|
||||
if issubclass(raysampler_type, GridRaysampler):
|
||||
raysampler_params.update(
|
||||
{"image_width": image_width, "image_height": image_height}
|
||||
)
|
||||
elif issubclass(raysampler_type, MonteCarloRaysampler):
|
||||
raysampler_params["n_rays_per_image"] = image_width * image_height
|
||||
else:
|
||||
raise ValueError(str(raysampler_type))
|
||||
|
||||
if issubclass(raysampler_type, NDCGridRaysampler):
|
||||
# NDCGridRaysampler does not use min/max_x/y
|
||||
for k in ("min_x", "max_x", "min_y", "max_y"):
|
||||
del raysampler_params[k]
|
||||
|
||||
# instantiate the raysampler
|
||||
raysampler = raysampler_type(**raysampler_params)
|
||||
|
||||
return raysampler
|
||||
|
||||
def test_raysamplers(
|
||||
self,
|
||||
batch_size=25,
|
||||
min_x=-1.0,
|
||||
max_x=1.0,
|
||||
min_y=-1.0,
|
||||
max_y=1.0,
|
||||
image_width=10,
|
||||
image_height=20,
|
||||
min_depth=1.0,
|
||||
max_depth=10.0,
|
||||
):
|
||||
"""
|
||||
Tests the shapes and outputs of MC and GridRaysamplers for randomly
|
||||
generated cameras and different numbers of points per ray.
|
||||
"""
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
for n_pts_per_ray in (100, 1):
|
||||
|
||||
for raysampler_type in (
|
||||
MonteCarloRaysampler,
|
||||
GridRaysampler,
|
||||
NDCGridRaysampler,
|
||||
):
|
||||
|
||||
raysampler = TestRaysampling.init_raysampler(
|
||||
raysampler_type=raysampler_type,
|
||||
min_x=min_x,
|
||||
max_x=max_x,
|
||||
min_y=min_y,
|
||||
max_y=max_y,
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
)
|
||||
|
||||
if issubclass(raysampler_type, NDCGridRaysampler):
|
||||
# adjust the gt bounds for NDCGridRaysampler
|
||||
half_pix_width = 1.0 / image_width
|
||||
half_pix_height = 1.0 / image_height
|
||||
min_x_ = 1.0 - half_pix_width
|
||||
max_x_ = -1.0 + half_pix_width
|
||||
min_y_ = 1.0 - half_pix_height
|
||||
max_y_ = -1.0 + half_pix_height
|
||||
else:
|
||||
min_x_ = min_x
|
||||
max_x_ = max_x
|
||||
min_y_ = min_y
|
||||
max_y_ = max_y
|
||||
|
||||
# carry out the test over several camera types
|
||||
for cam_type in (
|
||||
FoVPerspectiveCameras,
|
||||
FoVOrthographicCameras,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
):
|
||||
|
||||
# init a batch of random cameras
|
||||
cameras = init_random_cameras(
|
||||
cam_type, batch_size, random_z=True
|
||||
).to(device)
|
||||
|
||||
# call the raysampler
|
||||
ray_bundle = raysampler(cameras=cameras)
|
||||
|
||||
# check the shapes of the raysampler outputs
|
||||
self._check_raysampler_output_shapes(
|
||||
raysampler,
|
||||
ray_bundle,
|
||||
batch_size,
|
||||
image_width,
|
||||
image_height,
|
||||
n_pts_per_ray,
|
||||
)
|
||||
|
||||
# check the points sampled along each ray
|
||||
self._check_raysampler_ray_points(
|
||||
raysampler,
|
||||
cameras,
|
||||
ray_bundle,
|
||||
min_x_,
|
||||
max_x_,
|
||||
min_y_,
|
||||
max_y_,
|
||||
image_width,
|
||||
image_height,
|
||||
min_depth,
|
||||
max_depth,
|
||||
)
|
||||
|
||||
# check the output direction vectors
|
||||
self._check_raysampler_ray_directions(
|
||||
cameras, raysampler, ray_bundle
|
||||
)
|
||||
|
||||
def _check_grid_shape(self, grid, batch_size, spatial_size, n_pts_per_ray, dim):
|
||||
"""
|
||||
A helper for checking the desired size of a variable output by a raysampler.
|
||||
"""
|
||||
tgt_shape = [
|
||||
x for x in [batch_size, *spatial_size, n_pts_per_ray, dim] if x > 0
|
||||
]
|
||||
self.assertTrue(all(sz1 == sz2 for sz1, sz2 in zip(grid.shape, tgt_shape)))
|
||||
|
||||
def _check_raysampler_output_shapes(
|
||||
self,
|
||||
raysampler,
|
||||
ray_bundle,
|
||||
batch_size,
|
||||
image_width,
|
||||
image_height,
|
||||
n_pts_per_ray,
|
||||
):
|
||||
"""
|
||||
Checks the shapes of raysampler outputs.
|
||||
"""
|
||||
|
||||
if isinstance(raysampler, GridRaysampler):
|
||||
spatial_size = [image_height, image_width]
|
||||
elif isinstance(raysampler, MonteCarloRaysampler):
|
||||
spatial_size = [image_height * image_width]
|
||||
else:
|
||||
raise ValueError(str(type(raysampler)))
|
||||
|
||||
self._check_grid_shape(ray_bundle.xys, batch_size, spatial_size, 0, 2)
|
||||
self._check_grid_shape(ray_bundle.origins, batch_size, spatial_size, 0, 3)
|
||||
self._check_grid_shape(ray_bundle.directions, batch_size, spatial_size, 0, 3)
|
||||
self._check_grid_shape(
|
||||
ray_bundle.lengths, batch_size, spatial_size, n_pts_per_ray, 0
|
||||
)
|
||||
|
||||
def _check_raysampler_ray_points(
|
||||
self,
|
||||
raysampler,
|
||||
cameras,
|
||||
ray_bundle,
|
||||
min_x,
|
||||
max_x,
|
||||
min_y,
|
||||
max_y,
|
||||
image_width,
|
||||
image_height,
|
||||
min_depth,
|
||||
max_depth,
|
||||
):
|
||||
"""
|
||||
Check rays_points_world and rays_zs outputs of raysamplers.
|
||||
"""
|
||||
|
||||
batch_size = cameras.R.shape[0]
|
||||
|
||||
# convert to ray points
|
||||
rays_points_world = ray_bundle_variables_to_ray_points(
|
||||
ray_bundle.origins, ray_bundle.directions, ray_bundle.lengths
|
||||
)
|
||||
n_pts_per_ray = rays_points_world.shape[-2]
|
||||
|
||||
# check that the outputs if ray_bundle_variables_to_ray_points and
|
||||
# ray_bundle_to_ray_points match
|
||||
rays_points_world_ = ray_bundle_to_ray_points(ray_bundle)
|
||||
self.assertClose(rays_points_world, rays_points_world_)
|
||||
|
||||
# check that the depth of each ray point in camera coords
|
||||
# matches the expected linearly-spaced depth
|
||||
depth_expected = torch.linspace(
|
||||
min_depth,
|
||||
max_depth,
|
||||
n_pts_per_ray,
|
||||
dtype=torch.float32,
|
||||
device=rays_points_world.device,
|
||||
)
|
||||
ray_points_camera = (
|
||||
cameras.get_world_to_view_transform()
|
||||
.transform_points(rays_points_world.view(batch_size, -1, 3))
|
||||
.view(batch_size, -1, n_pts_per_ray, 3)
|
||||
)
|
||||
self.assertClose(
|
||||
ray_points_camera[..., 2],
|
||||
depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]),
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
# check also that rays_zs is consistent with depth_expected
|
||||
self.assertClose(
|
||||
ray_bundle.lengths.view(batch_size, -1, n_pts_per_ray),
|
||||
depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]),
|
||||
atol=1e-6,
|
||||
)
|
||||
|
||||
# project the world ray points back to screen space
|
||||
ray_points_projected = cameras.transform_points(
|
||||
rays_points_world.view(batch_size, -1, 3)
|
||||
).view(rays_points_world.shape)
|
||||
|
||||
# check that ray_xy matches rays_points_projected xy
|
||||
rays_xy_projected = ray_points_projected[..., :2].view(
|
||||
batch_size, -1, n_pts_per_ray, 2
|
||||
)
|
||||
self.assertClose(
|
||||
ray_bundle.xys.view(batch_size, -1, 1, 2).expand_as(rays_xy_projected),
|
||||
rays_xy_projected,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
# check that projected world points' xy coordinates
|
||||
# range correctly between [minx/y, max/y]
|
||||
if isinstance(raysampler, GridRaysampler):
|
||||
# get the expected coordinates along each grid axis
|
||||
ys, xs = [
|
||||
torch.linspace(
|
||||
low, high, sz, dtype=torch.float32, device=rays_points_world.device
|
||||
)
|
||||
for low, high, sz in (
|
||||
(min_y, max_y, image_height),
|
||||
(min_x, max_x, image_width),
|
||||
)
|
||||
]
|
||||
# compare expected xy with the output xy
|
||||
for dim, gt_axis in zip(
|
||||
(0, 1), (xs[None, None, :, None], ys[None, :, None, None])
|
||||
):
|
||||
self.assertClose(
|
||||
ray_points_projected[..., dim],
|
||||
gt_axis.expand_as(ray_points_projected[..., dim]),
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
elif isinstance(raysampler, MonteCarloRaysampler):
|
||||
# check that the randomly sampled locations
|
||||
# are within the allowed bounds for both x and y axes
|
||||
for dim, axis_bounds in zip((0, 1), ((min_x, max_x), (min_y, max_y))):
|
||||
self.assertTrue(
|
||||
(
|
||||
(ray_points_projected[..., dim] <= axis_bounds[1])
|
||||
& (ray_points_projected[..., dim] >= axis_bounds[0])
|
||||
).all()
|
||||
)
|
||||
|
||||
# also check that x,y along each ray is constant
|
||||
if n_pts_per_ray > 1:
|
||||
self.assertClose(
|
||||
ray_points_projected[..., :2].std(dim=-2),
|
||||
torch.zeros_like(ray_points_projected[..., 0, :2]),
|
||||
atol=1e-5,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(str(type(raysampler)))
|
||||
|
||||
def _check_raysampler_ray_directions(self, cameras, raysampler, ray_bundle):
|
||||
"""
|
||||
Check the rays_directions_world output of raysamplers.
|
||||
"""
|
||||
|
||||
batch_size = cameras.R.shape[0]
|
||||
n_pts_per_ray = ray_bundle.lengths.shape[-1]
|
||||
spatial_size = ray_bundle.xys.shape[1:-1]
|
||||
n_rays_per_image = spatial_size.numel()
|
||||
|
||||
# obtain the ray points in world coords
|
||||
rays_points_world = cameras.unproject_points(
|
||||
torch.cat(
|
||||
(
|
||||
ray_bundle.xys.view(batch_size, n_rays_per_image, 1, 2).expand(
|
||||
batch_size, n_rays_per_image, n_pts_per_ray, 2
|
||||
),
|
||||
ray_bundle.lengths.view(
|
||||
batch_size, n_rays_per_image, n_pts_per_ray, 1
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
).view(batch_size, -1, 3)
|
||||
).view(batch_size, -1, n_pts_per_ray, 3)
|
||||
|
||||
# reshape to common testing size
|
||||
rays_directions_world_normed = torch.nn.functional.normalize(
|
||||
ray_bundle.directions.view(batch_size, -1, 3), dim=-1
|
||||
)
|
||||
|
||||
# check that the l2-normed difference of all consecutive planes
|
||||
# of points in world coords matches ray_directions_world
|
||||
rays_directions_world_ = torch.nn.functional.normalize(
|
||||
rays_points_world[:, :, -1:] - rays_points_world[:, :, :-1], dim=-1
|
||||
)
|
||||
self.assertClose(
|
||||
rays_directions_world_normed[:, :, None].expand_as(rays_directions_world_),
|
||||
rays_directions_world_,
|
||||
atol=1e-4,
|
||||
)
|
||||
|
||||
# check the ray directions rotated using camera rotation matrix
|
||||
# match the ray directions of a camera with trivial extrinsics
|
||||
cameras_trivial_extrinsic = cameras.clone()
|
||||
cameras_trivial_extrinsic.R = eyes(
|
||||
N=batch_size, dim=3, dtype=cameras.R.dtype, device=cameras.device
|
||||
)
|
||||
cameras_trivial_extrinsic.T = torch.zeros_like(cameras.T)
|
||||
|
||||
# make sure we get the same random rays in case we call the
|
||||
# MonteCarloRaysampler twice below
|
||||
with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
|
||||
torch.random.manual_seed(42)
|
||||
ray_bundle_world_fix_seed = raysampler(cameras=cameras)
|
||||
torch.random.manual_seed(42)
|
||||
ray_bundle_camera_fix_seed = raysampler(cameras=cameras_trivial_extrinsic)
|
||||
|
||||
rays_directions_camera_fix_seed_ = Rotate(
|
||||
cameras.R, device=cameras.R.device
|
||||
).transform_points(ray_bundle_world_fix_seed.directions.view(batch_size, -1, 3))
|
||||
|
||||
self.assertClose(
|
||||
rays_directions_camera_fix_seed_,
|
||||
ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3),
|
||||
atol=1e-5,
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user