From 960fd6d8b6f55257dc1d205e8c8f3366202c23b7 Mon Sep 17 00:00:00 2001 From: Christoph Lassner Date: Tue, 3 Nov 2020 13:05:02 -0800 Subject: [PATCH] pulsar interface unification. Summary: This diff builds on top of the `pulsar integration` diff to provide a unified interface for the existing PyTorch3D point renderer and Pulsar. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder docs/examples. The unified interfaces are completely consistent. Switching the render backend is as easy as using `renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device)` instead of `renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)` and adding the `gamma` parameter to the forward function. All PyTorch3D camera types are supported as far as possible; keyword arguments are properly forwarded to the camera. The `PerspectiveCamera` and `OrthographicCamera` require znear and zfar as additional parameters for the forward pass. Reviewed By: nikhilaravi Differential Revision: D21421443 fbshipit-source-id: 4aa0a83a419592d9a0bb5d62486a1cdea9d73ce6 --- pytorch3d/renderer/points/__init__.py | 1 + pytorch3d/renderer/points/pulsar/renderer.py | 17 +- pytorch3d/renderer/points/pulsar/unified.py | 524 ++++++++++++++++++ pytorch3d/renderer/points/renderer.py | 4 +- pytorch3d/transforms/__init__.py | 4 +- pytorch3d/transforms/external/__init__.py | 1 - .../kornia_angle_axis_to_rotation_matrix.py | 94 ---- .../transforms/external/kornia_license.txt | 201 ------- setup.py | 20 +- ...loud_sphere_azimuth0.0_fovorthographic.png | Bin 0 -> 1930 bytes ...cloud_sphere_azimuth0.0_fovperspective.png | Bin 0 -> 3373 bytes ...ntcloud_sphere_azimuth0.0_orthographic.png | Bin 0 -> 1930 bytes ...intcloud_sphere_azimuth0.0_perspective.png | Bin 0 -> 3070 bytes ...oud_sphere_azimuth90.0_fovorthographic.png | Bin 0 -> 2112 bytes ...loud_sphere_azimuth90.0_fovperspective.png | Bin 0 -> 3451 bytes ...tcloud_sphere_azimuth90.0_orthographic.png | Bin 0 -> 2112 bytes ...ntcloud_sphere_azimuth90.0_perspective.png | Bin 0 -> 3278 bytes tests/test_render_points.py | 142 +++++ 18 files changed, 695 insertions(+), 313 deletions(-) create mode 100644 pytorch3d/renderer/points/pulsar/unified.py delete mode 100644 pytorch3d/transforms/external/__init__.py delete mode 100644 pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py delete mode 100644 pytorch3d/transforms/external/kornia_license.txt create mode 100644 tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png create mode 100644 tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovperspective.png create mode 100644 tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png create mode 100644 tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_perspective.png create mode 100644 tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png create mode 100644 tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovperspective.png create mode 100644 tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png create mode 100644 tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_perspective.png diff --git a/pytorch3d/renderer/points/__init__.py b/pytorch3d/renderer/points/__init__.py index b334f4c5..f0ee9b02 100644 --- a/pytorch3d/renderer/points/__init__.py +++ b/pytorch3d/renderer/points/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from .compositor import AlphaCompositor, NormWeightedCompositor +from .pulsar.unified import PulsarPointsRenderer from .rasterize_points import rasterize_points from .rasterizer import PointsRasterizationSettings, PointsRasterizer from .renderer import PointsRenderer diff --git a/pytorch3d/renderer/points/pulsar/renderer.py b/pytorch3d/renderer/points/pulsar/renderer.py index 0e86d6aa..5f30f648 100644 --- a/pytorch3d/renderer/points/pulsar/renderer.py +++ b/pytorch3d/renderer/points/pulsar/renderer.py @@ -369,6 +369,7 @@ class Renderer(torch.nn.Module): height: int, orthogonal: bool, right_handed: bool, + first_R_then_T: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -401,6 +402,8 @@ class Renderer(torch.nn.Module): (does not use focal length). * right_handed: bool, whether to use a right handed system (negative z in camera direction). + * first_R_then_T: bool, whether to first rotate, then translate + the camera (PyTorch3D convention). Returns: * pos_vec: the position vector in 3D, @@ -460,16 +463,18 @@ class Renderer(torch.nn.Module): # Always get quadratic pixels. pixel_size_x = sensor_size_x / float(width) sensor_size_y = height * pixel_size_x + if continuous_rep: + rot_mat = rotation_6d_to_matrix(rot_vec) + else: + rot_mat = axis_angle_to_matrix(rot_vec) + if first_R_then_T: + pos_vec = torch.matmul(rot_mat, pos_vec[..., None])[:, :, 0] LOGGER.debug( "Camera position: %s, rotation: %s. Focal length: %s.", str(pos_vec), str(rot_vec), str(focal_length), ) - if continuous_rep: - rot_mat = rotation_6d_to_matrix(rot_vec) - else: - rot_mat = axis_angle_to_matrix(rot_vec) sensor_dir_x = torch.matmul( rot_mat, torch.tensor( @@ -576,6 +581,7 @@ class Renderer(torch.nn.Module): max_n_hits: int = _C.MAX_UINT, mode: int = 0, return_forward_info: bool = False, + first_R_then_T: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: """ Rendering pass to create an image from the provided spheres and camera @@ -616,6 +622,8 @@ class Renderer(torch.nn.Module): the float encoded integer index of a sphere and its weight. They are the five spheres with the highest color contribution to this pixel color, ordered descending. Default: False. + * first_R_then_T: bool, whether to first apply rotation to the camera, + then translation (PyTorch3D convention). Default: False. Returns: * image: [Bx]HxWx3 float tensor with the resulting image. @@ -638,6 +646,7 @@ class Renderer(torch.nn.Module): self._renderer.height, self._renderer.orthogonal, self._renderer.right_handed, + first_R_then_T=first_R_then_T, ) if ( focal_lengths.min().item() > 0.0 diff --git a/pytorch3d/renderer/points/pulsar/unified.py b/pytorch3d/renderer/points/pulsar/unified.py new file mode 100644 index 00000000..08b10af6 --- /dev/null +++ b/pytorch3d/renderer/points/pulsar/unified.py @@ -0,0 +1,524 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import math +import warnings +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ....transforms import matrix_to_rotation_6d +from ...cameras import ( + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, +) +from ..compositor import AlphaCompositor, NormWeightedCompositor +from ..rasterizer import PointsRasterizer +from .renderer import Renderer as PulsarRenderer + + +class PulsarPointsRenderer(nn.Module): + """ + This renderer is a PyTorch3D interface wrapper around the pulsar renderer. + + It provides an interface consistent with PyTorch3D Pointcloud rendering. + It will extract all necessary information from the rasterizer and compositor + objects and convert them to the pulsar required format, then invoke rendering + in the pulsar renderer. All gradients are handled appropriately through the + wrapper and the wrapper should provide equivalent results to using the pulsar + renderer directly. + """ + + def __init__( + self, + rasterizer: PointsRasterizer, + compositor: Optional[Union[NormWeightedCompositor, AlphaCompositor]] = None, + n_channels: int = 3, + max_num_spheres: int = int(1e6), # noqa: B008 + ): + """ + rasterizer (PointsRasterizer): An object encapsulating rasterization parameters. + compositor (ignored): Only keeping this for interface consistency. Default: None. + n_channels (int): The number of channels of the resulting image. Default: 3. + max_num_spheres (int): The maximum number of spheres intended to render with + this renderer. Default: 1e6. + """ + super().__init__() + self.rasterizer = rasterizer + if compositor is not None: + warnings.warn( + "Creating a `PulsarPointsRenderer` with a compositor object! " + "This object is ignored and just allowed as an argument for interface " + "compatibility." + ) + # Initialize the pulsar renderers. + if not isinstance( + rasterizer.cameras, + ( + FoVOrthographicCameras, + FoVPerspectiveCameras, + PerspectiveCameras, + OrthographicCameras, + ), + ): + raise ValueError( + "Only FoVPerspectiveCameras, PerspectiveCameras, " + "FoVOrthographicCameras and OrthographicCameras are supported " + "by the pulsar backend." + ) + if isinstance(rasterizer.raster_settings.image_size, tuple): + width, height = rasterizer.raster_settings.image_size + else: + width = rasterizer.raster_settings.image_size + height = rasterizer.raster_settings.image_size + # Making sure about integer types. + width = int(width) + height = int(height) + max_num_spheres = int(max_num_spheres) + orthogonal_projection = isinstance( + rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras) + ) + n_channels = int(n_channels) + self.renderer = PulsarRenderer( + width=width, + height=height, + max_num_balls=max_num_spheres, + orthogonal_projection=orthogonal_projection, + right_handed_system=True, + n_channels=n_channels, + ) + + def _conf_check(self, point_clouds, kwargs: Dict[str, Any]) -> bool: + """ + Verify internal configuration state with kwargs and pointclouds. + + This method will raise ValueError's for any inconsistencies found. It + returns whether an orthogonal projection will be used. + """ + if "gamma" not in kwargs.keys(): + raise ValueError( + "gamma is a required keyword argument for the PulsarPointsRenderer!" + ) + if ( + len(point_clouds) != len(self.rasterizer.cameras) + and len(self.rasterizer.cameras) != 1 + ): + raise ValueError( + ( + "The len(point_clouds) must either be equal to len(rasterizer.cameras) or " + "only one camera must be used. len(point_clouds): %d, " + "len(rasterizer.cameras): %d." + ) + % ( + len(point_clouds), + len(self.rasterizer.cameras), + ) + ) + # Make sure the rasterizer and cameras objects have no + # changes that can't be matched. + orthogonal_projection = isinstance( + self.rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras) + ) + if orthogonal_projection != self.renderer._renderer.orthogonal: + raise ValueError( + "The camera type can not be changed after renderer initialization! " + "Current camera orthogonal: %r. Original orthogonal: %r." + ) % (orthogonal_projection, self.renderer._renderer.orthogonal) + if ( + isinstance(self.rasterizer.raster_settings.image_size, tuple) + and self.rasterizer.raster_settings.image_size[0] + != self.renderer._renderer.width + ) or ( + not isinstance(self.rasterizer.raster_settings.image_size, tuple) + and self.rasterizer.raster_settings.image_size + != self.renderer._renderer.width + ): + raise ValueError( + ( + "The rasterizer width and height can not be changed after renderer " + "initialization! Current width: %d. Original width: %d." + ) + % ( + self.rasterizer.raster_settings.image_size, + self.renderer._renderer.width, + ) + ) + if ( + isinstance(self.rasterizer.raster_settings.image_size, tuple) + and self.rasterizer.raster_settings.image_size[1] + != self.renderer._renderer.height + ) or ( + not isinstance(self.rasterizer.raster_settings.image_size, tuple) + and self.rasterizer.raster_settings.image_size + != self.renderer._renderer.height + ): + raise ValueError( + ( + "The rasterizer width and height can not be changed after renderer " + "initialization! Current height: %d. Original height: %d." + ) + % ( + self.rasterizer.raster_settings.image_size, + self.renderer._renderer.height, + ) + ) + return orthogonal_projection + + def _extract_intrinsics( + self, orthogonal_projection, kwargs, cloud_idx + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]: + """ + Translate the camera intrinsics from PyTorch3D format to pulsar format. + """ + # Shorthand: + cameras = self.rasterizer.cameras + if orthogonal_projection: + focal_length = 0.0 + if isinstance(cameras, FoVOrthographicCameras): + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `znear`. + znear = kwargs.get("znear", cameras.znear)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `zfar`. + zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_y`. + max_y = kwargs.get("max_y", cameras.max_y)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_y`. + min_y = kwargs.get("min_y", cameras.min_y)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_x`. + max_x = kwargs.get("max_x", cameras.max_x)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_x`. + min_x = kwargs.get("min_x", cameras.min_x)[cloud_idx] + if max_y != -min_y: + raise ValueError( + "The orthographic camera must be centered around 0. " + f"Max is {max_y} and min is {min_y}." + ) + if max_x != -min_x: + raise ValueError( + "The orthographic camera must be centered around 0. " + f"Max is {max_x} and min is {min_x}." + ) + if not torch.all( + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `scale_xyz`. + kwargs.get("scale_xyz", cameras.scale_xyz)[cloud_idx] + == 1.0 + ): + raise ValueError( + "The orthographic camera scale must be ((1.0, 1.0, 1.0),). " + f"{kwargs.get('scale_xyz', cameras.scale_xyz)[cloud_idx]}." + ) + sensor_width = max_x - min_x + if not sensor_width > 0.0: + raise ValueError( + f"The orthographic camera must have positive size! Is: {sensor_width}." # noqa: B950 + ) + principal_point_x, principal_point_y = 0.0, 0.0 + else: + # Currently, this means it must be an 'OrthographicCameras' object. + focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[ + cloud_idx + ] + if ( + focal_length_conf.numel() == 2 + and focal_length_conf[0] * self.renderer._renderer.width + - focal_length_conf[1] * self.renderer._renderer.height + > 1e-5 + ): + raise ValueError( + "Pulsar only supports a single focal length! " + "Provided: %s." % (str(focal_length_conf)) + ) + if focal_length_conf.numel() == 2: + sensor_width = 2.0 / focal_length_conf[0] + else: + if focal_length_conf.numel() != 1: + raise ValueError( + "Focal length not parsable: %s." % (str(focal_length_conf)) + ) + sensor_width = 2.0 / focal_length_conf + if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys(): + raise ValueError( + "pulsar needs znear and zfar values for " + "the OrthographicCameras. Please provide them as keyword " + "argument to the forward method." + ) + znear = kwargs["znear"][cloud_idx] + zfar = kwargs["zfar"][cloud_idx] + principal_point_x = ( + kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0] + * 0.5 + * self.renderer._renderer.width + ) + principal_point_y = ( + kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1] + * 0.5 + * self.renderer._renderer.height + ) + else: + if not isinstance(cameras, PerspectiveCameras): + # Create a virtual focal length that is closer than znear. + znear = kwargs.get("znear", cameras.znear)[cloud_idx] + zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx] + focal_length = znear - 1e-6 + # Create a sensor size that matches the expected fov assuming this f. + afov = kwargs.get("fov", cameras.fov)[cloud_idx] + if kwargs.get("degrees", cameras.degrees): + afov *= math.pi / 180.0 + sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length + if not ( + kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx] + - self.renderer._renderer.width / self.renderer._renderer.height + < 1e-6 + ): + raise ValueError( + "The aspect ratio (" + f"{kwargs.get('aspect_ratio', cameras.aspect_ratio)[cloud_idx]}) " + "must agree with the resolution width / height (" + f"{self.renderer._renderer.width / self.renderer._renderer.height})." # noqa: B950 + ) + principal_point_x, principal_point_y = 0.0, 0.0 + else: + # pyre-fixme[16]: `PerspectiveCameras` has no attribute `focal_length`. + focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[ + cloud_idx + ] + if ( + focal_length_conf.numel() == 2 + and focal_length_conf[0] * self.renderer._renderer.width + - focal_length_conf[1] * self.renderer._renderer.height + > 1e-5 + ): + raise ValueError( + "Pulsar only supports a single focal length! " + "Provided: %s." % (str(focal_length_conf)) + ) + if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys(): + raise ValueError( + "pulsar needs znear and zfar values for " + "the PerspectiveCameras. Please provide them as keyword " + "argument to the forward method." + ) + znear = kwargs["znear"][cloud_idx] + zfar = kwargs["zfar"][cloud_idx] + if focal_length_conf.numel() == 2: + focal_length_px = focal_length_conf[0] + else: + if focal_length_conf.numel() != 1: + raise ValueError( + "Focal length not parsable: %s." % (str(focal_length_conf)) + ) + focal_length_px = focal_length_conf + focal_length = znear - 1e-6 + sensor_width = focal_length / focal_length_px * 2.0 + principal_point_x = ( + # pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`. + kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0] + * 0.5 + * self.renderer._renderer.width + ) + principal_point_y = ( + kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1] + * 0.5 + * self.renderer._renderer.height + ) + return ( + focal_length, + sensor_width, + principal_point_x, + principal_point_y, + znear, + zfar, + ) + + def _extract_extrinsics( + self, kwargs, cloud_idx + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Shorthand: + cameras = self.rasterizer.cameras + R = kwargs.get("R", cameras.R)[cloud_idx] + T = kwargs.get("T", cameras.T)[cloud_idx] + norm_mat = torch.tensor( + [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], + dtype=torch.float32, + device=R.device, + ) + cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]) + cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None])) + cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot)) + return cam_pos, cam_rot + + def _get_vert_rad( + self, vert_pos, cam_pos, orthogonal_projection, focal_length, kwargs, cloud_idx + ) -> torch.Tensor: + """ + Get point radiuses. + + These can be depending on the camera position in case of a perspective + transform. + """ + # Normalize point radiuses. + # `self.rasterizer.raster_settings.radius` can either be a float + # or itself a tensor. + raster_rad = self.rasterizer.raster_settings.radius + if kwargs.get("radius_world", False): + return raster_rad + if isinstance(raster_rad, torch.Tensor) and raster_rad.numel() > 1: + # In this case it must be a batched torch tensor. + raster_rad = raster_rad[cloud_idx] + if orthogonal_projection: + vert_rad = ( + torch.ones( + (vert_pos.shape[0],), dtype=torch.float32, device=vert_pos.device + ) + * raster_rad + ) + else: + point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False) + vert_rad = raster_rad / focal_length * point_dists + if isinstance(self.rasterizer.cameras, PerspectiveCameras): + # NDC normalization happens through adjusted focal length. + pass + else: + vert_rad = vert_rad / 2.0 # NDC normalization. + return vert_rad + + def forward(self, point_clouds, **kwargs) -> torch.Tensor: + """ + Get the rendering of the provided `Pointclouds`. + + The number of point clouds in the `Pointclouds` object determines the + number of resulting images. The provided cameras can be either 1 or equal + to the number of pointclouds (in the first case, the same camera will be + used for all clouds, in the latter case each point cloud will be rendered + with the corresponding camera). + + The following kwargs are support from PyTorch3D (depending on the selected + camera model potentially overriding camera parameters): + radius_world (bool): use the provided radiuses from the raster_settings + plain as radiuses in world space. Default: False. + znear (Iterable[float]): near geometry cutoff. Is required for + OrthographicCameras and PerspectiveCameras. + zfar (Iterable[float]): far geometry cutoff. Is required for + OrthographicCameras and PerspectiveCameras. + R (torch.Tensor): [Bx3x3] camera rotation matrices. + T (torch.Tensor): [Bx3] camera translation vectors. + principal_point (torch.Tensor): [Bx2] camera intrinsic principal + point offset vectors. + focal_length (torch.Tensor): [Bx1] camera intrinsic focal lengths. + aspect_ratio (Iterable[float]): camera aspect ratios. + fov (Iterable[float]): camera FOVs. + degrees (bool): whether FOVs are specified in degrees or + radians. + min_x (Iterable[float]): minimum x for the FoVOrthographicCameras. + max_x (Iterable[float]): maximum x for the FoVOrthographicCameras. + min_y (Iterable[float]): minimum y for the FoVOrthographicCameras. + max_y (Iterable[float]): maximum y for the FoVOrthographicCameras. + + The following kwargs are supported from pulsar: + gamma (float): The gamma value to use. This defines the transparency for + differentiability (see pulsar paper for details). Must be in [1., 1e-5] + with 1.0 being mostly transparent. This keyword argument is *required*! + bg_col (torch.Tensor): The background color. Must be a tensor on the same + device as the point clouds, with as many channels as features (no batch + dimension - it is the same for all images in the batch). + Default: 0.0 for all channels. + percent_allowed_difference (float): a value in [0., 1.[ with the maximum + allowed difference in channel space. This is used to speed up the + computation. Default: 0.01. + max_n_hits (int): a hard limit on the number of sphere hits per ray. + Default: max int. + mode (int): render mode in {0, 1}. 0: render image; 1: render hit map. + """ + orthogonal_projection: bool = self._conf_check(point_clouds, kwargs) + # Get access to inputs. We're using the list accessor and process + # them sequentially. + position_list = point_clouds.points_list() + features_list = point_clouds.features_list() + # Result list. + images = [] + for cloud_idx, (vert_pos, vert_col) in enumerate( + zip(position_list, features_list) + ): + # Get intrinsics. + ( + focal_length, + sensor_width, + principal_point_x, + principal_point_y, + znear, + zfar, + ) = self._extract_intrinsics(orthogonal_projection, kwargs, cloud_idx) + # Get extrinsics. + cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx) + # Put everything together. + cam_params = torch.cat( + ( + cam_pos, + cam_rot, + torch.tensor( + [ + focal_length, + sensor_width, + principal_point_x, + principal_point_y, + ], + dtype=torch.float32, + device=cam_pos.device, + ), + ) + ) + # Get point radiuses (can depend on camera position). + vert_rad = self._get_vert_rad( + vert_pos, + cam_pos, + orthogonal_projection, + focal_length, + kwargs, + cloud_idx, + ) + # Clean kwargs for passing on. + gamma = kwargs["gamma"][cloud_idx] + if "first_R_then_T" in kwargs.keys(): + raise ValueError("`first_R_then_T` is not supported in this interface.") + otherargs = { + argn: argv + for argn, argv in kwargs.items() + if argn + not in [ + "radius_world", + "gamma", + "znear", + "zfar", + "R", + "T", + "principal_point", + "focal_length", + "aspect_ratio", + "fov", + "degrees", + "min_x", + "max_x", + "min_y", + "max_y", + ] + } + # background color + if "bg_col" not in otherargs: + bg_col = torch.zeros( + vert_col.shape[1], device=cam_params.device, dtype=torch.float32 + ) + otherargs["bg_col"] = bg_col + # Go! + images.append( + self.renderer( + vert_pos=vert_pos, + vert_col=vert_col, + vert_rad=vert_rad, + cam_params=cam_params, + gamma=gamma, + max_depth=zfar, + min_depth=znear, + **otherargs, + ) + ) + return torch.stack(images, dim=0) diff --git a/pytorch3d/renderer/points/renderer.py b/pytorch3d/renderer/points/renderer.py index 4dc610a3..0f5f3120 100644 --- a/pytorch3d/renderer/points/renderer.py +++ b/pytorch3d/renderer/points/renderer.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - - import torch import torch.nn as nn @@ -48,7 +46,7 @@ class PointsRenderer(nn.Module): fragments.idx.long().permute(0, 3, 1, 2), weights, point_clouds.features_packed().permute(1, 0), - **kwargs + **kwargs, ) # permute so image comes at the end diff --git a/pytorch3d/transforms/__init__.py b/pytorch3d/transforms/__init__.py index 2f0a0301..8769c16b 100644 --- a/pytorch3d/transforms/__init__.py +++ b/pytorch3d/transforms/__init__.py @@ -1,9 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .external.kornia_angle_axis_to_rotation_matrix import ( - angle_axis_to_rotation_matrix as axis_angle_to_matrix, -) from .rotation_conversions import ( + axis_angle_to_matrix, euler_angles_to_matrix, matrix_to_euler_angles, matrix_to_quaternion, diff --git a/pytorch3d/transforms/external/__init__.py b/pytorch3d/transforms/external/__init__.py deleted file mode 100644 index 40539064..00000000 --- a/pytorch3d/transforms/external/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. diff --git a/pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py b/pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py deleted file mode 100644 index 1269813a..00000000 --- a/pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python3 -""" -This file contains the great angle axis to rotation matrix conversion -from kornia (https://github.com/arraiyopensource/kornia). The license -can be found in kornia_license.txt. - -The method is used unchanged; the documentation has been adjusted -to match our doc format. -""" -import torch - - -def angle_axis_to_rotation_matrix(angle_axis): - """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix - - Args: - angle_axis (Tensor): tensor of 3d vector of axis-angle rotations. - - Returns: - Tensor: tensor of 3x3 rotation matrix. - - Shape: - - Input: :math:`(N, 3)` - - Output: :math:`(N, 3, 3)` - - Example: - - ..code-block::python - - >>> input = torch.rand(1, 3) # Nx3 - >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3 - >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3 - """ - - def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6): - # We want to be careful to only evaluate the square root if the - # norm of the angle_axis vector is greater than zero. Otherwise - # we get a division by zero. - k_one = 1.0 - theta = torch.sqrt(theta2) - wxyz = angle_axis / (theta + eps) - wx, wy, wz = torch.chunk(wxyz, 3, dim=1) - cos_theta = torch.cos(theta) - sin_theta = torch.sin(theta) - - r00 = cos_theta + wx * wx * (k_one - cos_theta) - r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) - r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) - r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta - r11 = cos_theta + wy * wy * (k_one - cos_theta) - r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) - r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) - r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) - r22 = cos_theta + wz * wz * (k_one - cos_theta) - rotation_matrix = torch.cat( - [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1 - ) - return rotation_matrix.view(-1, 3, 3) - - def _compute_rotation_matrix_taylor(angle_axis): - rx, ry, rz = torch.chunk(angle_axis, 3, dim=1) - k_one = torch.ones_like(rx) - rotation_matrix = torch.cat( - [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1 - ) - return rotation_matrix.view(-1, 3, 3) - - # stolen from ceres/rotation.h - - _angle_axis = torch.unsqueeze(angle_axis + 1e-6, dim=1) - # _angle_axis.register_hook(lambda grad: pdb.set_trace()) - # _angle_axis = 1e-6 - theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2)) - theta2 = torch.squeeze(theta2, dim=1) - - # compute rotation matrices - rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2) - rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis) - - # create mask to handle both cases - eps = 1e-6 - mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device) - mask_pos = (mask).type_as(theta2) - mask_neg = (mask == False).type_as(theta2) # noqa - - # create output pose matrix - batch_size = angle_axis.shape[0] - rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis) - rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1) - # fill output matrix with masked values - rotation_matrix[..., :3, :3] = ( - mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor - ) - return rotation_matrix.to(angle_axis.device).type_as(angle_axis) # Nx4x4 diff --git a/pytorch3d/transforms/external/kornia_license.txt b/pytorch3d/transforms/external/kornia_license.txt deleted file mode 100644 index 261eeb9e..00000000 --- a/pytorch3d/transforms/external/kornia_license.txt +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/setup.py b/setup.py index ee678ee1..cf041287 100755 --- a/setup.py +++ b/setup.py @@ -13,13 +13,8 @@ from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "pytorch3d", "csrc") - - main_source = os.path.join(extensions_dir, "ext.cpp") - sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) - source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) - - sources = [main_source] + sources - + sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True) + source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True) extension = CppExtension extra_compile_args = {"cxx": ["-std=c++14"]} @@ -30,7 +25,18 @@ def get_extensions(): extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] + cub_home = os.environ.get("CUB_HOME", None) + if cub_home is None: + raise Exception( + "The environment variable `CUB_HOME` was not found. " + "NVIDIA CUB is required for compilation and can be downloaded " + "from `https://github.com/NVIDIA/cub/releases`. You can unpack " + "it to a location of your choice and set the environment variable " + "`CUB_HOME` to the folder containing the `CMakeListst.txt` file." + ) nvcc_args = [ + "-I%s" % (os.path.realpath(cub_home).replace("\\ ", " ")), + "-std=c++14", "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png new file mode 100644 index 0000000000000000000000000000000000000000..cb6b5cec17ca25ca53641773b883a6dbd75e58e7 GIT binary patch literal 1930 zcmY*Yc{tPw82&MdwZ=@j8HPfb7{??T86=q|N3M~@ICHF;Xi?*u$`O)sv=PQtHrH%s z-D0#MSsAP&!)PN%a*V;aW~`o`{^ZJ9E75=gjCYrKdlcAgyMY+!Zl^}^7an6_94#mm z_}Az4B&y9|R`fRMWkz~g(Dv3=Ta7U{^Y&;MZ zJD zQk1r&-Djg90je;z88f!_Y^Qa*Hk4E4iXVHuN7ecDYz~ufGs6pkLK=@0^wOLIb0#bX z*AD~Um0kJrDL=nIB$YG4P403<3h4ekqLZP^7(BZr>eajZ24&-Mz>MX6Gg&1+7Z;qF zKpn5P$K%*AHtA)P2zh)U{It(a%~rF8p@fb3BR~p*K2R?*GvJ%8e)wQTlL!Ecsh>We z*%nXFnDk%_A*V(hk~_>IhZBnSqHgYlhHFA3>xU*Q2p6>S^4)nBSFKVP8U5I6CcdYh zPC2pVE+Yk?I;;fTjcgojMUANQ^^#vf|GNP`XZPlBsY%68w$D2=BqMvr2q*WJHOn_I z`ROsN*q8nY${jNqG$6r8;2n~n0GyH^PQJx_CvwfiEUJ5)5{I{_c2v7kP!=D&Qo6|y zfWAgkHqBuOGXwmhi4YxL)^^%@!=hJ&tRLhiuc02QscvnW4rtKCQGaIj2hIn)QwHWN zLJQG9nT`aPFh%l%!|MmV6qRC%xTf@49pYXQo=lQZ;DMc_|1(^qrpI9sQQC`;Z_ED> z#pWKX-Wx>q`q!Ci8{|4AP4H7+r%3U~vgBSi6O-IbZfAzJ7-#+N;JeVWsAGqWjTF8| zD{YDtNhWcfT2FKhaFHR<>^g^;@FBc_Zapz{c=9y2qNiq}g;Tjpmh#%~LhXjIcK@QE z88I@m}m69Caq}}@mctIRJa$NeCq<4x?V5`-Cr_%8`Z4xPRP|1CY-yeFm`0>0pX=#rz}CSK7Y{B{@5tbPDQQdll}Ui_ zg&`TWe__50SxIYMGA6Z6{yr|zL!MzS8?h!g70iRdO$C}qoLz~xY1Yc{m zuXY;UrH_4yi;gH4&f0h=&(_60rLA<> zKje6?>K8=*^N$)J?g-;?%KB1o`Bip31usj11V6~0@t<`!c16$&UtNG(nHALTMjZ8Z zu3_{hor+^y2_pb;R^!ZUD@1{JGwAK+1IVd?M?$6rjaJ0a56}FUs&oY(@!uFkSa$nY ze}t&2A?gvs=S^xS4tfZrC}f{;xf5`2wdvy~KlYVpF51>Z)JlwOVC#BLmP;F2zcTp< zNe!)AoA6DO1=<}5v817}AXfRmNKzez(%0$CPjsWr$QuJG2EDBVZXH^Sc8Y7Wiul$O zU|SN>zX7mjQEj$YpEQ=qg3@Mbm$gQsBM8-Bs&gHu27$t7c9+Z0 zlwGRxM&_~F>1zTH@`@ErFThPcrrtMZ-MBjorx*p~uSG^lBWpUW;MZV;>W2T4(Vt3{ z?#TWGIX3p48-lI{{*8!(Y*Ki~w|oN& tnHBD0Qb5JNEvEHWY*LiId*K^zuR`*M=Qb`%Bx`?dV9(lGR$F-8{vSF`%^O;AF(E`M6kAbBd%q;I^Q-KqFo zNa7HJz1jPy@<+b0e*3widq$b&arC5G7Fu`UVqj}5h`*?aPH=ee$+Ugv5yLu`Lu2|` z+ChB@FiH=JYBJolq87c-`0kq+oF={c!VxVxd@Th-dgQ=%B}VQK?em!{Mb_8xQ#tr@ zPrev6Wn2sRCS-RIDv|mUkL4GUxFF~-i9BNf!%S~@;F6GD6#~JEM}bkL;|$h`XSPJZ znk26QHvFa)j76O5-Q*{B@8*YZc$MuoFJQ1Kxyq$ZGncus1AZZi%S>6+yYgDWN z1sUTtQClse4K$#H8~D&U`W+kQF+2{Znic_!9}7m2zKUXYtA8WQ%^fV1=gA%vTYn7t zi&T$O1y^GCR3MMYW4ja6Xla~zHIDOGO?E9vgW(QPgmXoD2DK%+n5a)G%$xkrH`UQ? zx%4pG!w!y8yJ?w&RUDDxhJNS zryI0AzkSCtK@B&By60GdN2{?-HRGA%lKhP zT3d6y-@6A7U}MEI%dy5wNxFpBzQuqTpY7qMusaU0vM0Ee&C4@Ms%AbfP2yBd0e2UE zj`#&?g!&qWUB4pG0LPXFnM8ZiR-__2NwM&IXq$QDyCI<6PEZ5ryX-goHc_>X-2>Hj zGcx-UHU2b@WfeTt-XGusC=qq-V zlCRJRvQgn=v4~8y?7svi!2G#L6okd8t+UvnEfA;dp2EYcov4WpqyC~!@ z+0HZfz$YRl-NWX;pkStQ(Wd1^?#*{GT_(`Xl9uUXpgg9%cM0()HJDJZFsFk&5oM4oVHY7I zBudNzgGM?A*pKh{*;@o~VUp~B#_BOJdPNjkl<0Y#ad7(CR8|5CudHyndmr{YL}$oY zz79k#QY0GNqbu!p>DD{7fj38=5Hu?lsFiXSxRL-m@Y+fT4p)5_ZXUga<5nc0cZzd_ zU|IOblhw^ zatqFonKj#m7IOp%|B{>7aFebbMo94WEM7YsmCs$zZ|MEI(6FJ9EMw?HLXd^g{F`FW z@br?KpWeVXUqL^9a6G{dl1**N%1>kOAFe?vhP8cSLdtiQ{y~-JF3lh88SW@=5a(r_ z=dv{v`zJ|L_9~Z@B`K;s;k&MjEjq`W1t$!H^q*21yy7f*nl2g!n^IIc_J)DMKj>u! zJMK!dNI6WEg>uxig=!Ma7%tB_=8ux*z2PN*n1d=hdFK_q*Vn8JRi2NY_Oi@$a@N(Du>Kb2!-6A?g(+X6QxrFt8{DurgW?Y-7@UtKhe=+wn ze*dQg7QgIjDPdWlX?cA@e9BGriGCYiid&C8U*e}Gy17QKoEtY?W+74sM%R}yac86y-xFfAGVafFh^+Z@+0W9pzt zYR)&m@wHk$*T*{?we~00yhy%!vBkFu``~K8S-Jvzc2!Qflyw|ON1ru|Yo@Knf&NGh zWjpI76`CD7p5L-LM!^i*32Om2)E(%_UPC@1l+E1ApXNM!@yQQOq7grVFl&LPa8BM* zIot3}j8nx7zoN%G^9-K)9o^Y*!PP-FbvG6|Wg8$3_pKEl@f@D0_HmUZtf6m;%Otv! z9^Q3$x^TR)WyVp)3)zzg_fI zICmjb@@lQx1r9iX%nh8P)PTTU{|zM{Pm#{^suNCMh@vMu=N?^9x6Rdzm>m1kOSaL`IC{6Hsdi_B! zj%bHw9zvYo{i;Cau4!6N^hplsu)w=h-um{3Ba408Cm;*eMx>A=VZXiMS^CXW>WL{` zcVmlsg67=;%V#>vPdC2 z;g&5M1rZ*vIO96YC56 zvjNigJFtr0x+vq~@anJZcmb5axz8WkeD6r)Ym@SK4le&AM1BsPDAvCm-w@{C0kP|b z)G1mn_Ci{;fv(AH0pk5n;P;&dVoafNAiYIj-DPk4-9pju*O=pJ+|lMD12|10_4dB+ z>WeX8a3qa(a7z7^!M1*1r-G?LeNpK`gXO7@NzVDTdmwxM3ujZIH+haiLz4xtLA}ONv zk93rwMLI%Op>NzZRtIL6Aww5*r;P_j(C-D2>!;w;B&Io z+rBvmFVI{&ZjMtEW;>t5qAaWAr>-s8>ndheEt6E7B*}kOC0h?T*#xh0^37HnVEtr> z@Ar#yJe}Sv4uAA+juP0LR-FvHVzE*cHrmn7&FVx0H>Mx(%MFj)DAp}_Q(JkHXZb?E zO22>mw1H>u4r8f1so{q~Ln2XXit5gkh0`28Ri1m^5MHtPK$b|G^#HnORn5Li5Bcr{3n6NW4F2*A6q zcirj<@0h^+Y5A;51X<)q%HZRIGNYhf09xi!PI`9Ye7-veDmu35(s{-gqgDgpH%6`% vGRQ{*pY9vd$ww>SYfWMPuhRYBknXYQT$Y$n#^0&D|K`+3-n(6^gNpqZ02nPY literal 0 HcmV?d00001 diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png new file mode 100644 index 0000000000000000000000000000000000000000..cb6b5cec17ca25ca53641773b883a6dbd75e58e7 GIT binary patch literal 1930 zcmY*Yc{tPw82&MdwZ=@j8HPfb7{??T86=q|N3M~@ICHF;Xi?*u$`O)sv=PQtHrH%s z-D0#MSsAP&!)PN%a*V;aW~`o`{^ZJ9E75=gjCYrKdlcAgyMY+!Zl^}^7an6_94#mm z_}Az4B&y9|R`fRMWkz~g(Dv3=Ta7U{^Y&;MZ zJD zQk1r&-Djg90je;z88f!_Y^Qa*Hk4E4iXVHuN7ecDYz~ufGs6pkLK=@0^wOLIb0#bX z*AD~Um0kJrDL=nIB$YG4P403<3h4ekqLZP^7(BZr>eajZ24&-Mz>MX6Gg&1+7Z;qF zKpn5P$K%*AHtA)P2zh)U{It(a%~rF8p@fb3BR~p*K2R?*GvJ%8e)wQTlL!Ecsh>We z*%nXFnDk%_A*V(hk~_>IhZBnSqHgYlhHFA3>xU*Q2p6>S^4)nBSFKVP8U5I6CcdYh zPC2pVE+Yk?I;;fTjcgojMUANQ^^#vf|GNP`XZPlBsY%68w$D2=BqMvr2q*WJHOn_I z`ROsN*q8nY${jNqG$6r8;2n~n0GyH^PQJx_CvwfiEUJ5)5{I{_c2v7kP!=D&Qo6|y zfWAgkHqBuOGXwmhi4YxL)^^%@!=hJ&tRLhiuc02QscvnW4rtKCQGaIj2hIn)QwHWN zLJQG9nT`aPFh%l%!|MmV6qRC%xTf@49pYXQo=lQZ;DMc_|1(^qrpI9sQQC`;Z_ED> z#pWKX-Wx>q`q!Ci8{|4AP4H7+r%3U~vgBSi6O-IbZfAzJ7-#+N;JeVWsAGqWjTF8| zD{YDtNhWcfT2FKhaFHR<>^g^;@FBc_Zapz{c=9y2qNiq}g;Tjpmh#%~LhXjIcK@QE z88I@m}m69Caq}}@mctIRJa$NeCq<4x?V5`-Cr_%8`Z4xPRP|1CY-yeFm`0>0pX=#rz}CSK7Y{B{@5tbPDQQdll}Ui_ zg&`TWe__50SxIYMGA6Z6{yr|zL!MzS8?h!g70iRdO$C}qoLz~xY1Yc{m zuXY;UrH_4yi;gH4&f0h=&(_60rLA<> zKje6?>K8=*^N$)J?g-;?%KB1o`Bip31usj11V6~0@t<`!c16$&UtNG(nHALTMjZ8Z zu3_{hor+^y2_pb;R^!ZUD@1{JGwAK+1IVd?M?$6rjaJ0a56}FUs&oY(@!uFkSa$nY ze}t&2A?gvs=S^xS4tfZrC}f{;xf5`2wdvy~KlYVpF51>Z)JlwOVC#BLmP;F2zcTp< zNe!)AoA6DO1=<}5v817}AXfRmNKzez(%0$CPjsWr$QuJG2EDBVZXH^Sc8Y7Wiul$O zU|SN>zX7mjQEj$YpEQ=qg3@Mbm$gQsBM8-Bs&gHu27$t7c9+Z0 zlwGRxM&_~F>1zTH@`@ErFThPcrrtMZ-MBjorx*p~uSG^lBWpUW;MZV;>W2T4(Vt3{ z?#TWGIX3p48-lI{{*8!(Y*Ki~w|oN& tnHBD0Qb5JNEvEHWY*LiId*K^zuR`*M=Qb`%Bx`?dV9(lGR$F-8{vSa& zicQHJ%kj*aV;IJ_=O6g~@_c^zd|t2j`~CX-^nSnINw!dku)rAs001CtVQy*<0C4?2 zxd43Jzt@N;SQh{g?6EMt;TVHoq5sJFsUc+Y#e#Nlq^`a>t&l90CEDaD6ae0;gc>zE zr0GcCIArtw7XH_%Dd-KG>Nr2Sw!?b=v-SXSh;Fa`bX_4sj{`>j5kWHg6F%Zj`e=;AAqXte!t3i zsu~(uV*rClq9ylWL*NN9u3zVn^$>DlyGG?^-L9GDMif-7t8#Z`z%G5q`8j5T7dt{0 z8-7vef_NK}nEX_|gNOF9ZH5nHXj)JpGI=K~=`HraJ~-{{@hj)~_<@}z zSU+>2_b_}FYw3{Iq?oAQVU3PEyPi=|7z7``0t0Y+s%d2qof?)V96u9cx88I1f9Tg_ z)EN(Idd)4num-FFyz4wN-f7zgM^`65XAEGS^6zwiNC3!Wsmmvl^>szJ_KI*-rG;>^ zciS8b<%$d&5(3C$1fx633DzmMU2QfUogh!K!V2oOO2dDZyYE~*kc#TO8^7V%smZA@ z0BT8ok|{qCFJpSDw~3TSE;M{pMqAlNewaI2YED;euK1_?m}%$Uw-!Y!pna8g=(t8t z*x40{6Z?;|&W|QvVeV!iXgznn7aX=_2iUgCG2Cpl7|H~(#-t}Upbrx3TXm(8^KQQC zOZCj{uiw?AY}zZ-PX6$Z=jWb=Lc!3A0A0GYg+)>a!v3?3XM zgJ%%06;Rjv)4O5O6%QbWp|NA=_%a^M=+*g=7GFD|uT3PUVSSFHbpJrs#aZiLD_+mb zzDd}~?v_JtmjIC?ey<yG)G85@lvL@&tkgno9k!5xv`)=3EJ|#Jp%CacgJc(>Vd(+Jz^?HfqY~ZOj@CB zvsLP6m9-b!&44NO{w1XuE2ilWfJ^z=Xq-v%<(HJOnJ}$P>O>S14prlG6f_^1VU&xl!rXMTFOxR?5?0Zjle1!o9$}Zluu69YNTGMLaSS(!$-hv-s{m-~M zEmy@zI7A?&ANL6KmEfsA*>3blX^DqB`y9%Kntw>|qs!NNIrt#I`!+_ei*AMKfk+~P zT)%wiqA`>9Yj>JYF<&3jG**arT=gj|*Zz}%JIQ4qhg9hVDQfMP%A_|Vrv)itGIDg@ z@N3#~_l$)oba4IdjE|B}Ls3g^UiwKuws~Y=R+CGW>M#02Y7r4=JRZf0NRbGZVyEkd zhXQ7;e286|Im4B_3!~3QJmDF)$186ZWuCmr3Kqaecucyp?oTmuRvOP&M#7K*9eZ)Y z^EV85cj#WPMMU$W2&$%X!m4}cxs|__L9SD%XwF`%xns6B{iw zN1Vs;iShXoa;#E9tF^0DT;u#%o=eWxSC&}GDsSiFB>5$VWEX}Qu^DvGa4|D)hbY}J zKk6Ssv5z%1RbxFAIEFKk^z)zzWUW5+3WkQDy8MHZ*hz;}y%$C-0X~dk(50FSqO@uM zUbrfoAN;~>5jJT^Urmc#{*4i3jgy35pQzUQV6S{-F!q2|#2{6*hsFbTZ*p)#4^kUK zgKfv-n(#-sALFFbz1cAq0`N^uVVgN96mk1(J!Sn3^MsTeRA_m=J_gyBe%CL9qyxi% z@y{qgqEVaF)^k69zifA>)|jruJ@>008Y`BajCj^B*ciDJyW(EsO z&-K|d4&~90LGP*b?+yO13che?fy?d}vj>3b`Py};Fr$v6k^-rbU_j-9zvqsryV5dOzeDLe7YL}Dq9-(-acJA+ZrYLq?t<@A( z!>8dNC){k5R05QWJ#tK1&nwmWXFme8l(W&a?3^G1RO$zt$U~D-ADMjsevAcp{y%EJUe}apkk4A;+3O<;gY;yV+f+Mc9 z1?quK15EI?+7zglB$s$IujSFdAA-ChEDP&9imP_x>G0^e(B6Ufc6^v~bBt5jK;`-8 z5L8Kn*B#aeyv5sHWFrJHFzcHIAN!M_#n?1J8Y^VVk7KpQ_~He*bTp@gd-{UXcR;Pq z-MJi5f`|eGN~8Y={?@W?vo3ap^cvAhuIagTv+enn_d|BgMh#z>K!B z8D1$kZ1lG|-BVDC+w-rWLhK%?GzW%EWqmH*bGEyFLArXMvvNZ@({lyU71`0XHka~B zyv5tnz&#z|)@3B$UcAPbh1=qsrxAN;*tX;UwKM-eJtluwgw6}!?t!Ul{mZks1vMr8 H;fehpg$3Fc literal 0 HcmV?d00001 diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png new file mode 100644 index 0000000000000000000000000000000000000000..507fc204269ec10cf61875a1a938722094fb6fed GIT binary patch literal 2112 zcmY*ado&Y@AD>6w??@3VnI1@zN8Up6XdV&QBQk{;!wh8&i4s<==9kBnS8|0CHuEZ* z<+dUfD`Re6V;;+fdF*lPo^yZq^!exW`R9Ai=leOIb3V7dz%KGK8ZrO?K;HF?lMetO z@t515SaBctn!e`)7RGS-6_Vk zpT?tQy+P9X?t?CeSlnwo?q>P`zn8OFx_wx)Yk~!|60W3*GlkS)`3)!LgfWoaB$kZy zf+$jLfGwP55R{p|MrXc!}K|smmuwO#Us)pmef(2Kl5Eb=zU>vv4J_ zRJLntF+(8v!{IOZLd%^AX7|Qk>4jAkRNx86#{!B?P)iqgj|h~%^zFWuW^d+J1P&|$ zB)zbeBaW))7}8lO2Q-R=Wopu1a=guah#F~=`Mut8m0{P+Tf@eJSFZNew*^sySQh1X z2m3mCn^hhgk#?11r>2sAwgGrnr$zdrs>ENBuI zEjn#uiZeV#qL(gq%-yy*I8MH$%}>hYzAgzbH^|fof;$&YC^K2RIg}DXp|R&Y25aB` znCNPQ6}P!Nvz3~EsR`TUo}mid)KxkWH5(^P?qT^MgB=)|*5BDQVLdpwoRM15z7xpR zs>~Un6Tec5v=MB-ZF$u2>!YfyKW`$3St2QWmE7ZIewN`MIJFROcp)DY@6Vph99gP) z+GZ=i!8?%tK z%T%Ted-FoMgQD(W_KL^2nWRrC`u+79lLW@2n8fD>mi)>m(1&PS+>+2WZID63Rioms z`S?5)hKQ=&gVV+gr?s>|a6NMbcnSURT^ioyw;_3=7x?`Cqv5*{Gvmb}F)R*dOlVK` zX1_tum+YC3+B#8*xLw@ZLcvuz%P#NwvxWD)1AiAXP|`5Yi085sR`wpD?@FX^i!~`T zJ1WAFDHN*wHCvKC(;#EHZmo@`F04%DX4}59;88>n@DC;6 zlFJzhb>~_1e%(~B341YlV>nbr(mAazUa-1;kYP1E=O*S#!n;Ny z<5W_6CYt;Y`k55!8zeVW9%oFg1YCPNqLowyBqoO<4YZ8p2p$wseMo5pd)lbpI*MZT7IW6v1%2d_IJ7k&#*aSXOuLv=p$tKtTht)cq&0M=CuM9AS! zF(3OsZ)lmhA!+}U3|~V%7mtd~#WF8R7H1I+%g){Aoo?XS^MOVpdTNoNY@OxYi?<$1 zd?Od1{I>HkU56>S=WC=wzDqz+4y+@B7R7H_>==CSyFW3EQrO8=H4tGJB;?hg1Xk*; zaTq0c{MCcx$?$z7?`fS@)4HU6{JGgKJ2^*sFF6(#*D-RKE=ti%=U}N>rslWt8+ALw z2Kb7)Q9^2#VdnkE`sLvpO{o2EuU?HaR0a2+eq*XWF{G}NFhBRQ_00x%8#?&aWPai0 z4A)ZAA5-d>5*n7jHmN?SSH){pOjn@SeqJaV#?ZC_-ySSd%ienn$uAF zN)FaGAmi>FDS_Vs5&G%As>?fNkMZnt%`xE<2qK}iE7G&4TMsLQOW^4d8ZZTEm^_hp z_DsgP8Rs30#2?ps(y%r}ddVK}<>&48RhDH7U;}`-1axnGT#rNH>25h0_ zG5@P}zI#x45EyyN-h3ICk?&ySR?NTJnc^XUK$7H8*f-YW5`G|2ACP$PZ2|K61Og+- zJ~$W4kGa)+P=$vaQAgPlPtHuDdmrrW1zdT%G5HD#ay}U*4KKMWy(@+v3#_{M{AzC z7*Uk#0Z!W&mZ)%mvq$Hy0RtUvuzj7w6=4?sA*;|m`Zgd}#id4Z!Yx;hKlaQ}Br7>rU=Qz#p5bzF$HjkW81CI`)Nm&XaXLDS3 zx7GmF;lZ#4!TYaMuC~uF+P)X)8rB^Oy;i$#0E4@MmaN%-(7(`X?t-MlSIQ|(i}3v^iXAM3yDzeL2CPnYO*cTG)= zQT2S{2uk#R?UU#YOC!|r<+SFfm33UFs;GIIql+mbp@;c{O&_O0$jV64BOs&OiGW6k zNArh-ViN5eh1S#hE^eH@gtB)8?OMy8J~r^B(}lPD!=kqNRh?7>Rl1<-8Vy*&Kdk5k z)KvFT-0u!QXQX5%^%Zq1Fdl^D>m;PE0;*pgtZ#?Ekx2$l%X(Lm+xM_e0J=ZRgS5sL zuzkovKO+GHT*8zSJQ1a|_~_mBwW9Z|os{JGztQmqfWdEk-RXTt%pt}geN45RDq(wL zRZ_YHpOH$LDuAN0E~^;yAUfJ~HH z-M~%_rBmrdbmVF5Z+i@_IfKyJAn)K=>zM&HdW5%0@|IH{uW1X9nYRF)r+I#=qh+Kw zPpTPwq!WCDZm3dmZ37=ZXRFk1GH~gmbL<^^z`(z$YfhdY{`N`f(5rXnY0#O`0-DqG zGrg&ZBiSXgS!Ze$BCLArpU-LSA1VtuJ0D)WW#831z<0&Lrp~KGX~flZQ-d_w&S`8Q zd9=m2ha6u#B3-qAIcDln&%l|>R~{du}=LG2sYHoYlDI-BPa;h8T}C*852s7Q|=V>nYPg;_eOlbdK4^ zjM=2XYYQwPdQ0-UU_gVBzHLrFsQ%xuJ#wbMv*d2aACtYsOHUJ|ZQ7=O zfV!RHc+nD)#$P7C&r@QxOPVpKvrpxS z?<7`Tt)V}V^(!SA*1?Xk_WB`#bJiX7l$$gzP)`)C`n%4;^gegFwoMqim;=b{`xiYw>*$H47ytjk$NOd_0LdT-N|m76zuS>FSM{{5m(U@-eiIBZ=7d^^laO%^}ax%S~JXq=$k6r93JmucBEhhgq_P6Z(!bk>ZVg~-AjgS=Nk2x!KREwe@`ahw6?7xa` zb}y(vvx}$~=$QDUpZ>%!tm$!3VP}ICO}j-qRw7{ybMSG7U=I0fkh{v+$FudoO80q% zQex{VxYzo_!s~)eX``BtZguqWz%*78*Old9A-SY7kJYQece)4BhTEc}SC#zl4tT#O zrw}~*u#8Ai+rTMb^=_nW9P{kbWeRuX$ZKgM%O2h~`s8cw#mHkcw4w`?ViVgg?IL1OH|1@+DSM_VGDlJ#mq9$wXbe z1y?lPv1E~QcpO1ZN(Z_0ZCS;rKscj$e>ikIcrms|V^z9Pwi=KnK+jI?mUCXWE7|a} zC#tkE%WK$5twx3h2zct7M^hkLQpQaBO;hYogwMPDunxxBgV&Mah$$7SxZXLzUY#oB zwty>!gFgYjZ7^GQ{CAX;VDYMlK^3aHyb9XD@3@eV{% zd-l%0TwIaA)}9rCEIn(6`&9BtSo5S)6uQx{$8H{FmXy1eqLZ2N5b(C8K>m4~$>#jY zu8UJQ1*M!t{0bvgsLL(AlUMXc3&K~3hy7A+wZ86&cf-B2CqDquHZSKz_M5`)iY6{y z^2QtMT!shIfBwVTRQz)+6p&8hjnc_tB?H6I!grG$LZ^S>F7IB*&qq!yBaFJ%J;ZV} z5N~cPsNi1l+YU+mj6AMWQOBjd+-H_7hH!tbia8AE;r%O38M;I<2>~-@C7U~`no=^o z8d0bh+}VT4gNK&WcI#VpVoG_WoW`Z;7=H7%TCur_R|=U(6dnN$Tnw!;=J&4bZ02G@ z!bJ*}$oxD8eakm>YM@w`WdtS#c{hYT)d4v*;()jm69PVSmaj0+|GJ8}pJl!iAev|E zu}2oeMk9d2!Ns{ei!+rl_h#O#6^$?!Qik0^V277~oBI;3MM)iGyxJod13j((j_FhT z!?c<>uI+~h4a`9$Cp&yk>*H6?8OJic0+C7SN%DWzG(K$8dW~-z!+S#fw@b2jZiLiM|72!tyPd&C-69>%Wm&O>K9| za(~`ceDq6-@OfNDH}*~drHt8ffUp#q;DrGLmFC7^z!bc09$qvl=?I+K<+>~^EmNmw zOXD}U68z^>7rXus f%KqQ>#=7wOQ{(rrABM_*o+`l7+{UcV#69jm*_?o6 literal 0 HcmV?d00001 diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png new file mode 100644 index 0000000000000000000000000000000000000000..507fc204269ec10cf61875a1a938722094fb6fed GIT binary patch literal 2112 zcmY*ado&Y@AD>6w??@3VnI1@zN8Up6XdV&QBQk{;!wh8&i4s<==9kBnS8|0CHuEZ* z<+dUfD`Re6V;;+fdF*lPo^yZq^!exW`R9Ai=leOIb3V7dz%KGK8ZrO?K;HF?lMetO z@t515SaBctn!e`)7RGS-6_Vk zpT?tQy+P9X?t?CeSlnwo?q>P`zn8OFx_wx)Yk~!|60W3*GlkS)`3)!LgfWoaB$kZy zf+$jLfGwP55R{p|MrXc!}K|smmuwO#Us)pmef(2Kl5Eb=zU>vv4J_ zRJLntF+(8v!{IOZLd%^AX7|Qk>4jAkRNx86#{!B?P)iqgj|h~%^zFWuW^d+J1P&|$ zB)zbeBaW))7}8lO2Q-R=Wopu1a=guah#F~=`Mut8m0{P+Tf@eJSFZNew*^sySQh1X z2m3mCn^hhgk#?11r>2sAwgGrnr$zdrs>ENBuI zEjn#uiZeV#qL(gq%-yy*I8MH$%}>hYzAgzbH^|fof;$&YC^K2RIg}DXp|R&Y25aB` znCNPQ6}P!Nvz3~EsR`TUo}mid)KxkWH5(^P?qT^MgB=)|*5BDQVLdpwoRM15z7xpR zs>~Un6Tec5v=MB-ZF$u2>!YfyKW`$3St2QWmE7ZIewN`MIJFROcp)DY@6Vph99gP) z+GZ=i!8?%tK z%T%Ted-FoMgQD(W_KL^2nWRrC`u+79lLW@2n8fD>mi)>m(1&PS+>+2WZID63Rioms z`S?5)hKQ=&gVV+gr?s>|a6NMbcnSURT^ioyw;_3=7x?`Cqv5*{Gvmb}F)R*dOlVK` zX1_tum+YC3+B#8*xLw@ZLcvuz%P#NwvxWD)1AiAXP|`5Yi085sR`wpD?@FX^i!~`T zJ1WAFDHN*wHCvKC(;#EHZmo@`F04%DX4}59;88>n@DC;6 zlFJzhb>~_1e%(~B341YlV>nbr(mAazUa-1;kYP1E=O*S#!n;Ny z<5W_6CYt;Y`k55!8zeVW9%oFg1YCPNqLowyBqoO<4YZ8p2p$wseMo5pd)lbpI*MZT7IW6v1%2d_IJ7k&#*aSXOuLv=p$tKtTht)cq&0M=CuM9AS! zF(3OsZ)lmhA!+}U3|~V%7mtd~#WF8R7H1I+%g){Aoo?XS^MOVpdTNoNY@OxYi?<$1 zd?Od1{I>HkU56>S=WC=wzDqz+4y+@B7R7H_>==CSyFW3EQrO8=H4tGJB;?hg1Xk*; zaTq0c{MCcx$?$z7?`fS@)4HU6{JGgKJ2^*sFF6(#*D-RKE=ti%=U}N>rslWt8+ALw z2Kb7)Q9^2#VdnkE`sLvpO{o2EuU?HaR0a2+eq*XWF{G}NFhBRQ_00x%8#?&aWPai0 z4A)ZAA5-d>5*n7jHmN?SSH){pOjn@SeqJaV#?ZC_-ySSd%ienn$uAF zN)FaGAmi>FDS_Vs5&G%As>?fNkMZnt%`xE<2qK}iE7G&4TMsLQOW^4d8ZZTEm^_hp z_DsgP8Rs30#2?ps(y%r}ddVK}<>&48RhDKBS=VStZ(1{;CtTt%YA-$p8LZ+=bm$)g!{(2oa}<^00406>*<&Q0K?yr0bphP zOW$}jrO<(8sgRm^p?98j35-zPpeXsVDD~OS&Ws}J~+N>MFU39skV||I|*-iFK)6YBZT&Fnff2!B3Sbb&(5Qzze7dGHmM)GDTiw*cJ$BDn2;4 z9b7MC#c7Yww`wjP366 za#GXGjX+VAqMeWSET9^lgVK25K2cINY2;1 zl37SLEC zFo1++2E|OF%r^o2FejSI+_{twu2bY=E$v=DrKzbY5kCY+3_7^FtHi7tY6TQUPRFy0 zb_EGyy5CRU9ByWM_A(tuy?&3Y2(RG8eg4+#bYylcE+nlb&gHDib)fRSq3Kt;u2ZFL zSNhBI1>50LP9a@*)R(qQ&O6H`{3_P~gW2}c1fQATgpke7*PgmZN0Nwp1>Q79BvLf{ zEYmxo$a|_Q21|uO8i`~IWz%lF=Q>dJy z%xL^7NU%k+-PPHvT=J+s5y!ZY@!b^P)pg}5zxHRsvo3}9@#8CqoYfOZ6r)-<-3!4H z{3d+7`&q^;sIT)aaNZyB$x<69H}Vj_ev#S!q-_jxPpo58NV3`g@|F;`d3H}J-pm$) zkP+D{gvO>S9KJT+G7oGsEgH?01*XNz$-aMHPagWqHZSCJ+-msh%RMR*PJVFrFgJ@2 zL%OUXfKeXu8z&v7Su2_RnL|-X0V^yL?7449TuxOE2rg3hjX5Ih7gdpIJ)Ygpbk?V( z;L0%q){+DDu^%%in`pAss4gsP^X?ulwXd(@GL)6OcVITP31Dhh%T))|Ea=6h4lhv? z19e%dHvXrpSC$DdHMLw5VG%uR`nNd-ubeR{Du)fIB@CMHw&0#Wr(a|&9smlIBkPNe z>Adywnw}4UK=Bt}(+!k`ftLTnHI* zD7B*$ECP#wz+l+l9#pURj!HLi^)k=3{{^R=B6EVo)?Qvn#c45)pld%df`L^Q}yyn{0G7qyhh5M6ZHh0VWd5cF8vRF#{ z@nF+d^R~5TRYikX&v2sT335+0ttE}*kc?|sT)96W$UMJ`RZaU@G54sThEvC1Wtkdr zD|&7TbRPk0;$O|}iBAtZ$ctqNki3?M!B0V=;mIdkJRy-=F5cK@Ww=?#U=seX&np8t z8pH|T%z~Wi;@1U5%Oo4pwzAtl=roZ2n=sQ<z(h?)BSd@i^pU?b@&2cUT3@fL+&#ylCbji-m^{2;QJ zw%jz9mOfQWi%N=R5Ze_5sMSUEpnGd@4#N91XIcsQ%!$yL`j1RnmX>g50b6n-xx28?2GIgQwy3|t z$HYbE>CWI$f{J0da$l&%aSve@<`_LMa$p~ql1sgmO@mlZaM50}p|^P2%;z;M8(+z2 zAhWi~R;YLT>#%huJ9Xvtxvc1YBjsVttTHWM#CiyIR<_{{R`btDVSUQ<=&z|ZuvIyY zqAZnMP?w-@Ydo`Q3iDmx%L>fC_f+tAG1t5s+|@|!UAWajFDtQX0L7`M3?x^~9+q_^ zADuf7t;ohKstbI?%UXqJ0eV;~xfvD@UC(Ky&dS)<`qt^KB~L=E$C+ByrOPQGRH%n` z!PGV^L7@KkMp?BZk|@0JEa~Rj@Z?5Ov)`{8YqNtC%U|9De8OgQZgzBreC1NP%89la z`lTBzYpytkHllhRAO<(!$Ohz{1ZmWXbhS_3T$fiV)s5v)n5ipz>NbNDh?D$%xBLDe zy+4!QSA$hp)Q_9{llyJI36{fZj6YX$r6!dot&z)s-Sqg$(j!+jMExsZf{Sv7 z$vDzbt_KnnL;fI(ug-B()*XV$E-3`oK!4KB`0Jcq)y;>_QAAIjW&n#R**xeCI4sCt z15UBC&T%Uw*Z=}vnxrobjTFXnmx|mN6lUVjD1e@qb zk+(K>uR^&XymvF{jwhZ6^v^7z1uDMSDp8|6IYoVEXp;79tsgWFw~ui0X!*tiu-B<1 zK6sY)L}UDesJSv}(b4i|)E zR8m`of1s7!^HZt=BQa+`*&UcuCY#~wo}ai07zsw5dZ)H literal 0 HcmV?d00001 diff --git a/tests/test_render_points.py b/tests/test_render_points.py index 81f35dee..e400745e 100644 --- a/tests/test_render_points.py +++ b/tests/test_render_points.py @@ -16,6 +16,8 @@ from PIL import Image from pytorch3d.renderer.cameras import ( FoVOrthographicCameras, FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, look_at_view_transform, ) from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum @@ -25,6 +27,7 @@ from pytorch3d.renderer.points import ( PointsRasterizationSettings, PointsRasterizer, PointsRenderer, + PulsarPointsRenderer, ) from pytorch3d.structures.pointclouds import Pointclouds from pytorch3d.utils.ico_sphere import ico_sphere @@ -72,6 +75,145 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): ) self.assertClose(rgb, image_ref) + def test_simple_sphere_pulsar(self): + for device in [torch.device("cpu"), torch.device("cuda")]: + sphere_mesh = ico_sphere(1, device) + verts_padded = sphere_mesh.verts_padded() + # Shift vertices to check coordinate frames are correct. + verts_padded[..., 1] += 0.2 + verts_padded[..., 0] += 0.2 + pointclouds = Pointclouds( + points=verts_padded, features=torch.ones_like(verts_padded) + ) + for azimuth in [0.0, 90.0]: + R, T = look_at_view_transform(2.7, 0.0, azimuth) + for camera_name, cameras in [ + ("fovperspective", FoVPerspectiveCameras(device=device, R=R, T=T)), + ( + "fovorthographic", + FoVOrthographicCameras(device=device, R=R, T=T), + ), + ("perspective", PerspectiveCameras(device=device, R=R, T=T)), + ("orthographic", OrthographicCameras(device=device, R=R, T=T)), + ]: + raster_settings = PointsRasterizationSettings( + image_size=256, radius=5e-2, points_per_pixel=1 + ) + rasterizer = PointsRasterizer( + cameras=cameras, raster_settings=raster_settings + ) + renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device) + # Load reference image + filename = ( + "pulsar_simple_pointcloud_sphere_" + f"azimuth{azimuth}_{camera_name}.png" + ) + image_ref = load_rgb_image("test_%s" % filename, DATA_DIR) + images = renderer( + pointclouds, gamma=(1e-3,), znear=(1.0,), zfar=(100.0,) + ) + rgb = images[0, ..., :3].squeeze().cpu() + if DEBUG: + filename = "DEBUG_%s" % filename + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / filename + ) + self.assertClose(rgb, image_ref, rtol=7e-3, atol=5e-3) + + def test_unified_inputs_pulsar(self): + # Test data on different devices. + for device in [torch.device("cpu"), torch.device("cuda")]: + sphere_mesh = ico_sphere(1, device) + verts_padded = sphere_mesh.verts_padded() + pointclouds = Pointclouds( + points=verts_padded, features=torch.ones_like(verts_padded) + ) + R, T = look_at_view_transform(2.7, 0.0, 0.0) + # Test the different camera types. + for _, cameras in [ + ("fovperspective", FoVPerspectiveCameras(device=device, R=R, T=T)), + ( + "fovorthographic", + FoVOrthographicCameras(device=device, R=R, T=T), + ), + ("perspective", PerspectiveCameras(device=device, R=R, T=T)), + ("orthographic", OrthographicCameras(device=device, R=R, T=T)), + ]: + # Test different ways for image size specification. + for image_size in (256, (256, 256)): + raster_settings = PointsRasterizationSettings( + image_size=image_size, radius=5e-2, points_per_pixel=1 + ) + rasterizer = PointsRasterizer( + cameras=cameras, raster_settings=raster_settings + ) + # Test that the compositor can be provided. It's value is ignored + # so use a dummy. + _ = PulsarPointsRenderer(rasterizer=rasterizer, compositor=1).to( + device + ) + # Constructor without compositor. + _ = PulsarPointsRenderer(rasterizer=rasterizer).to(device) + # Constructor with n_channels. + _ = PulsarPointsRenderer(rasterizer=rasterizer, n_channels=3).to( + device + ) + # Constructor with max_num_spheres. + renderer = PulsarPointsRenderer( + rasterizer=rasterizer, max_num_spheres=1000 + ).to(device) + # Test the forward function. + if isinstance(cameras, (PerspectiveCameras, OrthographicCameras)): + # znear and zfar is required in this case. + self.assertRaises( + ValueError, + lambda: renderer.forward( + point_clouds=pointclouds, gamma=(1e-4,) + ), + ) + renderer.forward( + point_clouds=pointclouds, + gamma=(1e-4,), + znear=(1.0,), + zfar=(2.0,), + ) + # znear and zfar must be batched. + self.assertRaises( + TypeError, + lambda: renderer.forward( + point_clouds=pointclouds, + gamma=(1e-4,), + znear=1.0, + zfar=(2.0,), + ), + ) + self.assertRaises( + TypeError, + lambda: renderer.forward( + point_clouds=pointclouds, + gamma=(1e-4,), + znear=(1.0,), + zfar=2.0, + ), + ) + else: + # gamma must be batched. + self.assertRaises( + TypeError, + lambda: renderer.forward( + point_clouds=pointclouds, gamma=1e-4 + ), + ) + renderer.forward(point_clouds=pointclouds, gamma=(1e-4,)) + # rasterizer width and height change. + renderer.rasterizer.raster_settings.image_size = 0 + self.assertRaises( + ValueError, + lambda: renderer.forward( + point_clouds=pointclouds, gamma=(1e-4,) + ), + ) + def test_pointcloud_with_features(self): device = torch.device("cuda:0") file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"