examples and docs.

Summary: This diff updates the documentation and tutorials with information about the new pulsar backend. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder docs/examples.

Reviewed By: nikhilaravi

Differential Revision: D24498129

fbshipit-source-id: e312b0169a72b13590df6e4db36bfe6190d742f9
This commit is contained in:
Christoph Lassner
2020-11-03 13:05:02 -08:00
committed by Facebook GitHub Bot
parent 960fd6d8b6
commit 039e02601d
21 changed files with 759 additions and 60 deletions

View File

@@ -18,6 +18,15 @@ from ..rasterizer import PointsRasterizer
from .renderer import Renderer as PulsarRenderer
def _ensure_float_tensor(val_in, device):
"""Make sure that the value provided is wrapped a PyTorch float tensor."""
if not isinstance(val_in, torch.Tensor):
val_out = torch.tensor(val_in, dtype=torch.float32, device=device).reshape((1,))
else:
val_out = val_in.to(torch.float32).to(device).reshape((1,))
return val_out
class PulsarPointsRenderer(nn.Module):
"""
This renderer is a PyTorch3D interface wrapper around the pulsar renderer.
@@ -36,6 +45,7 @@ class PulsarPointsRenderer(nn.Module):
compositor: Optional[Union[NormWeightedCompositor, AlphaCompositor]] = None,
n_channels: int = 3,
max_num_spheres: int = int(1e6), # noqa: B008
**kwargs,
):
"""
rasterizer (PointsRasterizer): An object encapsulating rasterization parameters.
@@ -43,6 +53,8 @@ class PulsarPointsRenderer(nn.Module):
n_channels (int): The number of channels of the resulting image. Default: 3.
max_num_spheres (int): The maximum number of spheres intended to render with
this renderer. Default: 1e6.
kwargs (Any): kwargs to pass on to the pulsar renderer.
See `pytorch3d.renderer.points.pulsar.renderer.Renderer` for all options.
"""
super().__init__()
self.rasterizer = rasterizer
@@ -87,6 +99,7 @@ class PulsarPointsRenderer(nn.Module):
orthogonal_projection=orthogonal_projection,
right_handed_system=True,
n_channels=n_channels,
**kwargs,
)
def _conf_check(self, point_clouds, kwargs: Dict[str, Any]) -> bool:
@@ -165,8 +178,8 @@ class PulsarPointsRenderer(nn.Module):
)
return orthogonal_projection
def _extract_intrinsics(
self, orthogonal_projection, kwargs, cloud_idx
def _extract_intrinsics( # noqa: C901
self, orthogonal_projection, kwargs, cloud_idx, device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]:
"""
Translate the camera intrinsics from PyTorch3D format to pulsar format.
@@ -174,7 +187,7 @@ class PulsarPointsRenderer(nn.Module):
# Shorthand:
cameras = self.rasterizer.cameras
if orthogonal_projection:
focal_length = 0.0
focal_length = torch.zeros((1,), dtype=torch.float32)
if isinstance(cameras, FoVOrthographicCameras):
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `znear`.
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
@@ -212,7 +225,10 @@ class PulsarPointsRenderer(nn.Module):
raise ValueError(
f"The orthographic camera must have positive size! Is: {sensor_width}." # noqa: B950
)
principal_point_x, principal_point_y = 0.0, 0.0
principal_point_x, principal_point_y = (
torch.zeros((1,), dtype=torch.float32),
torch.zeros((1,), dtype=torch.float32),
)
else:
# Currently, this means it must be an 'OrthographicCameras' object.
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
@@ -276,7 +292,10 @@ class PulsarPointsRenderer(nn.Module):
"must agree with the resolution width / height ("
f"{self.renderer._renderer.width / self.renderer._renderer.height})." # noqa: B950
)
principal_point_x, principal_point_y = 0.0, 0.0
principal_point_x, principal_point_y = (
torch.zeros((1,), dtype=torch.float32),
torch.zeros((1,), dtype=torch.float32),
)
else:
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `focal_length`.
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
@@ -308,7 +327,13 @@ class PulsarPointsRenderer(nn.Module):
"Focal length not parsable: %s." % (str(focal_length_conf))
)
focal_length_px = focal_length_conf
focal_length = znear - 1e-6
focal_length = torch.tensor(
[
znear - 1e-6,
],
dtype=torch.float32,
device=focal_length_px.device,
)
sensor_width = focal_length / focal_length_px * 2.0
principal_point_x = (
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
@@ -321,6 +346,12 @@ class PulsarPointsRenderer(nn.Module):
* 0.5
* self.renderer._renderer.height
)
focal_length = _ensure_float_tensor(focal_length, device)
sensor_width = _ensure_float_tensor(sensor_width, device)
principal_point_x = _ensure_float_tensor(principal_point_x, device)
principal_point_y = _ensure_float_tensor(principal_point_y, device)
znear = _ensure_float_tensor(znear, device)
zfar = _ensure_float_tensor(zfar, device)
return (
focal_length,
sensor_width,
@@ -338,11 +369,17 @@ class PulsarPointsRenderer(nn.Module):
R = kwargs.get("R", cameras.R)[cloud_idx]
T = kwargs.get("T", cameras.T)[cloud_idx]
norm_mat = torch.tensor(
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
dtype=torch.float32,
device=R.device,
)
cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...])
cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]).permute((0, 2, 1))
norm_mat = torch.tensor(
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=torch.float32,
device=R.device,
)
cam_rot = torch.matmul(norm_mat, cam_rot)
cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None]))
cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot))
return cam_pos, cam_rot
@@ -374,7 +411,7 @@ class PulsarPointsRenderer(nn.Module):
)
else:
point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False)
vert_rad = raster_rad / focal_length * point_dists
vert_rad = raster_rad / focal_length.to(vert_pos.device) * point_dists
if isinstance(self.rasterizer.cameras, PerspectiveCameras):
# NDC normalization happens through adjusted focal length.
pass
@@ -382,6 +419,7 @@ class PulsarPointsRenderer(nn.Module):
vert_rad = vert_rad / 2.0 # NDC normalization.
return vert_rad
# point_clouds is not typed to avoid a cyclic dependency.
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
"""
Get the rendering of the provided `Pointclouds`.
@@ -439,6 +477,8 @@ class PulsarPointsRenderer(nn.Module):
for cloud_idx, (vert_pos, vert_col) in enumerate(
zip(position_list, features_list)
):
# Get extrinsics.
cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx)
# Get intrinsics.
(
focal_length,
@@ -447,23 +487,21 @@ class PulsarPointsRenderer(nn.Module):
principal_point_y,
znear,
zfar,
) = self._extract_intrinsics(orthogonal_projection, kwargs, cloud_idx)
# Get extrinsics.
cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx)
) = self._extract_intrinsics(
orthogonal_projection, kwargs, cloud_idx, cam_pos.device
)
# Put everything together.
cam_params = torch.cat(
(
cam_pos,
cam_rot,
torch.tensor(
cam_rot.to(cam_pos.device),
torch.cat(
[
focal_length,
sensor_width,
principal_point_x,
principal_point_y,
],
dtype=torch.float32,
device=cam_pos.device,
),
)
)