pulsar interface unification.

Summary:
This diff builds on top of the `pulsar integration` diff to provide a unified interface for the existing PyTorch3D point renderer and Pulsar. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder docs/examples.

The unified interfaces are completely consistent. Switching the render backend is as easy as using `renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device)` instead of `renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)` and adding the `gamma` parameter to the forward function. All PyTorch3D camera types are supported as far as possible; keyword arguments are properly forwarded to the camera. The `PerspectiveCamera` and `OrthographicCamera` require znear and zfar as additional parameters for the forward pass.

Reviewed By: nikhilaravi

Differential Revision: D21421443

fbshipit-source-id: 4aa0a83a419592d9a0bb5d62486a1cdea9d73ce6
This commit is contained in:
Christoph Lassner
2020-11-03 13:05:02 -08:00
committed by Facebook GitHub Bot
parent b19fe1de2f
commit 960fd6d8b6
18 changed files with 695 additions and 313 deletions

View File

@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .compositor import AlphaCompositor, NormWeightedCompositor
from .pulsar.unified import PulsarPointsRenderer
from .rasterize_points import rasterize_points
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
from .renderer import PointsRenderer

View File

@@ -369,6 +369,7 @@ class Renderer(torch.nn.Module):
height: int,
orthogonal: bool,
right_handed: bool,
first_R_then_T: bool = False,
) -> Tuple[
torch.Tensor,
torch.Tensor,
@@ -401,6 +402,8 @@ class Renderer(torch.nn.Module):
(does not use focal length).
* right_handed: bool, whether to use a right handed system
(negative z in camera direction).
* first_R_then_T: bool, whether to first rotate, then translate
the camera (PyTorch3D convention).
Returns:
* pos_vec: the position vector in 3D,
@@ -460,16 +463,18 @@ class Renderer(torch.nn.Module):
# Always get quadratic pixels.
pixel_size_x = sensor_size_x / float(width)
sensor_size_y = height * pixel_size_x
if continuous_rep:
rot_mat = rotation_6d_to_matrix(rot_vec)
else:
rot_mat = axis_angle_to_matrix(rot_vec)
if first_R_then_T:
pos_vec = torch.matmul(rot_mat, pos_vec[..., None])[:, :, 0]
LOGGER.debug(
"Camera position: %s, rotation: %s. Focal length: %s.",
str(pos_vec),
str(rot_vec),
str(focal_length),
)
if continuous_rep:
rot_mat = rotation_6d_to_matrix(rot_vec)
else:
rot_mat = axis_angle_to_matrix(rot_vec)
sensor_dir_x = torch.matmul(
rot_mat,
torch.tensor(
@@ -576,6 +581,7 @@ class Renderer(torch.nn.Module):
max_n_hits: int = _C.MAX_UINT,
mode: int = 0,
return_forward_info: bool = False,
first_R_then_T: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""
Rendering pass to create an image from the provided spheres and camera
@@ -616,6 +622,8 @@ class Renderer(torch.nn.Module):
the float encoded integer index of a sphere and its weight. They are the
five spheres with the highest color contribution to this pixel color,
ordered descending. Default: False.
* first_R_then_T: bool, whether to first apply rotation to the camera,
then translation (PyTorch3D convention). Default: False.
Returns:
* image: [Bx]HxWx3 float tensor with the resulting image.
@@ -638,6 +646,7 @@ class Renderer(torch.nn.Module):
self._renderer.height,
self._renderer.orthogonal,
self._renderer.right_handed,
first_R_then_T=first_R_then_T,
)
if (
focal_lengths.min().item() > 0.0

View File

@@ -0,0 +1,524 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
import warnings
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ....transforms import matrix_to_rotation_6d
from ...cameras import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
OrthographicCameras,
PerspectiveCameras,
)
from ..compositor import AlphaCompositor, NormWeightedCompositor
from ..rasterizer import PointsRasterizer
from .renderer import Renderer as PulsarRenderer
class PulsarPointsRenderer(nn.Module):
"""
This renderer is a PyTorch3D interface wrapper around the pulsar renderer.
It provides an interface consistent with PyTorch3D Pointcloud rendering.
It will extract all necessary information from the rasterizer and compositor
objects and convert them to the pulsar required format, then invoke rendering
in the pulsar renderer. All gradients are handled appropriately through the
wrapper and the wrapper should provide equivalent results to using the pulsar
renderer directly.
"""
def __init__(
self,
rasterizer: PointsRasterizer,
compositor: Optional[Union[NormWeightedCompositor, AlphaCompositor]] = None,
n_channels: int = 3,
max_num_spheres: int = int(1e6), # noqa: B008
):
"""
rasterizer (PointsRasterizer): An object encapsulating rasterization parameters.
compositor (ignored): Only keeping this for interface consistency. Default: None.
n_channels (int): The number of channels of the resulting image. Default: 3.
max_num_spheres (int): The maximum number of spheres intended to render with
this renderer. Default: 1e6.
"""
super().__init__()
self.rasterizer = rasterizer
if compositor is not None:
warnings.warn(
"Creating a `PulsarPointsRenderer` with a compositor object! "
"This object is ignored and just allowed as an argument for interface "
"compatibility."
)
# Initialize the pulsar renderers.
if not isinstance(
rasterizer.cameras,
(
FoVOrthographicCameras,
FoVPerspectiveCameras,
PerspectiveCameras,
OrthographicCameras,
),
):
raise ValueError(
"Only FoVPerspectiveCameras, PerspectiveCameras, "
"FoVOrthographicCameras and OrthographicCameras are supported "
"by the pulsar backend."
)
if isinstance(rasterizer.raster_settings.image_size, tuple):
width, height = rasterizer.raster_settings.image_size
else:
width = rasterizer.raster_settings.image_size
height = rasterizer.raster_settings.image_size
# Making sure about integer types.
width = int(width)
height = int(height)
max_num_spheres = int(max_num_spheres)
orthogonal_projection = isinstance(
rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
)
n_channels = int(n_channels)
self.renderer = PulsarRenderer(
width=width,
height=height,
max_num_balls=max_num_spheres,
orthogonal_projection=orthogonal_projection,
right_handed_system=True,
n_channels=n_channels,
)
def _conf_check(self, point_clouds, kwargs: Dict[str, Any]) -> bool:
"""
Verify internal configuration state with kwargs and pointclouds.
This method will raise ValueError's for any inconsistencies found. It
returns whether an orthogonal projection will be used.
"""
if "gamma" not in kwargs.keys():
raise ValueError(
"gamma is a required keyword argument for the PulsarPointsRenderer!"
)
if (
len(point_clouds) != len(self.rasterizer.cameras)
and len(self.rasterizer.cameras) != 1
):
raise ValueError(
(
"The len(point_clouds) must either be equal to len(rasterizer.cameras) or "
"only one camera must be used. len(point_clouds): %d, "
"len(rasterizer.cameras): %d."
)
% (
len(point_clouds),
len(self.rasterizer.cameras),
)
)
# Make sure the rasterizer and cameras objects have no
# changes that can't be matched.
orthogonal_projection = isinstance(
self.rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
)
if orthogonal_projection != self.renderer._renderer.orthogonal:
raise ValueError(
"The camera type can not be changed after renderer initialization! "
"Current camera orthogonal: %r. Original orthogonal: %r."
) % (orthogonal_projection, self.renderer._renderer.orthogonal)
if (
isinstance(self.rasterizer.raster_settings.image_size, tuple)
and self.rasterizer.raster_settings.image_size[0]
!= self.renderer._renderer.width
) or (
not isinstance(self.rasterizer.raster_settings.image_size, tuple)
and self.rasterizer.raster_settings.image_size
!= self.renderer._renderer.width
):
raise ValueError(
(
"The rasterizer width and height can not be changed after renderer "
"initialization! Current width: %d. Original width: %d."
)
% (
self.rasterizer.raster_settings.image_size,
self.renderer._renderer.width,
)
)
if (
isinstance(self.rasterizer.raster_settings.image_size, tuple)
and self.rasterizer.raster_settings.image_size[1]
!= self.renderer._renderer.height
) or (
not isinstance(self.rasterizer.raster_settings.image_size, tuple)
and self.rasterizer.raster_settings.image_size
!= self.renderer._renderer.height
):
raise ValueError(
(
"The rasterizer width and height can not be changed after renderer "
"initialization! Current height: %d. Original height: %d."
)
% (
self.rasterizer.raster_settings.image_size,
self.renderer._renderer.height,
)
)
return orthogonal_projection
def _extract_intrinsics(
self, orthogonal_projection, kwargs, cloud_idx
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]:
"""
Translate the camera intrinsics from PyTorch3D format to pulsar format.
"""
# Shorthand:
cameras = self.rasterizer.cameras
if orthogonal_projection:
focal_length = 0.0
if isinstance(cameras, FoVOrthographicCameras):
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `znear`.
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `zfar`.
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_y`.
max_y = kwargs.get("max_y", cameras.max_y)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_y`.
min_y = kwargs.get("min_y", cameras.min_y)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_x`.
max_x = kwargs.get("max_x", cameras.max_x)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_x`.
min_x = kwargs.get("min_x", cameras.min_x)[cloud_idx]
if max_y != -min_y:
raise ValueError(
"The orthographic camera must be centered around 0. "
f"Max is {max_y} and min is {min_y}."
)
if max_x != -min_x:
raise ValueError(
"The orthographic camera must be centered around 0. "
f"Max is {max_x} and min is {min_x}."
)
if not torch.all(
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `scale_xyz`.
kwargs.get("scale_xyz", cameras.scale_xyz)[cloud_idx]
== 1.0
):
raise ValueError(
"The orthographic camera scale must be ((1.0, 1.0, 1.0),). "
f"{kwargs.get('scale_xyz', cameras.scale_xyz)[cloud_idx]}."
)
sensor_width = max_x - min_x
if not sensor_width > 0.0:
raise ValueError(
f"The orthographic camera must have positive size! Is: {sensor_width}." # noqa: B950
)
principal_point_x, principal_point_y = 0.0, 0.0
else:
# Currently, this means it must be an 'OrthographicCameras' object.
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
cloud_idx
]
if (
focal_length_conf.numel() == 2
and focal_length_conf[0] * self.renderer._renderer.width
- focal_length_conf[1] * self.renderer._renderer.height
> 1e-5
):
raise ValueError(
"Pulsar only supports a single focal length! "
"Provided: %s." % (str(focal_length_conf))
)
if focal_length_conf.numel() == 2:
sensor_width = 2.0 / focal_length_conf[0]
else:
if focal_length_conf.numel() != 1:
raise ValueError(
"Focal length not parsable: %s." % (str(focal_length_conf))
)
sensor_width = 2.0 / focal_length_conf
if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys():
raise ValueError(
"pulsar needs znear and zfar values for "
"the OrthographicCameras. Please provide them as keyword "
"argument to the forward method."
)
znear = kwargs["znear"][cloud_idx]
zfar = kwargs["zfar"][cloud_idx]
principal_point_x = (
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
* 0.5
* self.renderer._renderer.width
)
principal_point_y = (
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
* 0.5
* self.renderer._renderer.height
)
else:
if not isinstance(cameras, PerspectiveCameras):
# Create a virtual focal length that is closer than znear.
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
focal_length = znear - 1e-6
# Create a sensor size that matches the expected fov assuming this f.
afov = kwargs.get("fov", cameras.fov)[cloud_idx]
if kwargs.get("degrees", cameras.degrees):
afov *= math.pi / 180.0
sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length
if not (
kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx]
- self.renderer._renderer.width / self.renderer._renderer.height
< 1e-6
):
raise ValueError(
"The aspect ratio ("
f"{kwargs.get('aspect_ratio', cameras.aspect_ratio)[cloud_idx]}) "
"must agree with the resolution width / height ("
f"{self.renderer._renderer.width / self.renderer._renderer.height})." # noqa: B950
)
principal_point_x, principal_point_y = 0.0, 0.0
else:
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `focal_length`.
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
cloud_idx
]
if (
focal_length_conf.numel() == 2
and focal_length_conf[0] * self.renderer._renderer.width
- focal_length_conf[1] * self.renderer._renderer.height
> 1e-5
):
raise ValueError(
"Pulsar only supports a single focal length! "
"Provided: %s." % (str(focal_length_conf))
)
if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys():
raise ValueError(
"pulsar needs znear and zfar values for "
"the PerspectiveCameras. Please provide them as keyword "
"argument to the forward method."
)
znear = kwargs["znear"][cloud_idx]
zfar = kwargs["zfar"][cloud_idx]
if focal_length_conf.numel() == 2:
focal_length_px = focal_length_conf[0]
else:
if focal_length_conf.numel() != 1:
raise ValueError(
"Focal length not parsable: %s." % (str(focal_length_conf))
)
focal_length_px = focal_length_conf
focal_length = znear - 1e-6
sensor_width = focal_length / focal_length_px * 2.0
principal_point_x = (
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
* 0.5
* self.renderer._renderer.width
)
principal_point_y = (
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
* 0.5
* self.renderer._renderer.height
)
return (
focal_length,
sensor_width,
principal_point_x,
principal_point_y,
znear,
zfar,
)
def _extract_extrinsics(
self, kwargs, cloud_idx
) -> Tuple[torch.Tensor, torch.Tensor]:
# Shorthand:
cameras = self.rasterizer.cameras
R = kwargs.get("R", cameras.R)[cloud_idx]
T = kwargs.get("T", cameras.T)[cloud_idx]
norm_mat = torch.tensor(
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
dtype=torch.float32,
device=R.device,
)
cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...])
cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None]))
cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot))
return cam_pos, cam_rot
def _get_vert_rad(
self, vert_pos, cam_pos, orthogonal_projection, focal_length, kwargs, cloud_idx
) -> torch.Tensor:
"""
Get point radiuses.
These can be depending on the camera position in case of a perspective
transform.
"""
# Normalize point radiuses.
# `self.rasterizer.raster_settings.radius` can either be a float
# or itself a tensor.
raster_rad = self.rasterizer.raster_settings.radius
if kwargs.get("radius_world", False):
return raster_rad
if isinstance(raster_rad, torch.Tensor) and raster_rad.numel() > 1:
# In this case it must be a batched torch tensor.
raster_rad = raster_rad[cloud_idx]
if orthogonal_projection:
vert_rad = (
torch.ones(
(vert_pos.shape[0],), dtype=torch.float32, device=vert_pos.device
)
* raster_rad
)
else:
point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False)
vert_rad = raster_rad / focal_length * point_dists
if isinstance(self.rasterizer.cameras, PerspectiveCameras):
# NDC normalization happens through adjusted focal length.
pass
else:
vert_rad = vert_rad / 2.0 # NDC normalization.
return vert_rad
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
"""
Get the rendering of the provided `Pointclouds`.
The number of point clouds in the `Pointclouds` object determines the
number of resulting images. The provided cameras can be either 1 or equal
to the number of pointclouds (in the first case, the same camera will be
used for all clouds, in the latter case each point cloud will be rendered
with the corresponding camera).
The following kwargs are support from PyTorch3D (depending on the selected
camera model potentially overriding camera parameters):
radius_world (bool): use the provided radiuses from the raster_settings
plain as radiuses in world space. Default: False.
znear (Iterable[float]): near geometry cutoff. Is required for
OrthographicCameras and PerspectiveCameras.
zfar (Iterable[float]): far geometry cutoff. Is required for
OrthographicCameras and PerspectiveCameras.
R (torch.Tensor): [Bx3x3] camera rotation matrices.
T (torch.Tensor): [Bx3] camera translation vectors.
principal_point (torch.Tensor): [Bx2] camera intrinsic principal
point offset vectors.
focal_length (torch.Tensor): [Bx1] camera intrinsic focal lengths.
aspect_ratio (Iterable[float]): camera aspect ratios.
fov (Iterable[float]): camera FOVs.
degrees (bool): whether FOVs are specified in degrees or
radians.
min_x (Iterable[float]): minimum x for the FoVOrthographicCameras.
max_x (Iterable[float]): maximum x for the FoVOrthographicCameras.
min_y (Iterable[float]): minimum y for the FoVOrthographicCameras.
max_y (Iterable[float]): maximum y for the FoVOrthographicCameras.
The following kwargs are supported from pulsar:
gamma (float): The gamma value to use. This defines the transparency for
differentiability (see pulsar paper for details). Must be in [1., 1e-5]
with 1.0 being mostly transparent. This keyword argument is *required*!
bg_col (torch.Tensor): The background color. Must be a tensor on the same
device as the point clouds, with as many channels as features (no batch
dimension - it is the same for all images in the batch).
Default: 0.0 for all channels.
percent_allowed_difference (float): a value in [0., 1.[ with the maximum
allowed difference in channel space. This is used to speed up the
computation. Default: 0.01.
max_n_hits (int): a hard limit on the number of sphere hits per ray.
Default: max int.
mode (int): render mode in {0, 1}. 0: render image; 1: render hit map.
"""
orthogonal_projection: bool = self._conf_check(point_clouds, kwargs)
# Get access to inputs. We're using the list accessor and process
# them sequentially.
position_list = point_clouds.points_list()
features_list = point_clouds.features_list()
# Result list.
images = []
for cloud_idx, (vert_pos, vert_col) in enumerate(
zip(position_list, features_list)
):
# Get intrinsics.
(
focal_length,
sensor_width,
principal_point_x,
principal_point_y,
znear,
zfar,
) = self._extract_intrinsics(orthogonal_projection, kwargs, cloud_idx)
# Get extrinsics.
cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx)
# Put everything together.
cam_params = torch.cat(
(
cam_pos,
cam_rot,
torch.tensor(
[
focal_length,
sensor_width,
principal_point_x,
principal_point_y,
],
dtype=torch.float32,
device=cam_pos.device,
),
)
)
# Get point radiuses (can depend on camera position).
vert_rad = self._get_vert_rad(
vert_pos,
cam_pos,
orthogonal_projection,
focal_length,
kwargs,
cloud_idx,
)
# Clean kwargs for passing on.
gamma = kwargs["gamma"][cloud_idx]
if "first_R_then_T" in kwargs.keys():
raise ValueError("`first_R_then_T` is not supported in this interface.")
otherargs = {
argn: argv
for argn, argv in kwargs.items()
if argn
not in [
"radius_world",
"gamma",
"znear",
"zfar",
"R",
"T",
"principal_point",
"focal_length",
"aspect_ratio",
"fov",
"degrees",
"min_x",
"max_x",
"min_y",
"max_y",
]
}
# background color
if "bg_col" not in otherargs:
bg_col = torch.zeros(
vert_col.shape[1], device=cam_params.device, dtype=torch.float32
)
otherargs["bg_col"] = bg_col
# Go!
images.append(
self.renderer(
vert_pos=vert_pos,
vert_col=vert_col,
vert_rad=vert_rad,
cam_params=cam_params,
gamma=gamma,
max_depth=zfar,
min_depth=znear,
**otherargs,
)
)
return torch.stack(images, dim=0)

View File

@@ -1,7 +1,5 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
@@ -48,7 +46,7 @@ class PointsRenderer(nn.Module):
fragments.idx.long().permute(0, 3, 1, 2),
weights,
point_clouds.features_packed().permute(1, 0),
**kwargs
**kwargs,
)
# permute so image comes at the end