pulsar interface unification.
Summary: This diff builds on top of the `pulsar integration` diff to provide a unified interface for the existing PyTorch3D point renderer and Pulsar. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder docs/examples. The unified interfaces are completely consistent. Switching the render backend is as easy as using `renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device)` instead of `renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)` and adding the `gamma` parameter to the forward function. All PyTorch3D camera types are supported as far as possible; keyword arguments are properly forwarded to the camera. The `PerspectiveCamera` and `OrthographicCamera` require znear and zfar as additional parameters for the forward pass. Reviewed By: nikhilaravi Differential Revision: D21421443 fbshipit-source-id: 4aa0a83a419592d9a0bb5d62486a1cdea9d73ce6
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .compositor import AlphaCompositor, NormWeightedCompositor
|
from .compositor import AlphaCompositor, NormWeightedCompositor
|
||||||
|
from .pulsar.unified import PulsarPointsRenderer
|
||||||
from .rasterize_points import rasterize_points
|
from .rasterize_points import rasterize_points
|
||||||
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
|
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
|
||||||
from .renderer import PointsRenderer
|
from .renderer import PointsRenderer
|
||||||
|
@ -369,6 +369,7 @@ class Renderer(torch.nn.Module):
|
|||||||
height: int,
|
height: int,
|
||||||
orthogonal: bool,
|
orthogonal: bool,
|
||||||
right_handed: bool,
|
right_handed: bool,
|
||||||
|
first_R_then_T: bool = False,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
@ -401,6 +402,8 @@ class Renderer(torch.nn.Module):
|
|||||||
(does not use focal length).
|
(does not use focal length).
|
||||||
* right_handed: bool, whether to use a right handed system
|
* right_handed: bool, whether to use a right handed system
|
||||||
(negative z in camera direction).
|
(negative z in camera direction).
|
||||||
|
* first_R_then_T: bool, whether to first rotate, then translate
|
||||||
|
the camera (PyTorch3D convention).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
* pos_vec: the position vector in 3D,
|
* pos_vec: the position vector in 3D,
|
||||||
@ -460,16 +463,18 @@ class Renderer(torch.nn.Module):
|
|||||||
# Always get quadratic pixels.
|
# Always get quadratic pixels.
|
||||||
pixel_size_x = sensor_size_x / float(width)
|
pixel_size_x = sensor_size_x / float(width)
|
||||||
sensor_size_y = height * pixel_size_x
|
sensor_size_y = height * pixel_size_x
|
||||||
|
if continuous_rep:
|
||||||
|
rot_mat = rotation_6d_to_matrix(rot_vec)
|
||||||
|
else:
|
||||||
|
rot_mat = axis_angle_to_matrix(rot_vec)
|
||||||
|
if first_R_then_T:
|
||||||
|
pos_vec = torch.matmul(rot_mat, pos_vec[..., None])[:, :, 0]
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"Camera position: %s, rotation: %s. Focal length: %s.",
|
"Camera position: %s, rotation: %s. Focal length: %s.",
|
||||||
str(pos_vec),
|
str(pos_vec),
|
||||||
str(rot_vec),
|
str(rot_vec),
|
||||||
str(focal_length),
|
str(focal_length),
|
||||||
)
|
)
|
||||||
if continuous_rep:
|
|
||||||
rot_mat = rotation_6d_to_matrix(rot_vec)
|
|
||||||
else:
|
|
||||||
rot_mat = axis_angle_to_matrix(rot_vec)
|
|
||||||
sensor_dir_x = torch.matmul(
|
sensor_dir_x = torch.matmul(
|
||||||
rot_mat,
|
rot_mat,
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
@ -576,6 +581,7 @@ class Renderer(torch.nn.Module):
|
|||||||
max_n_hits: int = _C.MAX_UINT,
|
max_n_hits: int = _C.MAX_UINT,
|
||||||
mode: int = 0,
|
mode: int = 0,
|
||||||
return_forward_info: bool = False,
|
return_forward_info: bool = False,
|
||||||
|
first_R_then_T: bool = False,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Rendering pass to create an image from the provided spheres and camera
|
Rendering pass to create an image from the provided spheres and camera
|
||||||
@ -616,6 +622,8 @@ class Renderer(torch.nn.Module):
|
|||||||
the float encoded integer index of a sphere and its weight. They are the
|
the float encoded integer index of a sphere and its weight. They are the
|
||||||
five spheres with the highest color contribution to this pixel color,
|
five spheres with the highest color contribution to this pixel color,
|
||||||
ordered descending. Default: False.
|
ordered descending. Default: False.
|
||||||
|
* first_R_then_T: bool, whether to first apply rotation to the camera,
|
||||||
|
then translation (PyTorch3D convention). Default: False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
* image: [Bx]HxWx3 float tensor with the resulting image.
|
* image: [Bx]HxWx3 float tensor with the resulting image.
|
||||||
@ -638,6 +646,7 @@ class Renderer(torch.nn.Module):
|
|||||||
self._renderer.height,
|
self._renderer.height,
|
||||||
self._renderer.orthogonal,
|
self._renderer.orthogonal,
|
||||||
self._renderer.right_handed,
|
self._renderer.right_handed,
|
||||||
|
first_R_then_T=first_R_then_T,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
focal_lengths.min().item() > 0.0
|
focal_lengths.min().item() > 0.0
|
||||||
|
524
pytorch3d/renderer/points/pulsar/unified.py
Normal file
@ -0,0 +1,524 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ....transforms import matrix_to_rotation_6d
|
||||||
|
from ...cameras import (
|
||||||
|
FoVOrthographicCameras,
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
OrthographicCameras,
|
||||||
|
PerspectiveCameras,
|
||||||
|
)
|
||||||
|
from ..compositor import AlphaCompositor, NormWeightedCompositor
|
||||||
|
from ..rasterizer import PointsRasterizer
|
||||||
|
from .renderer import Renderer as PulsarRenderer
|
||||||
|
|
||||||
|
|
||||||
|
class PulsarPointsRenderer(nn.Module):
|
||||||
|
"""
|
||||||
|
This renderer is a PyTorch3D interface wrapper around the pulsar renderer.
|
||||||
|
|
||||||
|
It provides an interface consistent with PyTorch3D Pointcloud rendering.
|
||||||
|
It will extract all necessary information from the rasterizer and compositor
|
||||||
|
objects and convert them to the pulsar required format, then invoke rendering
|
||||||
|
in the pulsar renderer. All gradients are handled appropriately through the
|
||||||
|
wrapper and the wrapper should provide equivalent results to using the pulsar
|
||||||
|
renderer directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rasterizer: PointsRasterizer,
|
||||||
|
compositor: Optional[Union[NormWeightedCompositor, AlphaCompositor]] = None,
|
||||||
|
n_channels: int = 3,
|
||||||
|
max_num_spheres: int = int(1e6), # noqa: B008
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
rasterizer (PointsRasterizer): An object encapsulating rasterization parameters.
|
||||||
|
compositor (ignored): Only keeping this for interface consistency. Default: None.
|
||||||
|
n_channels (int): The number of channels of the resulting image. Default: 3.
|
||||||
|
max_num_spheres (int): The maximum number of spheres intended to render with
|
||||||
|
this renderer. Default: 1e6.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.rasterizer = rasterizer
|
||||||
|
if compositor is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Creating a `PulsarPointsRenderer` with a compositor object! "
|
||||||
|
"This object is ignored and just allowed as an argument for interface "
|
||||||
|
"compatibility."
|
||||||
|
)
|
||||||
|
# Initialize the pulsar renderers.
|
||||||
|
if not isinstance(
|
||||||
|
rasterizer.cameras,
|
||||||
|
(
|
||||||
|
FoVOrthographicCameras,
|
||||||
|
FoVPerspectiveCameras,
|
||||||
|
PerspectiveCameras,
|
||||||
|
OrthographicCameras,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Only FoVPerspectiveCameras, PerspectiveCameras, "
|
||||||
|
"FoVOrthographicCameras and OrthographicCameras are supported "
|
||||||
|
"by the pulsar backend."
|
||||||
|
)
|
||||||
|
if isinstance(rasterizer.raster_settings.image_size, tuple):
|
||||||
|
width, height = rasterizer.raster_settings.image_size
|
||||||
|
else:
|
||||||
|
width = rasterizer.raster_settings.image_size
|
||||||
|
height = rasterizer.raster_settings.image_size
|
||||||
|
# Making sure about integer types.
|
||||||
|
width = int(width)
|
||||||
|
height = int(height)
|
||||||
|
max_num_spheres = int(max_num_spheres)
|
||||||
|
orthogonal_projection = isinstance(
|
||||||
|
rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
|
||||||
|
)
|
||||||
|
n_channels = int(n_channels)
|
||||||
|
self.renderer = PulsarRenderer(
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
max_num_balls=max_num_spheres,
|
||||||
|
orthogonal_projection=orthogonal_projection,
|
||||||
|
right_handed_system=True,
|
||||||
|
n_channels=n_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _conf_check(self, point_clouds, kwargs: Dict[str, Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Verify internal configuration state with kwargs and pointclouds.
|
||||||
|
|
||||||
|
This method will raise ValueError's for any inconsistencies found. It
|
||||||
|
returns whether an orthogonal projection will be used.
|
||||||
|
"""
|
||||||
|
if "gamma" not in kwargs.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"gamma is a required keyword argument for the PulsarPointsRenderer!"
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
len(point_clouds) != len(self.rasterizer.cameras)
|
||||||
|
and len(self.rasterizer.cameras) != 1
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
"The len(point_clouds) must either be equal to len(rasterizer.cameras) or "
|
||||||
|
"only one camera must be used. len(point_clouds): %d, "
|
||||||
|
"len(rasterizer.cameras): %d."
|
||||||
|
)
|
||||||
|
% (
|
||||||
|
len(point_clouds),
|
||||||
|
len(self.rasterizer.cameras),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Make sure the rasterizer and cameras objects have no
|
||||||
|
# changes that can't be matched.
|
||||||
|
orthogonal_projection = isinstance(
|
||||||
|
self.rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
|
||||||
|
)
|
||||||
|
if orthogonal_projection != self.renderer._renderer.orthogonal:
|
||||||
|
raise ValueError(
|
||||||
|
"The camera type can not be changed after renderer initialization! "
|
||||||
|
"Current camera orthogonal: %r. Original orthogonal: %r."
|
||||||
|
) % (orthogonal_projection, self.renderer._renderer.orthogonal)
|
||||||
|
if (
|
||||||
|
isinstance(self.rasterizer.raster_settings.image_size, tuple)
|
||||||
|
and self.rasterizer.raster_settings.image_size[0]
|
||||||
|
!= self.renderer._renderer.width
|
||||||
|
) or (
|
||||||
|
not isinstance(self.rasterizer.raster_settings.image_size, tuple)
|
||||||
|
and self.rasterizer.raster_settings.image_size
|
||||||
|
!= self.renderer._renderer.width
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
"The rasterizer width and height can not be changed after renderer "
|
||||||
|
"initialization! Current width: %d. Original width: %d."
|
||||||
|
)
|
||||||
|
% (
|
||||||
|
self.rasterizer.raster_settings.image_size,
|
||||||
|
self.renderer._renderer.width,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
isinstance(self.rasterizer.raster_settings.image_size, tuple)
|
||||||
|
and self.rasterizer.raster_settings.image_size[1]
|
||||||
|
!= self.renderer._renderer.height
|
||||||
|
) or (
|
||||||
|
not isinstance(self.rasterizer.raster_settings.image_size, tuple)
|
||||||
|
and self.rasterizer.raster_settings.image_size
|
||||||
|
!= self.renderer._renderer.height
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
"The rasterizer width and height can not be changed after renderer "
|
||||||
|
"initialization! Current height: %d. Original height: %d."
|
||||||
|
)
|
||||||
|
% (
|
||||||
|
self.rasterizer.raster_settings.image_size,
|
||||||
|
self.renderer._renderer.height,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return orthogonal_projection
|
||||||
|
|
||||||
|
def _extract_intrinsics(
|
||||||
|
self, orthogonal_projection, kwargs, cloud_idx
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]:
|
||||||
|
"""
|
||||||
|
Translate the camera intrinsics from PyTorch3D format to pulsar format.
|
||||||
|
"""
|
||||||
|
# Shorthand:
|
||||||
|
cameras = self.rasterizer.cameras
|
||||||
|
if orthogonal_projection:
|
||||||
|
focal_length = 0.0
|
||||||
|
if isinstance(cameras, FoVOrthographicCameras):
|
||||||
|
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `znear`.
|
||||||
|
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
|
||||||
|
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `zfar`.
|
||||||
|
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
|
||||||
|
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_y`.
|
||||||
|
max_y = kwargs.get("max_y", cameras.max_y)[cloud_idx]
|
||||||
|
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_y`.
|
||||||
|
min_y = kwargs.get("min_y", cameras.min_y)[cloud_idx]
|
||||||
|
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_x`.
|
||||||
|
max_x = kwargs.get("max_x", cameras.max_x)[cloud_idx]
|
||||||
|
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_x`.
|
||||||
|
min_x = kwargs.get("min_x", cameras.min_x)[cloud_idx]
|
||||||
|
if max_y != -min_y:
|
||||||
|
raise ValueError(
|
||||||
|
"The orthographic camera must be centered around 0. "
|
||||||
|
f"Max is {max_y} and min is {min_y}."
|
||||||
|
)
|
||||||
|
if max_x != -min_x:
|
||||||
|
raise ValueError(
|
||||||
|
"The orthographic camera must be centered around 0. "
|
||||||
|
f"Max is {max_x} and min is {min_x}."
|
||||||
|
)
|
||||||
|
if not torch.all(
|
||||||
|
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `scale_xyz`.
|
||||||
|
kwargs.get("scale_xyz", cameras.scale_xyz)[cloud_idx]
|
||||||
|
== 1.0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The orthographic camera scale must be ((1.0, 1.0, 1.0),). "
|
||||||
|
f"{kwargs.get('scale_xyz', cameras.scale_xyz)[cloud_idx]}."
|
||||||
|
)
|
||||||
|
sensor_width = max_x - min_x
|
||||||
|
if not sensor_width > 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"The orthographic camera must have positive size! Is: {sensor_width}." # noqa: B950
|
||||||
|
)
|
||||||
|
principal_point_x, principal_point_y = 0.0, 0.0
|
||||||
|
else:
|
||||||
|
# Currently, this means it must be an 'OrthographicCameras' object.
|
||||||
|
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
|
||||||
|
cloud_idx
|
||||||
|
]
|
||||||
|
if (
|
||||||
|
focal_length_conf.numel() == 2
|
||||||
|
and focal_length_conf[0] * self.renderer._renderer.width
|
||||||
|
- focal_length_conf[1] * self.renderer._renderer.height
|
||||||
|
> 1e-5
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Pulsar only supports a single focal length! "
|
||||||
|
"Provided: %s." % (str(focal_length_conf))
|
||||||
|
)
|
||||||
|
if focal_length_conf.numel() == 2:
|
||||||
|
sensor_width = 2.0 / focal_length_conf[0]
|
||||||
|
else:
|
||||||
|
if focal_length_conf.numel() != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Focal length not parsable: %s." % (str(focal_length_conf))
|
||||||
|
)
|
||||||
|
sensor_width = 2.0 / focal_length_conf
|
||||||
|
if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"pulsar needs znear and zfar values for "
|
||||||
|
"the OrthographicCameras. Please provide them as keyword "
|
||||||
|
"argument to the forward method."
|
||||||
|
)
|
||||||
|
znear = kwargs["znear"][cloud_idx]
|
||||||
|
zfar = kwargs["zfar"][cloud_idx]
|
||||||
|
principal_point_x = (
|
||||||
|
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
|
||||||
|
* 0.5
|
||||||
|
* self.renderer._renderer.width
|
||||||
|
)
|
||||||
|
principal_point_y = (
|
||||||
|
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
|
||||||
|
* 0.5
|
||||||
|
* self.renderer._renderer.height
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not isinstance(cameras, PerspectiveCameras):
|
||||||
|
# Create a virtual focal length that is closer than znear.
|
||||||
|
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
|
||||||
|
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
|
||||||
|
focal_length = znear - 1e-6
|
||||||
|
# Create a sensor size that matches the expected fov assuming this f.
|
||||||
|
afov = kwargs.get("fov", cameras.fov)[cloud_idx]
|
||||||
|
if kwargs.get("degrees", cameras.degrees):
|
||||||
|
afov *= math.pi / 180.0
|
||||||
|
sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length
|
||||||
|
if not (
|
||||||
|
kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx]
|
||||||
|
- self.renderer._renderer.width / self.renderer._renderer.height
|
||||||
|
< 1e-6
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The aspect ratio ("
|
||||||
|
f"{kwargs.get('aspect_ratio', cameras.aspect_ratio)[cloud_idx]}) "
|
||||||
|
"must agree with the resolution width / height ("
|
||||||
|
f"{self.renderer._renderer.width / self.renderer._renderer.height})." # noqa: B950
|
||||||
|
)
|
||||||
|
principal_point_x, principal_point_y = 0.0, 0.0
|
||||||
|
else:
|
||||||
|
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `focal_length`.
|
||||||
|
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
|
||||||
|
cloud_idx
|
||||||
|
]
|
||||||
|
if (
|
||||||
|
focal_length_conf.numel() == 2
|
||||||
|
and focal_length_conf[0] * self.renderer._renderer.width
|
||||||
|
- focal_length_conf[1] * self.renderer._renderer.height
|
||||||
|
> 1e-5
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Pulsar only supports a single focal length! "
|
||||||
|
"Provided: %s." % (str(focal_length_conf))
|
||||||
|
)
|
||||||
|
if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys():
|
||||||
|
raise ValueError(
|
||||||
|
"pulsar needs znear and zfar values for "
|
||||||
|
"the PerspectiveCameras. Please provide them as keyword "
|
||||||
|
"argument to the forward method."
|
||||||
|
)
|
||||||
|
znear = kwargs["znear"][cloud_idx]
|
||||||
|
zfar = kwargs["zfar"][cloud_idx]
|
||||||
|
if focal_length_conf.numel() == 2:
|
||||||
|
focal_length_px = focal_length_conf[0]
|
||||||
|
else:
|
||||||
|
if focal_length_conf.numel() != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Focal length not parsable: %s." % (str(focal_length_conf))
|
||||||
|
)
|
||||||
|
focal_length_px = focal_length_conf
|
||||||
|
focal_length = znear - 1e-6
|
||||||
|
sensor_width = focal_length / focal_length_px * 2.0
|
||||||
|
principal_point_x = (
|
||||||
|
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||||
|
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
|
||||||
|
* 0.5
|
||||||
|
* self.renderer._renderer.width
|
||||||
|
)
|
||||||
|
principal_point_y = (
|
||||||
|
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
|
||||||
|
* 0.5
|
||||||
|
* self.renderer._renderer.height
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
focal_length,
|
||||||
|
sensor_width,
|
||||||
|
principal_point_x,
|
||||||
|
principal_point_y,
|
||||||
|
znear,
|
||||||
|
zfar,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_extrinsics(
|
||||||
|
self, kwargs, cloud_idx
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Shorthand:
|
||||||
|
cameras = self.rasterizer.cameras
|
||||||
|
R = kwargs.get("R", cameras.R)[cloud_idx]
|
||||||
|
T = kwargs.get("T", cameras.T)[cloud_idx]
|
||||||
|
norm_mat = torch.tensor(
|
||||||
|
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=R.device,
|
||||||
|
)
|
||||||
|
cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...])
|
||||||
|
cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None]))
|
||||||
|
cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot))
|
||||||
|
return cam_pos, cam_rot
|
||||||
|
|
||||||
|
def _get_vert_rad(
|
||||||
|
self, vert_pos, cam_pos, orthogonal_projection, focal_length, kwargs, cloud_idx
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get point radiuses.
|
||||||
|
|
||||||
|
These can be depending on the camera position in case of a perspective
|
||||||
|
transform.
|
||||||
|
"""
|
||||||
|
# Normalize point radiuses.
|
||||||
|
# `self.rasterizer.raster_settings.radius` can either be a float
|
||||||
|
# or itself a tensor.
|
||||||
|
raster_rad = self.rasterizer.raster_settings.radius
|
||||||
|
if kwargs.get("radius_world", False):
|
||||||
|
return raster_rad
|
||||||
|
if isinstance(raster_rad, torch.Tensor) and raster_rad.numel() > 1:
|
||||||
|
# In this case it must be a batched torch tensor.
|
||||||
|
raster_rad = raster_rad[cloud_idx]
|
||||||
|
if orthogonal_projection:
|
||||||
|
vert_rad = (
|
||||||
|
torch.ones(
|
||||||
|
(vert_pos.shape[0],), dtype=torch.float32, device=vert_pos.device
|
||||||
|
)
|
||||||
|
* raster_rad
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False)
|
||||||
|
vert_rad = raster_rad / focal_length * point_dists
|
||||||
|
if isinstance(self.rasterizer.cameras, PerspectiveCameras):
|
||||||
|
# NDC normalization happens through adjusted focal length.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
vert_rad = vert_rad / 2.0 # NDC normalization.
|
||||||
|
return vert_rad
|
||||||
|
|
||||||
|
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get the rendering of the provided `Pointclouds`.
|
||||||
|
|
||||||
|
The number of point clouds in the `Pointclouds` object determines the
|
||||||
|
number of resulting images. The provided cameras can be either 1 or equal
|
||||||
|
to the number of pointclouds (in the first case, the same camera will be
|
||||||
|
used for all clouds, in the latter case each point cloud will be rendered
|
||||||
|
with the corresponding camera).
|
||||||
|
|
||||||
|
The following kwargs are support from PyTorch3D (depending on the selected
|
||||||
|
camera model potentially overriding camera parameters):
|
||||||
|
radius_world (bool): use the provided radiuses from the raster_settings
|
||||||
|
plain as radiuses in world space. Default: False.
|
||||||
|
znear (Iterable[float]): near geometry cutoff. Is required for
|
||||||
|
OrthographicCameras and PerspectiveCameras.
|
||||||
|
zfar (Iterable[float]): far geometry cutoff. Is required for
|
||||||
|
OrthographicCameras and PerspectiveCameras.
|
||||||
|
R (torch.Tensor): [Bx3x3] camera rotation matrices.
|
||||||
|
T (torch.Tensor): [Bx3] camera translation vectors.
|
||||||
|
principal_point (torch.Tensor): [Bx2] camera intrinsic principal
|
||||||
|
point offset vectors.
|
||||||
|
focal_length (torch.Tensor): [Bx1] camera intrinsic focal lengths.
|
||||||
|
aspect_ratio (Iterable[float]): camera aspect ratios.
|
||||||
|
fov (Iterable[float]): camera FOVs.
|
||||||
|
degrees (bool): whether FOVs are specified in degrees or
|
||||||
|
radians.
|
||||||
|
min_x (Iterable[float]): minimum x for the FoVOrthographicCameras.
|
||||||
|
max_x (Iterable[float]): maximum x for the FoVOrthographicCameras.
|
||||||
|
min_y (Iterable[float]): minimum y for the FoVOrthographicCameras.
|
||||||
|
max_y (Iterable[float]): maximum y for the FoVOrthographicCameras.
|
||||||
|
|
||||||
|
The following kwargs are supported from pulsar:
|
||||||
|
gamma (float): The gamma value to use. This defines the transparency for
|
||||||
|
differentiability (see pulsar paper for details). Must be in [1., 1e-5]
|
||||||
|
with 1.0 being mostly transparent. This keyword argument is *required*!
|
||||||
|
bg_col (torch.Tensor): The background color. Must be a tensor on the same
|
||||||
|
device as the point clouds, with as many channels as features (no batch
|
||||||
|
dimension - it is the same for all images in the batch).
|
||||||
|
Default: 0.0 for all channels.
|
||||||
|
percent_allowed_difference (float): a value in [0., 1.[ with the maximum
|
||||||
|
allowed difference in channel space. This is used to speed up the
|
||||||
|
computation. Default: 0.01.
|
||||||
|
max_n_hits (int): a hard limit on the number of sphere hits per ray.
|
||||||
|
Default: max int.
|
||||||
|
mode (int): render mode in {0, 1}. 0: render image; 1: render hit map.
|
||||||
|
"""
|
||||||
|
orthogonal_projection: bool = self._conf_check(point_clouds, kwargs)
|
||||||
|
# Get access to inputs. We're using the list accessor and process
|
||||||
|
# them sequentially.
|
||||||
|
position_list = point_clouds.points_list()
|
||||||
|
features_list = point_clouds.features_list()
|
||||||
|
# Result list.
|
||||||
|
images = []
|
||||||
|
for cloud_idx, (vert_pos, vert_col) in enumerate(
|
||||||
|
zip(position_list, features_list)
|
||||||
|
):
|
||||||
|
# Get intrinsics.
|
||||||
|
(
|
||||||
|
focal_length,
|
||||||
|
sensor_width,
|
||||||
|
principal_point_x,
|
||||||
|
principal_point_y,
|
||||||
|
znear,
|
||||||
|
zfar,
|
||||||
|
) = self._extract_intrinsics(orthogonal_projection, kwargs, cloud_idx)
|
||||||
|
# Get extrinsics.
|
||||||
|
cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx)
|
||||||
|
# Put everything together.
|
||||||
|
cam_params = torch.cat(
|
||||||
|
(
|
||||||
|
cam_pos,
|
||||||
|
cam_rot,
|
||||||
|
torch.tensor(
|
||||||
|
[
|
||||||
|
focal_length,
|
||||||
|
sensor_width,
|
||||||
|
principal_point_x,
|
||||||
|
principal_point_y,
|
||||||
|
],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=cam_pos.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Get point radiuses (can depend on camera position).
|
||||||
|
vert_rad = self._get_vert_rad(
|
||||||
|
vert_pos,
|
||||||
|
cam_pos,
|
||||||
|
orthogonal_projection,
|
||||||
|
focal_length,
|
||||||
|
kwargs,
|
||||||
|
cloud_idx,
|
||||||
|
)
|
||||||
|
# Clean kwargs for passing on.
|
||||||
|
gamma = kwargs["gamma"][cloud_idx]
|
||||||
|
if "first_R_then_T" in kwargs.keys():
|
||||||
|
raise ValueError("`first_R_then_T` is not supported in this interface.")
|
||||||
|
otherargs = {
|
||||||
|
argn: argv
|
||||||
|
for argn, argv in kwargs.items()
|
||||||
|
if argn
|
||||||
|
not in [
|
||||||
|
"radius_world",
|
||||||
|
"gamma",
|
||||||
|
"znear",
|
||||||
|
"zfar",
|
||||||
|
"R",
|
||||||
|
"T",
|
||||||
|
"principal_point",
|
||||||
|
"focal_length",
|
||||||
|
"aspect_ratio",
|
||||||
|
"fov",
|
||||||
|
"degrees",
|
||||||
|
"min_x",
|
||||||
|
"max_x",
|
||||||
|
"min_y",
|
||||||
|
"max_y",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
# background color
|
||||||
|
if "bg_col" not in otherargs:
|
||||||
|
bg_col = torch.zeros(
|
||||||
|
vert_col.shape[1], device=cam_params.device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
otherargs["bg_col"] = bg_col
|
||||||
|
# Go!
|
||||||
|
images.append(
|
||||||
|
self.renderer(
|
||||||
|
vert_pos=vert_pos,
|
||||||
|
vert_col=vert_col,
|
||||||
|
vert_rad=vert_rad,
|
||||||
|
cam_params=cam_params,
|
||||||
|
gamma=gamma,
|
||||||
|
max_depth=zfar,
|
||||||
|
min_depth=znear,
|
||||||
|
**otherargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return torch.stack(images, dim=0)
|
@ -1,7 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
@ -48,7 +46,7 @@ class PointsRenderer(nn.Module):
|
|||||||
fragments.idx.long().permute(0, 3, 1, 2),
|
fragments.idx.long().permute(0, 3, 1, 2),
|
||||||
weights,
|
weights,
|
||||||
point_clouds.features_packed().permute(1, 0),
|
point_clouds.features_packed().permute(1, 0),
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# permute so image comes at the end
|
# permute so image comes at the end
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .external.kornia_angle_axis_to_rotation_matrix import (
|
|
||||||
angle_axis_to_rotation_matrix as axis_angle_to_matrix,
|
|
||||||
)
|
|
||||||
from .rotation_conversions import (
|
from .rotation_conversions import (
|
||||||
|
axis_angle_to_matrix,
|
||||||
euler_angles_to_matrix,
|
euler_angles_to_matrix,
|
||||||
matrix_to_euler_angles,
|
matrix_to_euler_angles,
|
||||||
matrix_to_quaternion,
|
matrix_to_quaternion,
|
||||||
|
1
pytorch3d/transforms/external/__init__.py
vendored
@ -1 +0,0 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
@ -1,94 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
This file contains the great angle axis to rotation matrix conversion
|
|
||||||
from kornia (https://github.com/arraiyopensource/kornia). The license
|
|
||||||
can be found in kornia_license.txt.
|
|
||||||
|
|
||||||
The method is used unchanged; the documentation has been adjusted
|
|
||||||
to match our doc format.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def angle_axis_to_rotation_matrix(angle_axis):
|
|
||||||
"""Convert 3d vector of axis-angle rotation to 4x4 rotation matrix
|
|
||||||
|
|
||||||
Args:
|
|
||||||
angle_axis (Tensor): tensor of 3d vector of axis-angle rotations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor: tensor of 3x3 rotation matrix.
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
- Input: :math:`(N, 3)`
|
|
||||||
- Output: :math:`(N, 3, 3)`
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
..code-block::python
|
|
||||||
|
|
||||||
>>> input = torch.rand(1, 3) # Nx3
|
|
||||||
>>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3
|
|
||||||
>>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
|
|
||||||
# We want to be careful to only evaluate the square root if the
|
|
||||||
# norm of the angle_axis vector is greater than zero. Otherwise
|
|
||||||
# we get a division by zero.
|
|
||||||
k_one = 1.0
|
|
||||||
theta = torch.sqrt(theta2)
|
|
||||||
wxyz = angle_axis / (theta + eps)
|
|
||||||
wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
|
|
||||||
cos_theta = torch.cos(theta)
|
|
||||||
sin_theta = torch.sin(theta)
|
|
||||||
|
|
||||||
r00 = cos_theta + wx * wx * (k_one - cos_theta)
|
|
||||||
r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
|
|
||||||
r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
|
|
||||||
r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
|
|
||||||
r11 = cos_theta + wy * wy * (k_one - cos_theta)
|
|
||||||
r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
|
|
||||||
r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
|
|
||||||
r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
|
|
||||||
r22 = cos_theta + wz * wz * (k_one - cos_theta)
|
|
||||||
rotation_matrix = torch.cat(
|
|
||||||
[r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1
|
|
||||||
)
|
|
||||||
return rotation_matrix.view(-1, 3, 3)
|
|
||||||
|
|
||||||
def _compute_rotation_matrix_taylor(angle_axis):
|
|
||||||
rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
|
|
||||||
k_one = torch.ones_like(rx)
|
|
||||||
rotation_matrix = torch.cat(
|
|
||||||
[k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1
|
|
||||||
)
|
|
||||||
return rotation_matrix.view(-1, 3, 3)
|
|
||||||
|
|
||||||
# stolen from ceres/rotation.h
|
|
||||||
|
|
||||||
_angle_axis = torch.unsqueeze(angle_axis + 1e-6, dim=1)
|
|
||||||
# _angle_axis.register_hook(lambda grad: pdb.set_trace())
|
|
||||||
# _angle_axis = 1e-6
|
|
||||||
theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
|
|
||||||
theta2 = torch.squeeze(theta2, dim=1)
|
|
||||||
|
|
||||||
# compute rotation matrices
|
|
||||||
rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
|
|
||||||
rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
|
|
||||||
|
|
||||||
# create mask to handle both cases
|
|
||||||
eps = 1e-6
|
|
||||||
mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
|
|
||||||
mask_pos = (mask).type_as(theta2)
|
|
||||||
mask_neg = (mask == False).type_as(theta2) # noqa
|
|
||||||
|
|
||||||
# create output pose matrix
|
|
||||||
batch_size = angle_axis.shape[0]
|
|
||||||
rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis)
|
|
||||||
rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1)
|
|
||||||
# fill output matrix with masked values
|
|
||||||
rotation_matrix[..., :3, :3] = (
|
|
||||||
mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
|
|
||||||
)
|
|
||||||
return rotation_matrix.to(angle_axis.device).type_as(angle_axis) # Nx4x4
|
|
201
pytorch3d/transforms/external/kornia_license.txt
vendored
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
20
setup.py
@ -13,13 +13,8 @@ from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
|
|||||||
def get_extensions():
|
def get_extensions():
|
||||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
extensions_dir = os.path.join(this_dir, "pytorch3d", "csrc")
|
extensions_dir = os.path.join(this_dir, "pytorch3d", "csrc")
|
||||||
|
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)
|
||||||
main_source = os.path.join(extensions_dir, "ext.cpp")
|
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True)
|
||||||
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
|
|
||||||
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"))
|
|
||||||
|
|
||||||
sources = [main_source] + sources
|
|
||||||
|
|
||||||
extension = CppExtension
|
extension = CppExtension
|
||||||
|
|
||||||
extra_compile_args = {"cxx": ["-std=c++14"]}
|
extra_compile_args = {"cxx": ["-std=c++14"]}
|
||||||
@ -30,7 +25,18 @@ def get_extensions():
|
|||||||
extension = CUDAExtension
|
extension = CUDAExtension
|
||||||
sources += source_cuda
|
sources += source_cuda
|
||||||
define_macros += [("WITH_CUDA", None)]
|
define_macros += [("WITH_CUDA", None)]
|
||||||
|
cub_home = os.environ.get("CUB_HOME", None)
|
||||||
|
if cub_home is None:
|
||||||
|
raise Exception(
|
||||||
|
"The environment variable `CUB_HOME` was not found. "
|
||||||
|
"NVIDIA CUB is required for compilation and can be downloaded "
|
||||||
|
"from `https://github.com/NVIDIA/cub/releases`. You can unpack "
|
||||||
|
"it to a location of your choice and set the environment variable "
|
||||||
|
"`CUB_HOME` to the folder containing the `CMakeListst.txt` file."
|
||||||
|
)
|
||||||
nvcc_args = [
|
nvcc_args = [
|
||||||
|
"-I%s" % (os.path.realpath(cub_home).replace("\\ ", " ")),
|
||||||
|
"-std=c++14",
|
||||||
"-DCUDA_HAS_FP16=1",
|
"-DCUDA_HAS_FP16=1",
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
"-D__CUDA_NO_HALF_OPERATORS__",
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
||||||
|
After Width: | Height: | Size: 1.9 KiB |
After Width: | Height: | Size: 3.3 KiB |
After Width: | Height: | Size: 1.9 KiB |
After Width: | Height: | Size: 3.0 KiB |
After Width: | Height: | Size: 2.1 KiB |
After Width: | Height: | Size: 3.4 KiB |
After Width: | Height: | Size: 2.1 KiB |
After Width: | Height: | Size: 3.2 KiB |
@ -16,6 +16,8 @@ from PIL import Image
|
|||||||
from pytorch3d.renderer.cameras import (
|
from pytorch3d.renderer.cameras import (
|
||||||
FoVOrthographicCameras,
|
FoVOrthographicCameras,
|
||||||
FoVPerspectiveCameras,
|
FoVPerspectiveCameras,
|
||||||
|
OrthographicCameras,
|
||||||
|
PerspectiveCameras,
|
||||||
look_at_view_transform,
|
look_at_view_transform,
|
||||||
)
|
)
|
||||||
from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum
|
from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum
|
||||||
@ -25,6 +27,7 @@ from pytorch3d.renderer.points import (
|
|||||||
PointsRasterizationSettings,
|
PointsRasterizationSettings,
|
||||||
PointsRasterizer,
|
PointsRasterizer,
|
||||||
PointsRenderer,
|
PointsRenderer,
|
||||||
|
PulsarPointsRenderer,
|
||||||
)
|
)
|
||||||
from pytorch3d.structures.pointclouds import Pointclouds
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||||
@ -72,6 +75,145 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertClose(rgb, image_ref)
|
self.assertClose(rgb, image_ref)
|
||||||
|
|
||||||
|
def test_simple_sphere_pulsar(self):
|
||||||
|
for device in [torch.device("cpu"), torch.device("cuda")]:
|
||||||
|
sphere_mesh = ico_sphere(1, device)
|
||||||
|
verts_padded = sphere_mesh.verts_padded()
|
||||||
|
# Shift vertices to check coordinate frames are correct.
|
||||||
|
verts_padded[..., 1] += 0.2
|
||||||
|
verts_padded[..., 0] += 0.2
|
||||||
|
pointclouds = Pointclouds(
|
||||||
|
points=verts_padded, features=torch.ones_like(verts_padded)
|
||||||
|
)
|
||||||
|
for azimuth in [0.0, 90.0]:
|
||||||
|
R, T = look_at_view_transform(2.7, 0.0, azimuth)
|
||||||
|
for camera_name, cameras in [
|
||||||
|
("fovperspective", FoVPerspectiveCameras(device=device, R=R, T=T)),
|
||||||
|
(
|
||||||
|
"fovorthographic",
|
||||||
|
FoVOrthographicCameras(device=device, R=R, T=T),
|
||||||
|
),
|
||||||
|
("perspective", PerspectiveCameras(device=device, R=R, T=T)),
|
||||||
|
("orthographic", OrthographicCameras(device=device, R=R, T=T)),
|
||||||
|
]:
|
||||||
|
raster_settings = PointsRasterizationSettings(
|
||||||
|
image_size=256, radius=5e-2, points_per_pixel=1
|
||||||
|
)
|
||||||
|
rasterizer = PointsRasterizer(
|
||||||
|
cameras=cameras, raster_settings=raster_settings
|
||||||
|
)
|
||||||
|
renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device)
|
||||||
|
# Load reference image
|
||||||
|
filename = (
|
||||||
|
"pulsar_simple_pointcloud_sphere_"
|
||||||
|
f"azimuth{azimuth}_{camera_name}.png"
|
||||||
|
)
|
||||||
|
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||||
|
images = renderer(
|
||||||
|
pointclouds, gamma=(1e-3,), znear=(1.0,), zfar=(100.0,)
|
||||||
|
)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
filename = "DEBUG_%s" % filename
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / filename
|
||||||
|
)
|
||||||
|
self.assertClose(rgb, image_ref, rtol=7e-3, atol=5e-3)
|
||||||
|
|
||||||
|
def test_unified_inputs_pulsar(self):
|
||||||
|
# Test data on different devices.
|
||||||
|
for device in [torch.device("cpu"), torch.device("cuda")]:
|
||||||
|
sphere_mesh = ico_sphere(1, device)
|
||||||
|
verts_padded = sphere_mesh.verts_padded()
|
||||||
|
pointclouds = Pointclouds(
|
||||||
|
points=verts_padded, features=torch.ones_like(verts_padded)
|
||||||
|
)
|
||||||
|
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||||
|
# Test the different camera types.
|
||||||
|
for _, cameras in [
|
||||||
|
("fovperspective", FoVPerspectiveCameras(device=device, R=R, T=T)),
|
||||||
|
(
|
||||||
|
"fovorthographic",
|
||||||
|
FoVOrthographicCameras(device=device, R=R, T=T),
|
||||||
|
),
|
||||||
|
("perspective", PerspectiveCameras(device=device, R=R, T=T)),
|
||||||
|
("orthographic", OrthographicCameras(device=device, R=R, T=T)),
|
||||||
|
]:
|
||||||
|
# Test different ways for image size specification.
|
||||||
|
for image_size in (256, (256, 256)):
|
||||||
|
raster_settings = PointsRasterizationSettings(
|
||||||
|
image_size=image_size, radius=5e-2, points_per_pixel=1
|
||||||
|
)
|
||||||
|
rasterizer = PointsRasterizer(
|
||||||
|
cameras=cameras, raster_settings=raster_settings
|
||||||
|
)
|
||||||
|
# Test that the compositor can be provided. It's value is ignored
|
||||||
|
# so use a dummy.
|
||||||
|
_ = PulsarPointsRenderer(rasterizer=rasterizer, compositor=1).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
# Constructor without compositor.
|
||||||
|
_ = PulsarPointsRenderer(rasterizer=rasterizer).to(device)
|
||||||
|
# Constructor with n_channels.
|
||||||
|
_ = PulsarPointsRenderer(rasterizer=rasterizer, n_channels=3).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
# Constructor with max_num_spheres.
|
||||||
|
renderer = PulsarPointsRenderer(
|
||||||
|
rasterizer=rasterizer, max_num_spheres=1000
|
||||||
|
).to(device)
|
||||||
|
# Test the forward function.
|
||||||
|
if isinstance(cameras, (PerspectiveCameras, OrthographicCameras)):
|
||||||
|
# znear and zfar is required in this case.
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
lambda: renderer.forward(
|
||||||
|
point_clouds=pointclouds, gamma=(1e-4,)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
renderer.forward(
|
||||||
|
point_clouds=pointclouds,
|
||||||
|
gamma=(1e-4,),
|
||||||
|
znear=(1.0,),
|
||||||
|
zfar=(2.0,),
|
||||||
|
)
|
||||||
|
# znear and zfar must be batched.
|
||||||
|
self.assertRaises(
|
||||||
|
TypeError,
|
||||||
|
lambda: renderer.forward(
|
||||||
|
point_clouds=pointclouds,
|
||||||
|
gamma=(1e-4,),
|
||||||
|
znear=1.0,
|
||||||
|
zfar=(2.0,),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertRaises(
|
||||||
|
TypeError,
|
||||||
|
lambda: renderer.forward(
|
||||||
|
point_clouds=pointclouds,
|
||||||
|
gamma=(1e-4,),
|
||||||
|
znear=(1.0,),
|
||||||
|
zfar=2.0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# gamma must be batched.
|
||||||
|
self.assertRaises(
|
||||||
|
TypeError,
|
||||||
|
lambda: renderer.forward(
|
||||||
|
point_clouds=pointclouds, gamma=1e-4
|
||||||
|
),
|
||||||
|
)
|
||||||
|
renderer.forward(point_clouds=pointclouds, gamma=(1e-4,))
|
||||||
|
# rasterizer width and height change.
|
||||||
|
renderer.rasterizer.raster_settings.image_size = 0
|
||||||
|
self.assertRaises(
|
||||||
|
ValueError,
|
||||||
|
lambda: renderer.forward(
|
||||||
|
point_clouds=pointclouds, gamma=(1e-4,)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def test_pointcloud_with_features(self):
|
def test_pointcloud_with_features(self):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||||
|