mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-23 07:40:34 +08:00
examples and docs.
Summary: This diff updates the documentation and tutorials with information about the new pulsar backend. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder docs/examples. Reviewed By: nikhilaravi Differential Revision: D24498129 fbshipit-source-id: e312b0169a72b13590df6e4db36bfe6190d742f9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
960fd6d8b6
commit
039e02601d
@@ -18,6 +18,15 @@ from ..rasterizer import PointsRasterizer
|
||||
from .renderer import Renderer as PulsarRenderer
|
||||
|
||||
|
||||
def _ensure_float_tensor(val_in, device):
|
||||
"""Make sure that the value provided is wrapped a PyTorch float tensor."""
|
||||
if not isinstance(val_in, torch.Tensor):
|
||||
val_out = torch.tensor(val_in, dtype=torch.float32, device=device).reshape((1,))
|
||||
else:
|
||||
val_out = val_in.to(torch.float32).to(device).reshape((1,))
|
||||
return val_out
|
||||
|
||||
|
||||
class PulsarPointsRenderer(nn.Module):
|
||||
"""
|
||||
This renderer is a PyTorch3D interface wrapper around the pulsar renderer.
|
||||
@@ -36,6 +45,7 @@ class PulsarPointsRenderer(nn.Module):
|
||||
compositor: Optional[Union[NormWeightedCompositor, AlphaCompositor]] = None,
|
||||
n_channels: int = 3,
|
||||
max_num_spheres: int = int(1e6), # noqa: B008
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
rasterizer (PointsRasterizer): An object encapsulating rasterization parameters.
|
||||
@@ -43,6 +53,8 @@ class PulsarPointsRenderer(nn.Module):
|
||||
n_channels (int): The number of channels of the resulting image. Default: 3.
|
||||
max_num_spheres (int): The maximum number of spheres intended to render with
|
||||
this renderer. Default: 1e6.
|
||||
kwargs (Any): kwargs to pass on to the pulsar renderer.
|
||||
See `pytorch3d.renderer.points.pulsar.renderer.Renderer` for all options.
|
||||
"""
|
||||
super().__init__()
|
||||
self.rasterizer = rasterizer
|
||||
@@ -87,6 +99,7 @@ class PulsarPointsRenderer(nn.Module):
|
||||
orthogonal_projection=orthogonal_projection,
|
||||
right_handed_system=True,
|
||||
n_channels=n_channels,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _conf_check(self, point_clouds, kwargs: Dict[str, Any]) -> bool:
|
||||
@@ -165,8 +178,8 @@ class PulsarPointsRenderer(nn.Module):
|
||||
)
|
||||
return orthogonal_projection
|
||||
|
||||
def _extract_intrinsics(
|
||||
self, orthogonal_projection, kwargs, cloud_idx
|
||||
def _extract_intrinsics( # noqa: C901
|
||||
self, orthogonal_projection, kwargs, cloud_idx, device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]:
|
||||
"""
|
||||
Translate the camera intrinsics from PyTorch3D format to pulsar format.
|
||||
@@ -174,7 +187,7 @@ class PulsarPointsRenderer(nn.Module):
|
||||
# Shorthand:
|
||||
cameras = self.rasterizer.cameras
|
||||
if orthogonal_projection:
|
||||
focal_length = 0.0
|
||||
focal_length = torch.zeros((1,), dtype=torch.float32)
|
||||
if isinstance(cameras, FoVOrthographicCameras):
|
||||
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `znear`.
|
||||
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
|
||||
@@ -212,7 +225,10 @@ class PulsarPointsRenderer(nn.Module):
|
||||
raise ValueError(
|
||||
f"The orthographic camera must have positive size! Is: {sensor_width}." # noqa: B950
|
||||
)
|
||||
principal_point_x, principal_point_y = 0.0, 0.0
|
||||
principal_point_x, principal_point_y = (
|
||||
torch.zeros((1,), dtype=torch.float32),
|
||||
torch.zeros((1,), dtype=torch.float32),
|
||||
)
|
||||
else:
|
||||
# Currently, this means it must be an 'OrthographicCameras' object.
|
||||
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
|
||||
@@ -276,7 +292,10 @@ class PulsarPointsRenderer(nn.Module):
|
||||
"must agree with the resolution width / height ("
|
||||
f"{self.renderer._renderer.width / self.renderer._renderer.height})." # noqa: B950
|
||||
)
|
||||
principal_point_x, principal_point_y = 0.0, 0.0
|
||||
principal_point_x, principal_point_y = (
|
||||
torch.zeros((1,), dtype=torch.float32),
|
||||
torch.zeros((1,), dtype=torch.float32),
|
||||
)
|
||||
else:
|
||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `focal_length`.
|
||||
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
|
||||
@@ -308,7 +327,13 @@ class PulsarPointsRenderer(nn.Module):
|
||||
"Focal length not parsable: %s." % (str(focal_length_conf))
|
||||
)
|
||||
focal_length_px = focal_length_conf
|
||||
focal_length = znear - 1e-6
|
||||
focal_length = torch.tensor(
|
||||
[
|
||||
znear - 1e-6,
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=focal_length_px.device,
|
||||
)
|
||||
sensor_width = focal_length / focal_length_px * 2.0
|
||||
principal_point_x = (
|
||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||
@@ -321,6 +346,12 @@ class PulsarPointsRenderer(nn.Module):
|
||||
* 0.5
|
||||
* self.renderer._renderer.height
|
||||
)
|
||||
focal_length = _ensure_float_tensor(focal_length, device)
|
||||
sensor_width = _ensure_float_tensor(sensor_width, device)
|
||||
principal_point_x = _ensure_float_tensor(principal_point_x, device)
|
||||
principal_point_y = _ensure_float_tensor(principal_point_y, device)
|
||||
znear = _ensure_float_tensor(znear, device)
|
||||
zfar = _ensure_float_tensor(zfar, device)
|
||||
return (
|
||||
focal_length,
|
||||
sensor_width,
|
||||
@@ -338,11 +369,17 @@ class PulsarPointsRenderer(nn.Module):
|
||||
R = kwargs.get("R", cameras.R)[cloud_idx]
|
||||
T = kwargs.get("T", cameras.T)[cloud_idx]
|
||||
norm_mat = torch.tensor(
|
||||
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
|
||||
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
|
||||
dtype=torch.float32,
|
||||
device=R.device,
|
||||
)
|
||||
cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...])
|
||||
cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]).permute((0, 2, 1))
|
||||
norm_mat = torch.tensor(
|
||||
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
|
||||
dtype=torch.float32,
|
||||
device=R.device,
|
||||
)
|
||||
cam_rot = torch.matmul(norm_mat, cam_rot)
|
||||
cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None]))
|
||||
cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot))
|
||||
return cam_pos, cam_rot
|
||||
@@ -374,7 +411,7 @@ class PulsarPointsRenderer(nn.Module):
|
||||
)
|
||||
else:
|
||||
point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False)
|
||||
vert_rad = raster_rad / focal_length * point_dists
|
||||
vert_rad = raster_rad / focal_length.to(vert_pos.device) * point_dists
|
||||
if isinstance(self.rasterizer.cameras, PerspectiveCameras):
|
||||
# NDC normalization happens through adjusted focal length.
|
||||
pass
|
||||
@@ -382,6 +419,7 @@ class PulsarPointsRenderer(nn.Module):
|
||||
vert_rad = vert_rad / 2.0 # NDC normalization.
|
||||
return vert_rad
|
||||
|
||||
# point_clouds is not typed to avoid a cyclic dependency.
|
||||
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Get the rendering of the provided `Pointclouds`.
|
||||
@@ -439,6 +477,8 @@ class PulsarPointsRenderer(nn.Module):
|
||||
for cloud_idx, (vert_pos, vert_col) in enumerate(
|
||||
zip(position_list, features_list)
|
||||
):
|
||||
# Get extrinsics.
|
||||
cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx)
|
||||
# Get intrinsics.
|
||||
(
|
||||
focal_length,
|
||||
@@ -447,23 +487,21 @@ class PulsarPointsRenderer(nn.Module):
|
||||
principal_point_y,
|
||||
znear,
|
||||
zfar,
|
||||
) = self._extract_intrinsics(orthogonal_projection, kwargs, cloud_idx)
|
||||
# Get extrinsics.
|
||||
cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx)
|
||||
) = self._extract_intrinsics(
|
||||
orthogonal_projection, kwargs, cloud_idx, cam_pos.device
|
||||
)
|
||||
# Put everything together.
|
||||
cam_params = torch.cat(
|
||||
(
|
||||
cam_pos,
|
||||
cam_rot,
|
||||
torch.tensor(
|
||||
cam_rot.to(cam_pos.device),
|
||||
torch.cat(
|
||||
[
|
||||
focal_length,
|
||||
sensor_width,
|
||||
principal_point_x,
|
||||
principal_point_y,
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=cam_pos.device,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user