mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Raysampling
Summary: Implements 3 basic raysamplers. Reviewed By: nikhilaravi Differential Revision: D24110643 fbshipit-source-id: eb67d0e56773c7871ebdcb23e7e520302dc1b3c9
This commit is contained in:
parent
1f9cf91e1b
commit
e6bc960fb5
@ -20,7 +20,16 @@ from .cameras import (
|
|||||||
look_at_rotation,
|
look_at_rotation,
|
||||||
look_at_view_transform,
|
look_at_view_transform,
|
||||||
)
|
)
|
||||||
from .implicit import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
from .implicit import (
|
||||||
|
AbsorptionOnlyRaymarcher,
|
||||||
|
EmissionAbsorptionRaymarcher,
|
||||||
|
GridRaysampler,
|
||||||
|
MonteCarloRaysampler,
|
||||||
|
NDCGridRaysampler,
|
||||||
|
RayBundle,
|
||||||
|
ray_bundle_to_ray_points,
|
||||||
|
ray_bundle_variables_to_ray_points,
|
||||||
|
)
|
||||||
from .lighting import DirectionalLights, PointLights, diffuse, specular
|
from .lighting import DirectionalLights, PointLights, diffuse, specular
|
||||||
from .materials import Materials
|
from .materials import Materials
|
||||||
from .mesh import (
|
from .mesh import (
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
from .raymarching import AbsorptionOnlyRaymarcher, EmissionAbsorptionRaymarcher
|
||||||
|
from .raysampling import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||||
|
from .utils import (
|
||||||
|
RayBundle,
|
||||||
|
ray_bundle_to_ray_points,
|
||||||
|
ray_bundle_variables_to_ray_points,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||||
|
320
pytorch3d/renderer/implicit/raysampling.py
Normal file
320
pytorch3d/renderer/implicit/raysampling.py
Normal file
@ -0,0 +1,320 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..cameras import CamerasBase
|
||||||
|
from .utils import RayBundle
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file defines three raysampling techniques:
|
||||||
|
- GridRaysampler which can be used to sample rays from pixels of an image grid
|
||||||
|
- NDCGridRaysampler which can be used to sample rays from pixels of an image grid,
|
||||||
|
which follows the pytorch3d convention for image grid coordinates
|
||||||
|
- MonteCarloRaysampler which randomly selects image pixels and emits rays from them
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class GridRaysampler(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Samples a fixed number of points along rays which are regulary distributed
|
||||||
|
in a batch of rectangular image grids. Points along each ray
|
||||||
|
have uniformly-spaced z-coordinates between a predefined
|
||||||
|
minimum and maximum depth.
|
||||||
|
|
||||||
|
The raysampler first generates a 3D coordinate grid of the following form:
|
||||||
|
```
|
||||||
|
/ min_x, min_y, max_depth -------------- / max_x, min_y, max_depth
|
||||||
|
/ /|
|
||||||
|
/ / | ^
|
||||||
|
/ min_depth min_depth / | |
|
||||||
|
min_x ----------------------------- max_x | | image
|
||||||
|
min_y min_y | | height
|
||||||
|
| | | |
|
||||||
|
| | | v
|
||||||
|
| | |
|
||||||
|
| | / max_x, max_y, ^
|
||||||
|
| | / max_depth /
|
||||||
|
min_x max_y / / n_pts_per_ray
|
||||||
|
max_y ----------------------------- max_x/ min_depth v
|
||||||
|
< --- image_width --- >
|
||||||
|
```
|
||||||
|
|
||||||
|
In order to generate ray points, `GridRaysampler` takes each 3D point of
|
||||||
|
the grid (with coordinates `[x, y, depth]`) and unprojects it
|
||||||
|
with `cameras.unproject_points([x, y, depth])`, where `cameras` are an
|
||||||
|
additional input to the `forward` function.
|
||||||
|
|
||||||
|
Note that this is a generic implementation that can support any image grid
|
||||||
|
coordinate convention. For a raysampler which follows the PyTorch3D
|
||||||
|
coordinate conventions please refer to `NDCGridRaysampler`.
|
||||||
|
As such, `NDCGridRaysampler` is a special case of `GridRaysampler`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_x: float,
|
||||||
|
max_x: float,
|
||||||
|
min_y: float,
|
||||||
|
max_y: float,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
n_pts_per_ray: int,
|
||||||
|
min_depth: float,
|
||||||
|
max_depth: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
min_x: The leftmost x-coordinate of each ray's source pixel's center.
|
||||||
|
max_x: The rightmost x-coordinate of each ray's source pixel's center.
|
||||||
|
min_y: The topmost y-coordinate of each ray's source pixel's center.
|
||||||
|
max_y: The bottommost y-coordinate of each ray's source pixel's center.
|
||||||
|
image_width: The horizontal size of the image grid.
|
||||||
|
image_height: The vertical size of the image grid.
|
||||||
|
n_pts_per_ray: The number of points sampled along each ray.
|
||||||
|
min_depth: The minimum depth of a ray-point.
|
||||||
|
max_depth: The maximum depth of a ray-point.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._n_pts_per_ray = n_pts_per_ray
|
||||||
|
self._min_depth = min_depth
|
||||||
|
self._max_depth = max_depth
|
||||||
|
|
||||||
|
# get the initial grid of image xy coords
|
||||||
|
_xy_grid = torch.stack(
|
||||||
|
tuple(
|
||||||
|
reversed(
|
||||||
|
torch.meshgrid(
|
||||||
|
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
|
||||||
|
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
self.register_buffer("_xy_grid", _xy_grid)
|
||||||
|
|
||||||
|
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||||
|
Returns:
|
||||||
|
A named tuple RayBundle with the following fields:
|
||||||
|
origins: A tensor of shape
|
||||||
|
`(batch_size, image_height, image_width, 3)`
|
||||||
|
denoting the locations of ray origins in the world coordinates.
|
||||||
|
directions: A tensor of shape
|
||||||
|
`(batch_size, image_height, image_width, 3)`
|
||||||
|
denoting the directions of each ray in the world coordinates.
|
||||||
|
lengths: A tensor of shape
|
||||||
|
`(batch_size, image_height, image_width, n_pts_per_ray)`
|
||||||
|
containing the z-coordinate (=depth) of each ray in world units.
|
||||||
|
xys: A tensor of shape
|
||||||
|
`(batch_size, image_height, image_width, 2)`
|
||||||
|
containing the 2D image coordinates of each ray.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size = cameras.R.shape[0] # pyre-ignore
|
||||||
|
|
||||||
|
device = cameras.device
|
||||||
|
|
||||||
|
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
|
||||||
|
xy_grid = self._xy_grid.to(device)[None].expand( # pyre-ignore
|
||||||
|
batch_size, *self._xy_grid.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
return _xy_to_ray_bundle(
|
||||||
|
cameras, xy_grid, self._min_depth, self._max_depth, self._n_pts_per_ray
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NDCGridRaysampler(GridRaysampler):
|
||||||
|
"""
|
||||||
|
Samples a fixed number of points along rays which are regulary distributed
|
||||||
|
in a batch of rectangular image grids. Points along each ray
|
||||||
|
have uniformly-spaced z-coordinates between a predefined minimum and maximum depth.
|
||||||
|
|
||||||
|
`NDCGridRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds`
|
||||||
|
renderers. I.e. the border of the leftmost / rightmost / topmost / bottommost pixel
|
||||||
|
has coordinates 1.0 / -1.0 / 1.0 / -1.0 respectively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int,
|
||||||
|
n_pts_per_ray: int,
|
||||||
|
min_depth: float,
|
||||||
|
max_depth: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
image_width: The horizontal size of the image grid.
|
||||||
|
image_height: The vertical size of the image grid.
|
||||||
|
n_pts_per_ray: The number of points sampled along each ray.
|
||||||
|
min_depth: The minimum depth of a ray-point.
|
||||||
|
max_depth: The maximum depth of a ray-point.
|
||||||
|
"""
|
||||||
|
half_pix_width = 1.0 / image_width
|
||||||
|
half_pix_height = 1.0 / image_height
|
||||||
|
super().__init__(
|
||||||
|
min_x=1.0 - half_pix_width,
|
||||||
|
max_x=-1.0 + half_pix_width,
|
||||||
|
min_y=1.0 - half_pix_height,
|
||||||
|
max_y=-1.0 + half_pix_height,
|
||||||
|
image_width=image_width,
|
||||||
|
image_height=image_height,
|
||||||
|
n_pts_per_ray=n_pts_per_ray,
|
||||||
|
min_depth=min_depth,
|
||||||
|
max_depth=max_depth,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MonteCarloRaysampler(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Samples a fixed number of pixels within denoted xy bounds uniformly at random.
|
||||||
|
For each pixel, a fixed number of points is sampled along its ray at uniformly-spaced
|
||||||
|
z-coordinates such that the z-coordinates range between a predefined minimum
|
||||||
|
and maximum depth.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_x: float,
|
||||||
|
max_x: float,
|
||||||
|
min_y: float,
|
||||||
|
max_y: float,
|
||||||
|
n_rays_per_image: int,
|
||||||
|
n_pts_per_ray: int,
|
||||||
|
min_depth: float,
|
||||||
|
max_depth: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
min_x: The smallest x-coordinate of each ray's source pixel.
|
||||||
|
max_x: The largest x-coordinate of each ray's source pixel.
|
||||||
|
min_y: The smallest y-coordinate of each ray's source pixel.
|
||||||
|
max_y: The largest y-coordinate of each ray's source pixel.
|
||||||
|
n_rays_per_image: The number of rays randomly sampled in each camera.
|
||||||
|
n_pts_per_ray: The number of points sampled along each ray.
|
||||||
|
min_depth: The minimum depth of each ray-point.
|
||||||
|
max_depth: The maximum depth of each ray-point.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._min_x = min_x
|
||||||
|
self._max_x = max_x
|
||||||
|
self._min_y = min_y
|
||||||
|
self._max_y = max_y
|
||||||
|
self._n_rays_per_image = n_rays_per_image
|
||||||
|
self._n_pts_per_ray = n_pts_per_ray
|
||||||
|
self._min_depth = min_depth
|
||||||
|
self._max_depth = max_depth
|
||||||
|
|
||||||
|
def forward(self, cameras: CamerasBase, **kwargs) -> RayBundle:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||||
|
Returns:
|
||||||
|
A named tuple RayBundle with the following fields:
|
||||||
|
origins: A tensor of shape
|
||||||
|
`(batch_size, n_rays_per_image, 3)`
|
||||||
|
denoting the locations of ray origins in the world coordinates.
|
||||||
|
directions: A tensor of shape
|
||||||
|
`(batch_size, n_rays_per_image, 3)`
|
||||||
|
denoting the directions of each ray in the world coordinates.
|
||||||
|
lengths: A tensor of shape
|
||||||
|
`(batch_size, n_rays_per_image, n_pts_per_ray)`
|
||||||
|
containing the z-coordinate (=depth) of each ray in world units.
|
||||||
|
xys: A tensor of shape
|
||||||
|
`(batch_size, n_rays_per_image, 2)`
|
||||||
|
containing the 2D image coordinates of each ray.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size = cameras.R.shape[0] # pyre-ignore
|
||||||
|
|
||||||
|
device = cameras.device
|
||||||
|
|
||||||
|
# get the initial grid of image xy coords
|
||||||
|
# of shape (batch_size, n_rays_per_image, 2)
|
||||||
|
rays_xy = torch.cat(
|
||||||
|
[
|
||||||
|
torch.rand(
|
||||||
|
size=(batch_size, self._n_rays_per_image, 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
* (high - low)
|
||||||
|
+ low
|
||||||
|
for low, high in (
|
||||||
|
(self._min_x, self._max_x),
|
||||||
|
(self._min_y, self._max_y),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _xy_to_ray_bundle(
|
||||||
|
cameras, rays_xy, self._min_depth, self._max_depth, self._n_pts_per_ray
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _xy_to_ray_bundle(
|
||||||
|
cameras: CamerasBase,
|
||||||
|
xy_grid: torch.Tensor,
|
||||||
|
min_depth: float,
|
||||||
|
max_depth: float,
|
||||||
|
n_pts_per_ray: int,
|
||||||
|
) -> RayBundle:
|
||||||
|
"""
|
||||||
|
Extends the `xy_grid` input of shape `(batch_size, ..., 2)` to rays.
|
||||||
|
This adds to each xy location in the grid a vector of `n_pts_per_ray` depths
|
||||||
|
uniformly spaced between `min_depth` and `max_depth`.
|
||||||
|
|
||||||
|
The extended grid is then unprojected with `cameras` to yield
|
||||||
|
ray origins, directions and depths.
|
||||||
|
"""
|
||||||
|
batch_size = xy_grid.shape[0]
|
||||||
|
spatial_size = xy_grid.shape[1:-1]
|
||||||
|
n_rays_per_image = spatial_size.numel() # pyre-ignore
|
||||||
|
|
||||||
|
# ray z-coords
|
||||||
|
depths = torch.linspace(
|
||||||
|
min_depth, max_depth, n_pts_per_ray, dtype=xy_grid.dtype, device=xy_grid.device
|
||||||
|
)
|
||||||
|
rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray)
|
||||||
|
|
||||||
|
# make two sets of points at a constant depth=1 and 2
|
||||||
|
to_unproject = torch.cat(
|
||||||
|
(
|
||||||
|
xy_grid.view(batch_size, 1, n_rays_per_image, 2)
|
||||||
|
.expand(batch_size, 2, n_rays_per_image, 2)
|
||||||
|
.reshape(batch_size, n_rays_per_image * 2, 2),
|
||||||
|
torch.cat(
|
||||||
|
(
|
||||||
|
xy_grid.new_ones(batch_size, n_rays_per_image, 1), # pyre-ignore
|
||||||
|
2.0 * xy_grid.new_ones(batch_size, n_rays_per_image, 1),
|
||||||
|
),
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# unproject the points
|
||||||
|
unprojected = cameras.unproject_points(to_unproject) # pyre-ignore
|
||||||
|
|
||||||
|
# split the two planes back
|
||||||
|
rays_plane_1_world = unprojected[:, :n_rays_per_image]
|
||||||
|
rays_plane_2_world = unprojected[:, n_rays_per_image:]
|
||||||
|
|
||||||
|
# directions are the differences between the two planes of points
|
||||||
|
rays_directions_world = rays_plane_2_world - rays_plane_1_world
|
||||||
|
|
||||||
|
# origins are given by subtracting the ray directions from the first plane
|
||||||
|
rays_origins_world = rays_plane_1_world - rays_directions_world
|
||||||
|
|
||||||
|
return RayBundle(
|
||||||
|
rays_origins_world.view(batch_size, *spatial_size, 3),
|
||||||
|
rays_directions_world.view(batch_size, *spatial_size, 3),
|
||||||
|
rays_zs.view(batch_size, *spatial_size, n_pts_per_ray),
|
||||||
|
xy_grid,
|
||||||
|
)
|
82
pytorch3d/renderer/implicit/utils.py
Normal file
82
pytorch3d/renderer/implicit/utils.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class RayBundle(NamedTuple):
|
||||||
|
"""
|
||||||
|
RayBundle parametrizes points along projection rays by storing ray `origins`,
|
||||||
|
`directions` vectors and `lengths` at which the ray-points are sampled.
|
||||||
|
Furthermore, the xy-locations (`xys`) of the ray pixels are stored as well.
|
||||||
|
"""
|
||||||
|
|
||||||
|
origins: torch.Tensor
|
||||||
|
directions: torch.Tensor
|
||||||
|
lengths: torch.Tensor
|
||||||
|
xys: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def ray_bundle_to_ray_points(ray_bundle: RayBundle) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts rays parametrized with a `ray_bundle` (an instance of the `RayBundle`
|
||||||
|
named tuple) to 3D points by extending each ray according to the corresponding
|
||||||
|
length.
|
||||||
|
|
||||||
|
E.g. for 2 dimensional tensors `ray_bundle.origins`, `ray_bundle.directions`
|
||||||
|
and `ray_bundle.lengths`, the ray point at position `[i, j]` is:
|
||||||
|
```
|
||||||
|
ray_bundle.points[i, j, :] = (
|
||||||
|
ray_bundle.origins[i, :]
|
||||||
|
+ ray_bundle.directions[i, :] * ray_bundle.lengths[i, j]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ray_bundle: A `RayBundle` object with fields:
|
||||||
|
origins: A tensor of shape `(..., 3)`
|
||||||
|
directions: A tensor of shape `(..., 3)`
|
||||||
|
lengths: A tensor of shape `(..., num_points_per_ray)`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
rays_points: A tensor of shape `(..., num_points_per_ray, 3)`
|
||||||
|
containing the points sampled along each ray.
|
||||||
|
"""
|
||||||
|
return ray_bundle_variables_to_ray_points(
|
||||||
|
ray_bundle.origins, ray_bundle.directions, ray_bundle.lengths
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ray_bundle_variables_to_ray_points(
|
||||||
|
rays_origins: torch.Tensor,
|
||||||
|
rays_directions: torch.Tensor,
|
||||||
|
rays_lengths: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts rays parametrized with origins, directions
|
||||||
|
to 3D points by extending each ray according to the corresponding
|
||||||
|
ray_length:
|
||||||
|
|
||||||
|
E.g. for 2 dimensional input tensors `rays_origins`, `rays_directions`
|
||||||
|
and `rays_lengths`, the ray point at position `[i, j]` is:
|
||||||
|
```
|
||||||
|
rays_points[i, j, :] = (
|
||||||
|
rays_origins[i, :]
|
||||||
|
+ rays_directions[i, :] * rays_lengths[i, j]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rays_origins: A tensor of shape `(..., 3)`
|
||||||
|
rays_directions: A tensor of shape `(..., 3)`
|
||||||
|
rays_lengths: A tensor of shape `(..., num_points_per_ray)`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
rays_points: A tensor of shape `(..., num_points_per_ray, 3)`
|
||||||
|
containing the points sampled along each ray.
|
||||||
|
"""
|
||||||
|
rays_points = (
|
||||||
|
rays_origins[..., None, :]
|
||||||
|
+ rays_lengths[..., :, None] * rays_directions[..., None, :]
|
||||||
|
)
|
||||||
|
return rays_points
|
39
tests/bm_raysampling.py
Normal file
39
tests/bm_raysampling.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
from fvcore.common.benchmark import benchmark
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
GridRaysampler,
|
||||||
|
MonteCarloRaysampler,
|
||||||
|
NDCGridRaysampler,
|
||||||
|
FoVOrthographicCameras,
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
OrthographicCameras,
|
||||||
|
PerspectiveCameras,
|
||||||
|
)
|
||||||
|
from test_raysampling import TestRaysampling
|
||||||
|
|
||||||
|
|
||||||
|
def bm_raysampling() -> None:
|
||||||
|
case_grid = {
|
||||||
|
"raysampler_type": [GridRaysampler, NDCGridRaysampler, MonteCarloRaysampler],
|
||||||
|
"camera_type": [
|
||||||
|
PerspectiveCameras,
|
||||||
|
OrthographicCameras,
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
FoVOrthographicCameras,
|
||||||
|
],
|
||||||
|
"batch_size": [1, 10],
|
||||||
|
"n_pts_per_ray": [10, 1000, 10000],
|
||||||
|
"image_width": [10, 300],
|
||||||
|
"image_height": [10, 300],
|
||||||
|
}
|
||||||
|
test_cases = itertools.product(*case_grid.values())
|
||||||
|
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
|
||||||
|
|
||||||
|
benchmark(TestRaysampling.raysampler, "RAYSAMPLER", kwargs_list, warmup_iters=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
bm_raysampling()
|
423
tests/test_raysampling.py
Normal file
423
tests/test_raysampling.py
Normal file
@ -0,0 +1,423 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin
|
||||||
|
from pytorch3d.ops import eyes
|
||||||
|
from pytorch3d.renderer import GridRaysampler, MonteCarloRaysampler, NDCGridRaysampler
|
||||||
|
from pytorch3d.renderer.cameras import (
|
||||||
|
FoVOrthographicCameras,
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
OrthographicCameras,
|
||||||
|
PerspectiveCameras,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.implicit.utils import (
|
||||||
|
ray_bundle_to_ray_points,
|
||||||
|
ray_bundle_variables_to_ray_points,
|
||||||
|
)
|
||||||
|
from pytorch3d.transforms import Rotate
|
||||||
|
from test_cameras import init_random_cameras
|
||||||
|
|
||||||
|
|
||||||
|
class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||||
|
def setUp(self) -> None:
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def raysampler(
|
||||||
|
raysampler_type=GridRaysampler,
|
||||||
|
camera_type=PerspectiveCameras,
|
||||||
|
n_pts_per_ray=10,
|
||||||
|
batch_size=1,
|
||||||
|
image_width=10,
|
||||||
|
image_height=20,
|
||||||
|
):
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
# init raysamplers
|
||||||
|
raysampler = TestRaysampling.init_raysampler(
|
||||||
|
raysampler_type=raysampler_type,
|
||||||
|
min_x=-1.0,
|
||||||
|
max_x=1.0,
|
||||||
|
min_y=-1.0,
|
||||||
|
max_y=1.0,
|
||||||
|
image_width=image_width,
|
||||||
|
image_height=image_height,
|
||||||
|
min_depth=1.0,
|
||||||
|
max_depth=10.0,
|
||||||
|
n_pts_per_ray=n_pts_per_ray,
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# init a batch of random cameras
|
||||||
|
cameras = init_random_cameras(camera_type, batch_size, random_z=True).to(device)
|
||||||
|
|
||||||
|
def run_raysampler():
|
||||||
|
raysampler(cameras=cameras)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return run_raysampler
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def init_raysampler(
|
||||||
|
raysampler_type=GridRaysampler,
|
||||||
|
min_x=-1.0,
|
||||||
|
max_x=1.0,
|
||||||
|
min_y=-1.0,
|
||||||
|
max_y=1.0,
|
||||||
|
image_width=10,
|
||||||
|
image_height=20,
|
||||||
|
min_depth=1.0,
|
||||||
|
max_depth=10.0,
|
||||||
|
n_pts_per_ray=10,
|
||||||
|
):
|
||||||
|
raysampler_params = {
|
||||||
|
"min_x": min_x,
|
||||||
|
"max_x": max_x,
|
||||||
|
"min_y": min_y,
|
||||||
|
"max_y": max_y,
|
||||||
|
"n_pts_per_ray": n_pts_per_ray,
|
||||||
|
"min_depth": min_depth,
|
||||||
|
"max_depth": max_depth,
|
||||||
|
}
|
||||||
|
|
||||||
|
if issubclass(raysampler_type, GridRaysampler):
|
||||||
|
raysampler_params.update(
|
||||||
|
{"image_width": image_width, "image_height": image_height}
|
||||||
|
)
|
||||||
|
elif issubclass(raysampler_type, MonteCarloRaysampler):
|
||||||
|
raysampler_params["n_rays_per_image"] = image_width * image_height
|
||||||
|
else:
|
||||||
|
raise ValueError(str(raysampler_type))
|
||||||
|
|
||||||
|
if issubclass(raysampler_type, NDCGridRaysampler):
|
||||||
|
# NDCGridRaysampler does not use min/max_x/y
|
||||||
|
for k in ("min_x", "max_x", "min_y", "max_y"):
|
||||||
|
del raysampler_params[k]
|
||||||
|
|
||||||
|
# instantiate the raysampler
|
||||||
|
raysampler = raysampler_type(**raysampler_params)
|
||||||
|
|
||||||
|
return raysampler
|
||||||
|
|
||||||
|
def test_raysamplers(
|
||||||
|
self,
|
||||||
|
batch_size=25,
|
||||||
|
min_x=-1.0,
|
||||||
|
max_x=1.0,
|
||||||
|
min_y=-1.0,
|
||||||
|
max_y=1.0,
|
||||||
|
image_width=10,
|
||||||
|
image_height=20,
|
||||||
|
min_depth=1.0,
|
||||||
|
max_depth=10.0,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Tests the shapes and outputs of MC and GridRaysamplers for randomly
|
||||||
|
generated cameras and different numbers of points per ray.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
for n_pts_per_ray in (100, 1):
|
||||||
|
|
||||||
|
for raysampler_type in (
|
||||||
|
MonteCarloRaysampler,
|
||||||
|
GridRaysampler,
|
||||||
|
NDCGridRaysampler,
|
||||||
|
):
|
||||||
|
|
||||||
|
raysampler = TestRaysampling.init_raysampler(
|
||||||
|
raysampler_type=raysampler_type,
|
||||||
|
min_x=min_x,
|
||||||
|
max_x=max_x,
|
||||||
|
min_y=min_y,
|
||||||
|
max_y=max_y,
|
||||||
|
image_width=image_width,
|
||||||
|
image_height=image_height,
|
||||||
|
min_depth=min_depth,
|
||||||
|
max_depth=max_depth,
|
||||||
|
n_pts_per_ray=n_pts_per_ray,
|
||||||
|
)
|
||||||
|
|
||||||
|
if issubclass(raysampler_type, NDCGridRaysampler):
|
||||||
|
# adjust the gt bounds for NDCGridRaysampler
|
||||||
|
half_pix_width = 1.0 / image_width
|
||||||
|
half_pix_height = 1.0 / image_height
|
||||||
|
min_x_ = 1.0 - half_pix_width
|
||||||
|
max_x_ = -1.0 + half_pix_width
|
||||||
|
min_y_ = 1.0 - half_pix_height
|
||||||
|
max_y_ = -1.0 + half_pix_height
|
||||||
|
else:
|
||||||
|
min_x_ = min_x
|
||||||
|
max_x_ = max_x
|
||||||
|
min_y_ = min_y
|
||||||
|
max_y_ = max_y
|
||||||
|
|
||||||
|
# carry out the test over several camera types
|
||||||
|
for cam_type in (
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
FoVOrthographicCameras,
|
||||||
|
OrthographicCameras,
|
||||||
|
PerspectiveCameras,
|
||||||
|
):
|
||||||
|
|
||||||
|
# init a batch of random cameras
|
||||||
|
cameras = init_random_cameras(
|
||||||
|
cam_type, batch_size, random_z=True
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# call the raysampler
|
||||||
|
ray_bundle = raysampler(cameras=cameras)
|
||||||
|
|
||||||
|
# check the shapes of the raysampler outputs
|
||||||
|
self._check_raysampler_output_shapes(
|
||||||
|
raysampler,
|
||||||
|
ray_bundle,
|
||||||
|
batch_size,
|
||||||
|
image_width,
|
||||||
|
image_height,
|
||||||
|
n_pts_per_ray,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the points sampled along each ray
|
||||||
|
self._check_raysampler_ray_points(
|
||||||
|
raysampler,
|
||||||
|
cameras,
|
||||||
|
ray_bundle,
|
||||||
|
min_x_,
|
||||||
|
max_x_,
|
||||||
|
min_y_,
|
||||||
|
max_y_,
|
||||||
|
image_width,
|
||||||
|
image_height,
|
||||||
|
min_depth,
|
||||||
|
max_depth,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the output direction vectors
|
||||||
|
self._check_raysampler_ray_directions(
|
||||||
|
cameras, raysampler, ray_bundle
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_grid_shape(self, grid, batch_size, spatial_size, n_pts_per_ray, dim):
|
||||||
|
"""
|
||||||
|
A helper for checking the desired size of a variable output by a raysampler.
|
||||||
|
"""
|
||||||
|
tgt_shape = [
|
||||||
|
x for x in [batch_size, *spatial_size, n_pts_per_ray, dim] if x > 0
|
||||||
|
]
|
||||||
|
self.assertTrue(all(sz1 == sz2 for sz1, sz2 in zip(grid.shape, tgt_shape)))
|
||||||
|
|
||||||
|
def _check_raysampler_output_shapes(
|
||||||
|
self,
|
||||||
|
raysampler,
|
||||||
|
ray_bundle,
|
||||||
|
batch_size,
|
||||||
|
image_width,
|
||||||
|
image_height,
|
||||||
|
n_pts_per_ray,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Checks the shapes of raysampler outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(raysampler, GridRaysampler):
|
||||||
|
spatial_size = [image_height, image_width]
|
||||||
|
elif isinstance(raysampler, MonteCarloRaysampler):
|
||||||
|
spatial_size = [image_height * image_width]
|
||||||
|
else:
|
||||||
|
raise ValueError(str(type(raysampler)))
|
||||||
|
|
||||||
|
self._check_grid_shape(ray_bundle.xys, batch_size, spatial_size, 0, 2)
|
||||||
|
self._check_grid_shape(ray_bundle.origins, batch_size, spatial_size, 0, 3)
|
||||||
|
self._check_grid_shape(ray_bundle.directions, batch_size, spatial_size, 0, 3)
|
||||||
|
self._check_grid_shape(
|
||||||
|
ray_bundle.lengths, batch_size, spatial_size, n_pts_per_ray, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_raysampler_ray_points(
|
||||||
|
self,
|
||||||
|
raysampler,
|
||||||
|
cameras,
|
||||||
|
ray_bundle,
|
||||||
|
min_x,
|
||||||
|
max_x,
|
||||||
|
min_y,
|
||||||
|
max_y,
|
||||||
|
image_width,
|
||||||
|
image_height,
|
||||||
|
min_depth,
|
||||||
|
max_depth,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check rays_points_world and rays_zs outputs of raysamplers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size = cameras.R.shape[0]
|
||||||
|
|
||||||
|
# convert to ray points
|
||||||
|
rays_points_world = ray_bundle_variables_to_ray_points(
|
||||||
|
ray_bundle.origins, ray_bundle.directions, ray_bundle.lengths
|
||||||
|
)
|
||||||
|
n_pts_per_ray = rays_points_world.shape[-2]
|
||||||
|
|
||||||
|
# check that the outputs if ray_bundle_variables_to_ray_points and
|
||||||
|
# ray_bundle_to_ray_points match
|
||||||
|
rays_points_world_ = ray_bundle_to_ray_points(ray_bundle)
|
||||||
|
self.assertClose(rays_points_world, rays_points_world_)
|
||||||
|
|
||||||
|
# check that the depth of each ray point in camera coords
|
||||||
|
# matches the expected linearly-spaced depth
|
||||||
|
depth_expected = torch.linspace(
|
||||||
|
min_depth,
|
||||||
|
max_depth,
|
||||||
|
n_pts_per_ray,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=rays_points_world.device,
|
||||||
|
)
|
||||||
|
ray_points_camera = (
|
||||||
|
cameras.get_world_to_view_transform()
|
||||||
|
.transform_points(rays_points_world.view(batch_size, -1, 3))
|
||||||
|
.view(batch_size, -1, n_pts_per_ray, 3)
|
||||||
|
)
|
||||||
|
self.assertClose(
|
||||||
|
ray_points_camera[..., 2],
|
||||||
|
depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]),
|
||||||
|
atol=1e-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check also that rays_zs is consistent with depth_expected
|
||||||
|
self.assertClose(
|
||||||
|
ray_bundle.lengths.view(batch_size, -1, n_pts_per_ray),
|
||||||
|
depth_expected[None, None, :].expand_as(ray_points_camera[..., 2]),
|
||||||
|
atol=1e-6,
|
||||||
|
)
|
||||||
|
|
||||||
|
# project the world ray points back to screen space
|
||||||
|
ray_points_projected = cameras.transform_points(
|
||||||
|
rays_points_world.view(batch_size, -1, 3)
|
||||||
|
).view(rays_points_world.shape)
|
||||||
|
|
||||||
|
# check that ray_xy matches rays_points_projected xy
|
||||||
|
rays_xy_projected = ray_points_projected[..., :2].view(
|
||||||
|
batch_size, -1, n_pts_per_ray, 2
|
||||||
|
)
|
||||||
|
self.assertClose(
|
||||||
|
ray_bundle.xys.view(batch_size, -1, 1, 2).expand_as(rays_xy_projected),
|
||||||
|
rays_xy_projected,
|
||||||
|
atol=1e-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that projected world points' xy coordinates
|
||||||
|
# range correctly between [minx/y, max/y]
|
||||||
|
if isinstance(raysampler, GridRaysampler):
|
||||||
|
# get the expected coordinates along each grid axis
|
||||||
|
ys, xs = [
|
||||||
|
torch.linspace(
|
||||||
|
low, high, sz, dtype=torch.float32, device=rays_points_world.device
|
||||||
|
)
|
||||||
|
for low, high, sz in (
|
||||||
|
(min_y, max_y, image_height),
|
||||||
|
(min_x, max_x, image_width),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
# compare expected xy with the output xy
|
||||||
|
for dim, gt_axis in zip(
|
||||||
|
(0, 1), (xs[None, None, :, None], ys[None, :, None, None])
|
||||||
|
):
|
||||||
|
self.assertClose(
|
||||||
|
ray_points_projected[..., dim],
|
||||||
|
gt_axis.expand_as(ray_points_projected[..., dim]),
|
||||||
|
atol=1e-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(raysampler, MonteCarloRaysampler):
|
||||||
|
# check that the randomly sampled locations
|
||||||
|
# are within the allowed bounds for both x and y axes
|
||||||
|
for dim, axis_bounds in zip((0, 1), ((min_x, max_x), (min_y, max_y))):
|
||||||
|
self.assertTrue(
|
||||||
|
(
|
||||||
|
(ray_points_projected[..., dim] <= axis_bounds[1])
|
||||||
|
& (ray_points_projected[..., dim] >= axis_bounds[0])
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
# also check that x,y along each ray is constant
|
||||||
|
if n_pts_per_ray > 1:
|
||||||
|
self.assertClose(
|
||||||
|
ray_points_projected[..., :2].std(dim=-2),
|
||||||
|
torch.zeros_like(ray_points_projected[..., 0, :2]),
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(str(type(raysampler)))
|
||||||
|
|
||||||
|
def _check_raysampler_ray_directions(self, cameras, raysampler, ray_bundle):
|
||||||
|
"""
|
||||||
|
Check the rays_directions_world output of raysamplers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size = cameras.R.shape[0]
|
||||||
|
n_pts_per_ray = ray_bundle.lengths.shape[-1]
|
||||||
|
spatial_size = ray_bundle.xys.shape[1:-1]
|
||||||
|
n_rays_per_image = spatial_size.numel()
|
||||||
|
|
||||||
|
# obtain the ray points in world coords
|
||||||
|
rays_points_world = cameras.unproject_points(
|
||||||
|
torch.cat(
|
||||||
|
(
|
||||||
|
ray_bundle.xys.view(batch_size, n_rays_per_image, 1, 2).expand(
|
||||||
|
batch_size, n_rays_per_image, n_pts_per_ray, 2
|
||||||
|
),
|
||||||
|
ray_bundle.lengths.view(
|
||||||
|
batch_size, n_rays_per_image, n_pts_per_ray, 1
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
).view(batch_size, -1, 3)
|
||||||
|
).view(batch_size, -1, n_pts_per_ray, 3)
|
||||||
|
|
||||||
|
# reshape to common testing size
|
||||||
|
rays_directions_world_normed = torch.nn.functional.normalize(
|
||||||
|
ray_bundle.directions.view(batch_size, -1, 3), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that the l2-normed difference of all consecutive planes
|
||||||
|
# of points in world coords matches ray_directions_world
|
||||||
|
rays_directions_world_ = torch.nn.functional.normalize(
|
||||||
|
rays_points_world[:, :, -1:] - rays_points_world[:, :, :-1], dim=-1
|
||||||
|
)
|
||||||
|
self.assertClose(
|
||||||
|
rays_directions_world_normed[:, :, None].expand_as(rays_directions_world_),
|
||||||
|
rays_directions_world_,
|
||||||
|
atol=1e-4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the ray directions rotated using camera rotation matrix
|
||||||
|
# match the ray directions of a camera with trivial extrinsics
|
||||||
|
cameras_trivial_extrinsic = cameras.clone()
|
||||||
|
cameras_trivial_extrinsic.R = eyes(
|
||||||
|
N=batch_size, dim=3, dtype=cameras.R.dtype, device=cameras.device
|
||||||
|
)
|
||||||
|
cameras_trivial_extrinsic.T = torch.zeros_like(cameras.T)
|
||||||
|
|
||||||
|
# make sure we get the same random rays in case we call the
|
||||||
|
# MonteCarloRaysampler twice below
|
||||||
|
with torch.random.fork_rng(devices=range(torch.cuda.device_count())):
|
||||||
|
torch.random.manual_seed(42)
|
||||||
|
ray_bundle_world_fix_seed = raysampler(cameras=cameras)
|
||||||
|
torch.random.manual_seed(42)
|
||||||
|
ray_bundle_camera_fix_seed = raysampler(cameras=cameras_trivial_extrinsic)
|
||||||
|
|
||||||
|
rays_directions_camera_fix_seed_ = Rotate(
|
||||||
|
cameras.R, device=cameras.R.device
|
||||||
|
).transform_points(ray_bundle_world_fix_seed.directions.view(batch_size, -1, 3))
|
||||||
|
|
||||||
|
self.assertClose(
|
||||||
|
rays_directions_camera_fix_seed_,
|
||||||
|
ray_bundle_camera_fix_seed.directions.view(batch_size, -1, 3),
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user