diff --git a/pytorch3d/renderer/points/__init__.py b/pytorch3d/renderer/points/__init__.py index b334f4c5..f0ee9b02 100644 --- a/pytorch3d/renderer/points/__init__.py +++ b/pytorch3d/renderer/points/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from .compositor import AlphaCompositor, NormWeightedCompositor +from .pulsar.unified import PulsarPointsRenderer from .rasterize_points import rasterize_points from .rasterizer import PointsRasterizationSettings, PointsRasterizer from .renderer import PointsRenderer diff --git a/pytorch3d/renderer/points/pulsar/renderer.py b/pytorch3d/renderer/points/pulsar/renderer.py index 0e86d6aa..5f30f648 100644 --- a/pytorch3d/renderer/points/pulsar/renderer.py +++ b/pytorch3d/renderer/points/pulsar/renderer.py @@ -369,6 +369,7 @@ class Renderer(torch.nn.Module): height: int, orthogonal: bool, right_handed: bool, + first_R_then_T: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -401,6 +402,8 @@ class Renderer(torch.nn.Module): (does not use focal length). * right_handed: bool, whether to use a right handed system (negative z in camera direction). + * first_R_then_T: bool, whether to first rotate, then translate + the camera (PyTorch3D convention). Returns: * pos_vec: the position vector in 3D, @@ -460,16 +463,18 @@ class Renderer(torch.nn.Module): # Always get quadratic pixels. pixel_size_x = sensor_size_x / float(width) sensor_size_y = height * pixel_size_x + if continuous_rep: + rot_mat = rotation_6d_to_matrix(rot_vec) + else: + rot_mat = axis_angle_to_matrix(rot_vec) + if first_R_then_T: + pos_vec = torch.matmul(rot_mat, pos_vec[..., None])[:, :, 0] LOGGER.debug( "Camera position: %s, rotation: %s. Focal length: %s.", str(pos_vec), str(rot_vec), str(focal_length), ) - if continuous_rep: - rot_mat = rotation_6d_to_matrix(rot_vec) - else: - rot_mat = axis_angle_to_matrix(rot_vec) sensor_dir_x = torch.matmul( rot_mat, torch.tensor( @@ -576,6 +581,7 @@ class Renderer(torch.nn.Module): max_n_hits: int = _C.MAX_UINT, mode: int = 0, return_forward_info: bool = False, + first_R_then_T: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: """ Rendering pass to create an image from the provided spheres and camera @@ -616,6 +622,8 @@ class Renderer(torch.nn.Module): the float encoded integer index of a sphere and its weight. They are the five spheres with the highest color contribution to this pixel color, ordered descending. Default: False. + * first_R_then_T: bool, whether to first apply rotation to the camera, + then translation (PyTorch3D convention). Default: False. Returns: * image: [Bx]HxWx3 float tensor with the resulting image. @@ -638,6 +646,7 @@ class Renderer(torch.nn.Module): self._renderer.height, self._renderer.orthogonal, self._renderer.right_handed, + first_R_then_T=first_R_then_T, ) if ( focal_lengths.min().item() > 0.0 diff --git a/pytorch3d/renderer/points/pulsar/unified.py b/pytorch3d/renderer/points/pulsar/unified.py new file mode 100644 index 00000000..08b10af6 --- /dev/null +++ b/pytorch3d/renderer/points/pulsar/unified.py @@ -0,0 +1,524 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import math +import warnings +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ....transforms import matrix_to_rotation_6d +from ...cameras import ( + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, +) +from ..compositor import AlphaCompositor, NormWeightedCompositor +from ..rasterizer import PointsRasterizer +from .renderer import Renderer as PulsarRenderer + + +class PulsarPointsRenderer(nn.Module): + """ + This renderer is a PyTorch3D interface wrapper around the pulsar renderer. + + It provides an interface consistent with PyTorch3D Pointcloud rendering. + It will extract all necessary information from the rasterizer and compositor + objects and convert them to the pulsar required format, then invoke rendering + in the pulsar renderer. All gradients are handled appropriately through the + wrapper and the wrapper should provide equivalent results to using the pulsar + renderer directly. + """ + + def __init__( + self, + rasterizer: PointsRasterizer, + compositor: Optional[Union[NormWeightedCompositor, AlphaCompositor]] = None, + n_channels: int = 3, + max_num_spheres: int = int(1e6), # noqa: B008 + ): + """ + rasterizer (PointsRasterizer): An object encapsulating rasterization parameters. + compositor (ignored): Only keeping this for interface consistency. Default: None. + n_channels (int): The number of channels of the resulting image. Default: 3. + max_num_spheres (int): The maximum number of spheres intended to render with + this renderer. Default: 1e6. + """ + super().__init__() + self.rasterizer = rasterizer + if compositor is not None: + warnings.warn( + "Creating a `PulsarPointsRenderer` with a compositor object! " + "This object is ignored and just allowed as an argument for interface " + "compatibility." + ) + # Initialize the pulsar renderers. + if not isinstance( + rasterizer.cameras, + ( + FoVOrthographicCameras, + FoVPerspectiveCameras, + PerspectiveCameras, + OrthographicCameras, + ), + ): + raise ValueError( + "Only FoVPerspectiveCameras, PerspectiveCameras, " + "FoVOrthographicCameras and OrthographicCameras are supported " + "by the pulsar backend." + ) + if isinstance(rasterizer.raster_settings.image_size, tuple): + width, height = rasterizer.raster_settings.image_size + else: + width = rasterizer.raster_settings.image_size + height = rasterizer.raster_settings.image_size + # Making sure about integer types. + width = int(width) + height = int(height) + max_num_spheres = int(max_num_spheres) + orthogonal_projection = isinstance( + rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras) + ) + n_channels = int(n_channels) + self.renderer = PulsarRenderer( + width=width, + height=height, + max_num_balls=max_num_spheres, + orthogonal_projection=orthogonal_projection, + right_handed_system=True, + n_channels=n_channels, + ) + + def _conf_check(self, point_clouds, kwargs: Dict[str, Any]) -> bool: + """ + Verify internal configuration state with kwargs and pointclouds. + + This method will raise ValueError's for any inconsistencies found. It + returns whether an orthogonal projection will be used. + """ + if "gamma" not in kwargs.keys(): + raise ValueError( + "gamma is a required keyword argument for the PulsarPointsRenderer!" + ) + if ( + len(point_clouds) != len(self.rasterizer.cameras) + and len(self.rasterizer.cameras) != 1 + ): + raise ValueError( + ( + "The len(point_clouds) must either be equal to len(rasterizer.cameras) or " + "only one camera must be used. len(point_clouds): %d, " + "len(rasterizer.cameras): %d." + ) + % ( + len(point_clouds), + len(self.rasterizer.cameras), + ) + ) + # Make sure the rasterizer and cameras objects have no + # changes that can't be matched. + orthogonal_projection = isinstance( + self.rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras) + ) + if orthogonal_projection != self.renderer._renderer.orthogonal: + raise ValueError( + "The camera type can not be changed after renderer initialization! " + "Current camera orthogonal: %r. Original orthogonal: %r." + ) % (orthogonal_projection, self.renderer._renderer.orthogonal) + if ( + isinstance(self.rasterizer.raster_settings.image_size, tuple) + and self.rasterizer.raster_settings.image_size[0] + != self.renderer._renderer.width + ) or ( + not isinstance(self.rasterizer.raster_settings.image_size, tuple) + and self.rasterizer.raster_settings.image_size + != self.renderer._renderer.width + ): + raise ValueError( + ( + "The rasterizer width and height can not be changed after renderer " + "initialization! Current width: %d. Original width: %d." + ) + % ( + self.rasterizer.raster_settings.image_size, + self.renderer._renderer.width, + ) + ) + if ( + isinstance(self.rasterizer.raster_settings.image_size, tuple) + and self.rasterizer.raster_settings.image_size[1] + != self.renderer._renderer.height + ) or ( + not isinstance(self.rasterizer.raster_settings.image_size, tuple) + and self.rasterizer.raster_settings.image_size + != self.renderer._renderer.height + ): + raise ValueError( + ( + "The rasterizer width and height can not be changed after renderer " + "initialization! Current height: %d. Original height: %d." + ) + % ( + self.rasterizer.raster_settings.image_size, + self.renderer._renderer.height, + ) + ) + return orthogonal_projection + + def _extract_intrinsics( + self, orthogonal_projection, kwargs, cloud_idx + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]: + """ + Translate the camera intrinsics from PyTorch3D format to pulsar format. + """ + # Shorthand: + cameras = self.rasterizer.cameras + if orthogonal_projection: + focal_length = 0.0 + if isinstance(cameras, FoVOrthographicCameras): + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `znear`. + znear = kwargs.get("znear", cameras.znear)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `zfar`. + zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_y`. + max_y = kwargs.get("max_y", cameras.max_y)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_y`. + min_y = kwargs.get("min_y", cameras.min_y)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_x`. + max_x = kwargs.get("max_x", cameras.max_x)[cloud_idx] + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_x`. + min_x = kwargs.get("min_x", cameras.min_x)[cloud_idx] + if max_y != -min_y: + raise ValueError( + "The orthographic camera must be centered around 0. " + f"Max is {max_y} and min is {min_y}." + ) + if max_x != -min_x: + raise ValueError( + "The orthographic camera must be centered around 0. " + f"Max is {max_x} and min is {min_x}." + ) + if not torch.all( + # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `scale_xyz`. + kwargs.get("scale_xyz", cameras.scale_xyz)[cloud_idx] + == 1.0 + ): + raise ValueError( + "The orthographic camera scale must be ((1.0, 1.0, 1.0),). " + f"{kwargs.get('scale_xyz', cameras.scale_xyz)[cloud_idx]}." + ) + sensor_width = max_x - min_x + if not sensor_width > 0.0: + raise ValueError( + f"The orthographic camera must have positive size! Is: {sensor_width}." # noqa: B950 + ) + principal_point_x, principal_point_y = 0.0, 0.0 + else: + # Currently, this means it must be an 'OrthographicCameras' object. + focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[ + cloud_idx + ] + if ( + focal_length_conf.numel() == 2 + and focal_length_conf[0] * self.renderer._renderer.width + - focal_length_conf[1] * self.renderer._renderer.height + > 1e-5 + ): + raise ValueError( + "Pulsar only supports a single focal length! " + "Provided: %s." % (str(focal_length_conf)) + ) + if focal_length_conf.numel() == 2: + sensor_width = 2.0 / focal_length_conf[0] + else: + if focal_length_conf.numel() != 1: + raise ValueError( + "Focal length not parsable: %s." % (str(focal_length_conf)) + ) + sensor_width = 2.0 / focal_length_conf + if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys(): + raise ValueError( + "pulsar needs znear and zfar values for " + "the OrthographicCameras. Please provide them as keyword " + "argument to the forward method." + ) + znear = kwargs["znear"][cloud_idx] + zfar = kwargs["zfar"][cloud_idx] + principal_point_x = ( + kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0] + * 0.5 + * self.renderer._renderer.width + ) + principal_point_y = ( + kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1] + * 0.5 + * self.renderer._renderer.height + ) + else: + if not isinstance(cameras, PerspectiveCameras): + # Create a virtual focal length that is closer than znear. + znear = kwargs.get("znear", cameras.znear)[cloud_idx] + zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx] + focal_length = znear - 1e-6 + # Create a sensor size that matches the expected fov assuming this f. + afov = kwargs.get("fov", cameras.fov)[cloud_idx] + if kwargs.get("degrees", cameras.degrees): + afov *= math.pi / 180.0 + sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length + if not ( + kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx] + - self.renderer._renderer.width / self.renderer._renderer.height + < 1e-6 + ): + raise ValueError( + "The aspect ratio (" + f"{kwargs.get('aspect_ratio', cameras.aspect_ratio)[cloud_idx]}) " + "must agree with the resolution width / height (" + f"{self.renderer._renderer.width / self.renderer._renderer.height})." # noqa: B950 + ) + principal_point_x, principal_point_y = 0.0, 0.0 + else: + # pyre-fixme[16]: `PerspectiveCameras` has no attribute `focal_length`. + focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[ + cloud_idx + ] + if ( + focal_length_conf.numel() == 2 + and focal_length_conf[0] * self.renderer._renderer.width + - focal_length_conf[1] * self.renderer._renderer.height + > 1e-5 + ): + raise ValueError( + "Pulsar only supports a single focal length! " + "Provided: %s." % (str(focal_length_conf)) + ) + if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys(): + raise ValueError( + "pulsar needs znear and zfar values for " + "the PerspectiveCameras. Please provide them as keyword " + "argument to the forward method." + ) + znear = kwargs["znear"][cloud_idx] + zfar = kwargs["zfar"][cloud_idx] + if focal_length_conf.numel() == 2: + focal_length_px = focal_length_conf[0] + else: + if focal_length_conf.numel() != 1: + raise ValueError( + "Focal length not parsable: %s." % (str(focal_length_conf)) + ) + focal_length_px = focal_length_conf + focal_length = znear - 1e-6 + sensor_width = focal_length / focal_length_px * 2.0 + principal_point_x = ( + # pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`. + kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0] + * 0.5 + * self.renderer._renderer.width + ) + principal_point_y = ( + kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1] + * 0.5 + * self.renderer._renderer.height + ) + return ( + focal_length, + sensor_width, + principal_point_x, + principal_point_y, + znear, + zfar, + ) + + def _extract_extrinsics( + self, kwargs, cloud_idx + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Shorthand: + cameras = self.rasterizer.cameras + R = kwargs.get("R", cameras.R)[cloud_idx] + T = kwargs.get("T", cameras.T)[cloud_idx] + norm_mat = torch.tensor( + [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], + dtype=torch.float32, + device=R.device, + ) + cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]) + cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None])) + cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot)) + return cam_pos, cam_rot + + def _get_vert_rad( + self, vert_pos, cam_pos, orthogonal_projection, focal_length, kwargs, cloud_idx + ) -> torch.Tensor: + """ + Get point radiuses. + + These can be depending on the camera position in case of a perspective + transform. + """ + # Normalize point radiuses. + # `self.rasterizer.raster_settings.radius` can either be a float + # or itself a tensor. + raster_rad = self.rasterizer.raster_settings.radius + if kwargs.get("radius_world", False): + return raster_rad + if isinstance(raster_rad, torch.Tensor) and raster_rad.numel() > 1: + # In this case it must be a batched torch tensor. + raster_rad = raster_rad[cloud_idx] + if orthogonal_projection: + vert_rad = ( + torch.ones( + (vert_pos.shape[0],), dtype=torch.float32, device=vert_pos.device + ) + * raster_rad + ) + else: + point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False) + vert_rad = raster_rad / focal_length * point_dists + if isinstance(self.rasterizer.cameras, PerspectiveCameras): + # NDC normalization happens through adjusted focal length. + pass + else: + vert_rad = vert_rad / 2.0 # NDC normalization. + return vert_rad + + def forward(self, point_clouds, **kwargs) -> torch.Tensor: + """ + Get the rendering of the provided `Pointclouds`. + + The number of point clouds in the `Pointclouds` object determines the + number of resulting images. The provided cameras can be either 1 or equal + to the number of pointclouds (in the first case, the same camera will be + used for all clouds, in the latter case each point cloud will be rendered + with the corresponding camera). + + The following kwargs are support from PyTorch3D (depending on the selected + camera model potentially overriding camera parameters): + radius_world (bool): use the provided radiuses from the raster_settings + plain as radiuses in world space. Default: False. + znear (Iterable[float]): near geometry cutoff. Is required for + OrthographicCameras and PerspectiveCameras. + zfar (Iterable[float]): far geometry cutoff. Is required for + OrthographicCameras and PerspectiveCameras. + R (torch.Tensor): [Bx3x3] camera rotation matrices. + T (torch.Tensor): [Bx3] camera translation vectors. + principal_point (torch.Tensor): [Bx2] camera intrinsic principal + point offset vectors. + focal_length (torch.Tensor): [Bx1] camera intrinsic focal lengths. + aspect_ratio (Iterable[float]): camera aspect ratios. + fov (Iterable[float]): camera FOVs. + degrees (bool): whether FOVs are specified in degrees or + radians. + min_x (Iterable[float]): minimum x for the FoVOrthographicCameras. + max_x (Iterable[float]): maximum x for the FoVOrthographicCameras. + min_y (Iterable[float]): minimum y for the FoVOrthographicCameras. + max_y (Iterable[float]): maximum y for the FoVOrthographicCameras. + + The following kwargs are supported from pulsar: + gamma (float): The gamma value to use. This defines the transparency for + differentiability (see pulsar paper for details). Must be in [1., 1e-5] + with 1.0 being mostly transparent. This keyword argument is *required*! + bg_col (torch.Tensor): The background color. Must be a tensor on the same + device as the point clouds, with as many channels as features (no batch + dimension - it is the same for all images in the batch). + Default: 0.0 for all channels. + percent_allowed_difference (float): a value in [0., 1.[ with the maximum + allowed difference in channel space. This is used to speed up the + computation. Default: 0.01. + max_n_hits (int): a hard limit on the number of sphere hits per ray. + Default: max int. + mode (int): render mode in {0, 1}. 0: render image; 1: render hit map. + """ + orthogonal_projection: bool = self._conf_check(point_clouds, kwargs) + # Get access to inputs. We're using the list accessor and process + # them sequentially. + position_list = point_clouds.points_list() + features_list = point_clouds.features_list() + # Result list. + images = [] + for cloud_idx, (vert_pos, vert_col) in enumerate( + zip(position_list, features_list) + ): + # Get intrinsics. + ( + focal_length, + sensor_width, + principal_point_x, + principal_point_y, + znear, + zfar, + ) = self._extract_intrinsics(orthogonal_projection, kwargs, cloud_idx) + # Get extrinsics. + cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx) + # Put everything together. + cam_params = torch.cat( + ( + cam_pos, + cam_rot, + torch.tensor( + [ + focal_length, + sensor_width, + principal_point_x, + principal_point_y, + ], + dtype=torch.float32, + device=cam_pos.device, + ), + ) + ) + # Get point radiuses (can depend on camera position). + vert_rad = self._get_vert_rad( + vert_pos, + cam_pos, + orthogonal_projection, + focal_length, + kwargs, + cloud_idx, + ) + # Clean kwargs for passing on. + gamma = kwargs["gamma"][cloud_idx] + if "first_R_then_T" in kwargs.keys(): + raise ValueError("`first_R_then_T` is not supported in this interface.") + otherargs = { + argn: argv + for argn, argv in kwargs.items() + if argn + not in [ + "radius_world", + "gamma", + "znear", + "zfar", + "R", + "T", + "principal_point", + "focal_length", + "aspect_ratio", + "fov", + "degrees", + "min_x", + "max_x", + "min_y", + "max_y", + ] + } + # background color + if "bg_col" not in otherargs: + bg_col = torch.zeros( + vert_col.shape[1], device=cam_params.device, dtype=torch.float32 + ) + otherargs["bg_col"] = bg_col + # Go! + images.append( + self.renderer( + vert_pos=vert_pos, + vert_col=vert_col, + vert_rad=vert_rad, + cam_params=cam_params, + gamma=gamma, + max_depth=zfar, + min_depth=znear, + **otherargs, + ) + ) + return torch.stack(images, dim=0) diff --git a/pytorch3d/renderer/points/renderer.py b/pytorch3d/renderer/points/renderer.py index 4dc610a3..0f5f3120 100644 --- a/pytorch3d/renderer/points/renderer.py +++ b/pytorch3d/renderer/points/renderer.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. - - import torch import torch.nn as nn @@ -48,7 +46,7 @@ class PointsRenderer(nn.Module): fragments.idx.long().permute(0, 3, 1, 2), weights, point_clouds.features_packed().permute(1, 0), - **kwargs + **kwargs, ) # permute so image comes at the end diff --git a/pytorch3d/transforms/__init__.py b/pytorch3d/transforms/__init__.py index 2f0a0301..8769c16b 100644 --- a/pytorch3d/transforms/__init__.py +++ b/pytorch3d/transforms/__init__.py @@ -1,9 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .external.kornia_angle_axis_to_rotation_matrix import ( - angle_axis_to_rotation_matrix as axis_angle_to_matrix, -) from .rotation_conversions import ( + axis_angle_to_matrix, euler_angles_to_matrix, matrix_to_euler_angles, matrix_to_quaternion, diff --git a/pytorch3d/transforms/external/__init__.py b/pytorch3d/transforms/external/__init__.py deleted file mode 100644 index 40539064..00000000 --- a/pytorch3d/transforms/external/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. diff --git a/pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py b/pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py deleted file mode 100644 index 1269813a..00000000 --- a/pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python3 -""" -This file contains the great angle axis to rotation matrix conversion -from kornia (https://github.com/arraiyopensource/kornia). The license -can be found in kornia_license.txt. - -The method is used unchanged; the documentation has been adjusted -to match our doc format. -""" -import torch - - -def angle_axis_to_rotation_matrix(angle_axis): - """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix - - Args: - angle_axis (Tensor): tensor of 3d vector of axis-angle rotations. - - Returns: - Tensor: tensor of 3x3 rotation matrix. - - Shape: - - Input: :math:`(N, 3)` - - Output: :math:`(N, 3, 3)` - - Example: - - ..code-block::python - - >>> input = torch.rand(1, 3) # Nx3 - >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3 - >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3 - """ - - def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6): - # We want to be careful to only evaluate the square root if the - # norm of the angle_axis vector is greater than zero. Otherwise - # we get a division by zero. - k_one = 1.0 - theta = torch.sqrt(theta2) - wxyz = angle_axis / (theta + eps) - wx, wy, wz = torch.chunk(wxyz, 3, dim=1) - cos_theta = torch.cos(theta) - sin_theta = torch.sin(theta) - - r00 = cos_theta + wx * wx * (k_one - cos_theta) - r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) - r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) - r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta - r11 = cos_theta + wy * wy * (k_one - cos_theta) - r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) - r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) - r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) - r22 = cos_theta + wz * wz * (k_one - cos_theta) - rotation_matrix = torch.cat( - [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1 - ) - return rotation_matrix.view(-1, 3, 3) - - def _compute_rotation_matrix_taylor(angle_axis): - rx, ry, rz = torch.chunk(angle_axis, 3, dim=1) - k_one = torch.ones_like(rx) - rotation_matrix = torch.cat( - [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1 - ) - return rotation_matrix.view(-1, 3, 3) - - # stolen from ceres/rotation.h - - _angle_axis = torch.unsqueeze(angle_axis + 1e-6, dim=1) - # _angle_axis.register_hook(lambda grad: pdb.set_trace()) - # _angle_axis = 1e-6 - theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2)) - theta2 = torch.squeeze(theta2, dim=1) - - # compute rotation matrices - rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2) - rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis) - - # create mask to handle both cases - eps = 1e-6 - mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device) - mask_pos = (mask).type_as(theta2) - mask_neg = (mask == False).type_as(theta2) # noqa - - # create output pose matrix - batch_size = angle_axis.shape[0] - rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis) - rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1) - # fill output matrix with masked values - rotation_matrix[..., :3, :3] = ( - mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor - ) - return rotation_matrix.to(angle_axis.device).type_as(angle_axis) # Nx4x4 diff --git a/pytorch3d/transforms/external/kornia_license.txt b/pytorch3d/transforms/external/kornia_license.txt deleted file mode 100644 index 261eeb9e..00000000 --- a/pytorch3d/transforms/external/kornia_license.txt +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/setup.py b/setup.py index ee678ee1..cf041287 100755 --- a/setup.py +++ b/setup.py @@ -13,13 +13,8 @@ from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) extensions_dir = os.path.join(this_dir, "pytorch3d", "csrc") - - main_source = os.path.join(extensions_dir, "ext.cpp") - sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) - source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) - - sources = [main_source] + sources - + sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True) + source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True) extension = CppExtension extra_compile_args = {"cxx": ["-std=c++14"]} @@ -30,7 +25,18 @@ def get_extensions(): extension = CUDAExtension sources += source_cuda define_macros += [("WITH_CUDA", None)] + cub_home = os.environ.get("CUB_HOME", None) + if cub_home is None: + raise Exception( + "The environment variable `CUB_HOME` was not found. " + "NVIDIA CUB is required for compilation and can be downloaded " + "from `https://github.com/NVIDIA/cub/releases`. You can unpack " + "it to a location of your choice and set the environment variable " + "`CUB_HOME` to the folder containing the `CMakeListst.txt` file." + ) nvcc_args = [ + "-I%s" % (os.path.realpath(cub_home).replace("\\ ", " ")), + "-std=c++14", "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png new file mode 100644 index 00000000..cb6b5cec Binary files /dev/null and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovorthographic.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovperspective.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovperspective.png new file mode 100644 index 00000000..bbeb8807 Binary files /dev/null and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_fovperspective.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png new file mode 100644 index 00000000..cb6b5cec Binary files /dev/null and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_orthographic.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_perspective.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_perspective.png new file mode 100644 index 00000000..3bfa4d14 Binary files /dev/null and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth0.0_perspective.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png new file mode 100644 index 00000000..507fc204 Binary files /dev/null and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovorthographic.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovperspective.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovperspective.png new file mode 100644 index 00000000..f378283b Binary files /dev/null and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_fovperspective.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png new file mode 100644 index 00000000..507fc204 Binary files /dev/null and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_orthographic.png differ diff --git a/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_perspective.png b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_perspective.png new file mode 100644 index 00000000..ab01389a Binary files /dev/null and b/tests/data/test_pulsar_simple_pointcloud_sphere_azimuth90.0_perspective.png differ diff --git a/tests/test_render_points.py b/tests/test_render_points.py index 81f35dee..e400745e 100644 --- a/tests/test_render_points.py +++ b/tests/test_render_points.py @@ -16,6 +16,8 @@ from PIL import Image from pytorch3d.renderer.cameras import ( FoVOrthographicCameras, FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, look_at_view_transform, ) from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum @@ -25,6 +27,7 @@ from pytorch3d.renderer.points import ( PointsRasterizationSettings, PointsRasterizer, PointsRenderer, + PulsarPointsRenderer, ) from pytorch3d.structures.pointclouds import Pointclouds from pytorch3d.utils.ico_sphere import ico_sphere @@ -72,6 +75,145 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): ) self.assertClose(rgb, image_ref) + def test_simple_sphere_pulsar(self): + for device in [torch.device("cpu"), torch.device("cuda")]: + sphere_mesh = ico_sphere(1, device) + verts_padded = sphere_mesh.verts_padded() + # Shift vertices to check coordinate frames are correct. + verts_padded[..., 1] += 0.2 + verts_padded[..., 0] += 0.2 + pointclouds = Pointclouds( + points=verts_padded, features=torch.ones_like(verts_padded) + ) + for azimuth in [0.0, 90.0]: + R, T = look_at_view_transform(2.7, 0.0, azimuth) + for camera_name, cameras in [ + ("fovperspective", FoVPerspectiveCameras(device=device, R=R, T=T)), + ( + "fovorthographic", + FoVOrthographicCameras(device=device, R=R, T=T), + ), + ("perspective", PerspectiveCameras(device=device, R=R, T=T)), + ("orthographic", OrthographicCameras(device=device, R=R, T=T)), + ]: + raster_settings = PointsRasterizationSettings( + image_size=256, radius=5e-2, points_per_pixel=1 + ) + rasterizer = PointsRasterizer( + cameras=cameras, raster_settings=raster_settings + ) + renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device) + # Load reference image + filename = ( + "pulsar_simple_pointcloud_sphere_" + f"azimuth{azimuth}_{camera_name}.png" + ) + image_ref = load_rgb_image("test_%s" % filename, DATA_DIR) + images = renderer( + pointclouds, gamma=(1e-3,), znear=(1.0,), zfar=(100.0,) + ) + rgb = images[0, ..., :3].squeeze().cpu() + if DEBUG: + filename = "DEBUG_%s" % filename + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / filename + ) + self.assertClose(rgb, image_ref, rtol=7e-3, atol=5e-3) + + def test_unified_inputs_pulsar(self): + # Test data on different devices. + for device in [torch.device("cpu"), torch.device("cuda")]: + sphere_mesh = ico_sphere(1, device) + verts_padded = sphere_mesh.verts_padded() + pointclouds = Pointclouds( + points=verts_padded, features=torch.ones_like(verts_padded) + ) + R, T = look_at_view_transform(2.7, 0.0, 0.0) + # Test the different camera types. + for _, cameras in [ + ("fovperspective", FoVPerspectiveCameras(device=device, R=R, T=T)), + ( + "fovorthographic", + FoVOrthographicCameras(device=device, R=R, T=T), + ), + ("perspective", PerspectiveCameras(device=device, R=R, T=T)), + ("orthographic", OrthographicCameras(device=device, R=R, T=T)), + ]: + # Test different ways for image size specification. + for image_size in (256, (256, 256)): + raster_settings = PointsRasterizationSettings( + image_size=image_size, radius=5e-2, points_per_pixel=1 + ) + rasterizer = PointsRasterizer( + cameras=cameras, raster_settings=raster_settings + ) + # Test that the compositor can be provided. It's value is ignored + # so use a dummy. + _ = PulsarPointsRenderer(rasterizer=rasterizer, compositor=1).to( + device + ) + # Constructor without compositor. + _ = PulsarPointsRenderer(rasterizer=rasterizer).to(device) + # Constructor with n_channels. + _ = PulsarPointsRenderer(rasterizer=rasterizer, n_channels=3).to( + device + ) + # Constructor with max_num_spheres. + renderer = PulsarPointsRenderer( + rasterizer=rasterizer, max_num_spheres=1000 + ).to(device) + # Test the forward function. + if isinstance(cameras, (PerspectiveCameras, OrthographicCameras)): + # znear and zfar is required in this case. + self.assertRaises( + ValueError, + lambda: renderer.forward( + point_clouds=pointclouds, gamma=(1e-4,) + ), + ) + renderer.forward( + point_clouds=pointclouds, + gamma=(1e-4,), + znear=(1.0,), + zfar=(2.0,), + ) + # znear and zfar must be batched. + self.assertRaises( + TypeError, + lambda: renderer.forward( + point_clouds=pointclouds, + gamma=(1e-4,), + znear=1.0, + zfar=(2.0,), + ), + ) + self.assertRaises( + TypeError, + lambda: renderer.forward( + point_clouds=pointclouds, + gamma=(1e-4,), + znear=(1.0,), + zfar=2.0, + ), + ) + else: + # gamma must be batched. + self.assertRaises( + TypeError, + lambda: renderer.forward( + point_clouds=pointclouds, gamma=1e-4 + ), + ) + renderer.forward(point_clouds=pointclouds, gamma=(1e-4,)) + # rasterizer width and height change. + renderer.rasterizer.raster_settings.image_size = 0 + self.assertRaises( + ValueError, + lambda: renderer.forward( + point_clouds=pointclouds, gamma=(1e-4,) + ), + ) + def test_pointcloud_with_features(self): device = torch.device("cuda:0") file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"