pulsar interface unification.
Summary: This diff builds on top of the `pulsar integration` diff to provide a unified interface for the existing PyTorch3D point renderer and Pulsar. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder docs/examples. The unified interfaces are completely consistent. Switching the render backend is as easy as using `renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device)` instead of `renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)` and adding the `gamma` parameter to the forward function. All PyTorch3D camera types are supported as far as possible; keyword arguments are properly forwarded to the camera. The `PerspectiveCamera` and `OrthographicCamera` require znear and zfar as additional parameters for the forward pass. Reviewed By: nikhilaravi Differential Revision: D21421443 fbshipit-source-id: 4aa0a83a419592d9a0bb5d62486a1cdea9d73ce6
@ -1,6 +1,7 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from .compositor import AlphaCompositor, NormWeightedCompositor
 | 
			
		||||
from .pulsar.unified import PulsarPointsRenderer
 | 
			
		||||
from .rasterize_points import rasterize_points
 | 
			
		||||
from .rasterizer import PointsRasterizationSettings, PointsRasterizer
 | 
			
		||||
from .renderer import PointsRenderer
 | 
			
		||||
 | 
			
		||||
@ -369,6 +369,7 @@ class Renderer(torch.nn.Module):
 | 
			
		||||
        height: int,
 | 
			
		||||
        orthogonal: bool,
 | 
			
		||||
        right_handed: bool,
 | 
			
		||||
        first_R_then_T: bool = False,
 | 
			
		||||
    ) -> Tuple[
 | 
			
		||||
        torch.Tensor,
 | 
			
		||||
        torch.Tensor,
 | 
			
		||||
@ -401,6 +402,8 @@ class Renderer(torch.nn.Module):
 | 
			
		||||
                  (does not use focal length).
 | 
			
		||||
            * right_handed: bool, whether to use a right handed system
 | 
			
		||||
                  (negative z in camera direction).
 | 
			
		||||
            * first_R_then_T: bool, whether to first rotate, then translate
 | 
			
		||||
                  the camera (PyTorch3D convention).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            * pos_vec: the position vector in 3D,
 | 
			
		||||
@ -460,16 +463,18 @@ class Renderer(torch.nn.Module):
 | 
			
		||||
        # Always get quadratic pixels.
 | 
			
		||||
        pixel_size_x = sensor_size_x / float(width)
 | 
			
		||||
        sensor_size_y = height * pixel_size_x
 | 
			
		||||
        if continuous_rep:
 | 
			
		||||
            rot_mat = rotation_6d_to_matrix(rot_vec)
 | 
			
		||||
        else:
 | 
			
		||||
            rot_mat = axis_angle_to_matrix(rot_vec)
 | 
			
		||||
        if first_R_then_T:
 | 
			
		||||
            pos_vec = torch.matmul(rot_mat, pos_vec[..., None])[:, :, 0]
 | 
			
		||||
        LOGGER.debug(
 | 
			
		||||
            "Camera position: %s, rotation: %s. Focal length: %s.",
 | 
			
		||||
            str(pos_vec),
 | 
			
		||||
            str(rot_vec),
 | 
			
		||||
            str(focal_length),
 | 
			
		||||
        )
 | 
			
		||||
        if continuous_rep:
 | 
			
		||||
            rot_mat = rotation_6d_to_matrix(rot_vec)
 | 
			
		||||
        else:
 | 
			
		||||
            rot_mat = axis_angle_to_matrix(rot_vec)
 | 
			
		||||
        sensor_dir_x = torch.matmul(
 | 
			
		||||
            rot_mat,
 | 
			
		||||
            torch.tensor(
 | 
			
		||||
@ -576,6 +581,7 @@ class Renderer(torch.nn.Module):
 | 
			
		||||
        max_n_hits: int = _C.MAX_UINT,
 | 
			
		||||
        mode: int = 0,
 | 
			
		||||
        return_forward_info: bool = False,
 | 
			
		||||
        first_R_then_T: bool = False,
 | 
			
		||||
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
 | 
			
		||||
        """
 | 
			
		||||
        Rendering pass to create an image from the provided spheres and camera
 | 
			
		||||
@ -616,6 +622,8 @@ class Renderer(torch.nn.Module):
 | 
			
		||||
                the float encoded integer index of a sphere and its weight. They are the
 | 
			
		||||
                five spheres with the highest color contribution to this pixel color,
 | 
			
		||||
                ordered descending. Default: False.
 | 
			
		||||
            * first_R_then_T: bool, whether to first apply rotation to the camera,
 | 
			
		||||
                then translation (PyTorch3D convention). Default: False.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            * image: [Bx]HxWx3 float tensor with the resulting image.
 | 
			
		||||
@ -638,6 +646,7 @@ class Renderer(torch.nn.Module):
 | 
			
		||||
            self._renderer.height,
 | 
			
		||||
            self._renderer.orthogonal,
 | 
			
		||||
            self._renderer.right_handed,
 | 
			
		||||
            first_R_then_T=first_R_then_T,
 | 
			
		||||
        )
 | 
			
		||||
        if (
 | 
			
		||||
            focal_lengths.min().item() > 0.0
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										524
									
								
								pytorch3d/renderer/points/pulsar/unified.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						@ -0,0 +1,524 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
from ....transforms import matrix_to_rotation_6d
 | 
			
		||||
from ...cameras import (
 | 
			
		||||
    FoVOrthographicCameras,
 | 
			
		||||
    FoVPerspectiveCameras,
 | 
			
		||||
    OrthographicCameras,
 | 
			
		||||
    PerspectiveCameras,
 | 
			
		||||
)
 | 
			
		||||
from ..compositor import AlphaCompositor, NormWeightedCompositor
 | 
			
		||||
from ..rasterizer import PointsRasterizer
 | 
			
		||||
from .renderer import Renderer as PulsarRenderer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PulsarPointsRenderer(nn.Module):
 | 
			
		||||
    """
 | 
			
		||||
    This renderer is a PyTorch3D interface wrapper around the pulsar renderer.
 | 
			
		||||
 | 
			
		||||
    It provides an interface consistent with PyTorch3D Pointcloud rendering.
 | 
			
		||||
    It will extract all necessary information from the rasterizer and compositor
 | 
			
		||||
    objects and convert them to the pulsar required format, then invoke rendering
 | 
			
		||||
    in the pulsar renderer. All gradients are handled appropriately through the
 | 
			
		||||
    wrapper and the wrapper should provide equivalent results to using the pulsar
 | 
			
		||||
    renderer directly.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        rasterizer: PointsRasterizer,
 | 
			
		||||
        compositor: Optional[Union[NormWeightedCompositor, AlphaCompositor]] = None,
 | 
			
		||||
        n_channels: int = 3,
 | 
			
		||||
        max_num_spheres: int = int(1e6),  # noqa: B008
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        rasterizer (PointsRasterizer): An object encapsulating rasterization parameters.
 | 
			
		||||
        compositor (ignored): Only keeping this for interface consistency. Default: None.
 | 
			
		||||
        n_channels (int): The number of channels of the resulting image. Default: 3.
 | 
			
		||||
        max_num_spheres (int): The maximum number of spheres intended to render with
 | 
			
		||||
            this renderer. Default: 1e6.
 | 
			
		||||
        """
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.rasterizer = rasterizer
 | 
			
		||||
        if compositor is not None:
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                "Creating a `PulsarPointsRenderer` with a compositor object! "
 | 
			
		||||
                "This object is ignored and just allowed as an argument for interface "
 | 
			
		||||
                "compatibility."
 | 
			
		||||
            )
 | 
			
		||||
        # Initialize the pulsar renderers.
 | 
			
		||||
        if not isinstance(
 | 
			
		||||
            rasterizer.cameras,
 | 
			
		||||
            (
 | 
			
		||||
                FoVOrthographicCameras,
 | 
			
		||||
                FoVPerspectiveCameras,
 | 
			
		||||
                PerspectiveCameras,
 | 
			
		||||
                OrthographicCameras,
 | 
			
		||||
            ),
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Only FoVPerspectiveCameras, PerspectiveCameras, "
 | 
			
		||||
                "FoVOrthographicCameras and OrthographicCameras are supported "
 | 
			
		||||
                "by the pulsar backend."
 | 
			
		||||
            )
 | 
			
		||||
        if isinstance(rasterizer.raster_settings.image_size, tuple):
 | 
			
		||||
            width, height = rasterizer.raster_settings.image_size
 | 
			
		||||
        else:
 | 
			
		||||
            width = rasterizer.raster_settings.image_size
 | 
			
		||||
            height = rasterizer.raster_settings.image_size
 | 
			
		||||
        # Making sure about integer types.
 | 
			
		||||
        width = int(width)
 | 
			
		||||
        height = int(height)
 | 
			
		||||
        max_num_spheres = int(max_num_spheres)
 | 
			
		||||
        orthogonal_projection = isinstance(
 | 
			
		||||
            rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
 | 
			
		||||
        )
 | 
			
		||||
        n_channels = int(n_channels)
 | 
			
		||||
        self.renderer = PulsarRenderer(
 | 
			
		||||
            width=width,
 | 
			
		||||
            height=height,
 | 
			
		||||
            max_num_balls=max_num_spheres,
 | 
			
		||||
            orthogonal_projection=orthogonal_projection,
 | 
			
		||||
            right_handed_system=True,
 | 
			
		||||
            n_channels=n_channels,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _conf_check(self, point_clouds, kwargs: Dict[str, Any]) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Verify internal configuration state with kwargs and pointclouds.
 | 
			
		||||
 | 
			
		||||
        This method will raise ValueError's for any inconsistencies found. It
 | 
			
		||||
        returns whether an orthogonal projection will be used.
 | 
			
		||||
        """
 | 
			
		||||
        if "gamma" not in kwargs.keys():
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "gamma is a required keyword argument for the PulsarPointsRenderer!"
 | 
			
		||||
            )
 | 
			
		||||
        if (
 | 
			
		||||
            len(point_clouds) != len(self.rasterizer.cameras)
 | 
			
		||||
            and len(self.rasterizer.cameras) != 1
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                (
 | 
			
		||||
                    "The len(point_clouds) must either be equal to len(rasterizer.cameras) or "
 | 
			
		||||
                    "only one camera must be used. len(point_clouds): %d, "
 | 
			
		||||
                    "len(rasterizer.cameras): %d."
 | 
			
		||||
                )
 | 
			
		||||
                % (
 | 
			
		||||
                    len(point_clouds),
 | 
			
		||||
                    len(self.rasterizer.cameras),
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        # Make sure the rasterizer and cameras objects have no
 | 
			
		||||
        # changes that can't be matched.
 | 
			
		||||
        orthogonal_projection = isinstance(
 | 
			
		||||
            self.rasterizer.cameras, (FoVOrthographicCameras, OrthographicCameras)
 | 
			
		||||
        )
 | 
			
		||||
        if orthogonal_projection != self.renderer._renderer.orthogonal:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "The camera type can not be changed after renderer initialization! "
 | 
			
		||||
                "Current camera orthogonal: %r. Original orthogonal: %r."
 | 
			
		||||
            ) % (orthogonal_projection, self.renderer._renderer.orthogonal)
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(self.rasterizer.raster_settings.image_size, tuple)
 | 
			
		||||
            and self.rasterizer.raster_settings.image_size[0]
 | 
			
		||||
            != self.renderer._renderer.width
 | 
			
		||||
        ) or (
 | 
			
		||||
            not isinstance(self.rasterizer.raster_settings.image_size, tuple)
 | 
			
		||||
            and self.rasterizer.raster_settings.image_size
 | 
			
		||||
            != self.renderer._renderer.width
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                (
 | 
			
		||||
                    "The rasterizer width and height can not be changed after renderer "
 | 
			
		||||
                    "initialization! Current width: %d. Original width: %d."
 | 
			
		||||
                )
 | 
			
		||||
                % (
 | 
			
		||||
                    self.rasterizer.raster_settings.image_size,
 | 
			
		||||
                    self.renderer._renderer.width,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        if (
 | 
			
		||||
            isinstance(self.rasterizer.raster_settings.image_size, tuple)
 | 
			
		||||
            and self.rasterizer.raster_settings.image_size[1]
 | 
			
		||||
            != self.renderer._renderer.height
 | 
			
		||||
        ) or (
 | 
			
		||||
            not isinstance(self.rasterizer.raster_settings.image_size, tuple)
 | 
			
		||||
            and self.rasterizer.raster_settings.image_size
 | 
			
		||||
            != self.renderer._renderer.height
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                (
 | 
			
		||||
                    "The rasterizer width and height can not be changed after renderer "
 | 
			
		||||
                    "initialization! Current height: %d. Original height: %d."
 | 
			
		||||
                )
 | 
			
		||||
                % (
 | 
			
		||||
                    self.rasterizer.raster_settings.image_size,
 | 
			
		||||
                    self.renderer._renderer.height,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        return orthogonal_projection
 | 
			
		||||
 | 
			
		||||
    def _extract_intrinsics(
 | 
			
		||||
        self, orthogonal_projection, kwargs, cloud_idx
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]:
 | 
			
		||||
        """
 | 
			
		||||
        Translate the camera intrinsics from PyTorch3D format to pulsar format.
 | 
			
		||||
        """
 | 
			
		||||
        # Shorthand:
 | 
			
		||||
        cameras = self.rasterizer.cameras
 | 
			
		||||
        if orthogonal_projection:
 | 
			
		||||
            focal_length = 0.0
 | 
			
		||||
            if isinstance(cameras, FoVOrthographicCameras):
 | 
			
		||||
                # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `znear`.
 | 
			
		||||
                znear = kwargs.get("znear", cameras.znear)[cloud_idx]
 | 
			
		||||
                # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `zfar`.
 | 
			
		||||
                zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
 | 
			
		||||
                # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_y`.
 | 
			
		||||
                max_y = kwargs.get("max_y", cameras.max_y)[cloud_idx]
 | 
			
		||||
                # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_y`.
 | 
			
		||||
                min_y = kwargs.get("min_y", cameras.min_y)[cloud_idx]
 | 
			
		||||
                # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_x`.
 | 
			
		||||
                max_x = kwargs.get("max_x", cameras.max_x)[cloud_idx]
 | 
			
		||||
                # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_x`.
 | 
			
		||||
                min_x = kwargs.get("min_x", cameras.min_x)[cloud_idx]
 | 
			
		||||
                if max_y != -min_y:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "The orthographic camera must be centered around 0. "
 | 
			
		||||
                        f"Max is {max_y} and min is {min_y}."
 | 
			
		||||
                    )
 | 
			
		||||
                if max_x != -min_x:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "The orthographic camera must be centered around 0. "
 | 
			
		||||
                        f"Max is {max_x} and min is {min_x}."
 | 
			
		||||
                    )
 | 
			
		||||
                if not torch.all(
 | 
			
		||||
                    # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `scale_xyz`.
 | 
			
		||||
                    kwargs.get("scale_xyz", cameras.scale_xyz)[cloud_idx]
 | 
			
		||||
                    == 1.0
 | 
			
		||||
                ):
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "The orthographic camera scale must be ((1.0, 1.0, 1.0),). "
 | 
			
		||||
                        f"{kwargs.get('scale_xyz', cameras.scale_xyz)[cloud_idx]}."
 | 
			
		||||
                    )
 | 
			
		||||
                sensor_width = max_x - min_x
 | 
			
		||||
                if not sensor_width > 0.0:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"The orthographic camera must have positive size! Is: {sensor_width}."  # noqa: B950
 | 
			
		||||
                    )
 | 
			
		||||
                principal_point_x, principal_point_y = 0.0, 0.0
 | 
			
		||||
            else:
 | 
			
		||||
                # Currently, this means it must be an 'OrthographicCameras' object.
 | 
			
		||||
                focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
 | 
			
		||||
                    cloud_idx
 | 
			
		||||
                ]
 | 
			
		||||
                if (
 | 
			
		||||
                    focal_length_conf.numel() == 2
 | 
			
		||||
                    and focal_length_conf[0] * self.renderer._renderer.width
 | 
			
		||||
                    - focal_length_conf[1] * self.renderer._renderer.height
 | 
			
		||||
                    > 1e-5
 | 
			
		||||
                ):
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "Pulsar only supports a single focal length! "
 | 
			
		||||
                        "Provided: %s." % (str(focal_length_conf))
 | 
			
		||||
                    )
 | 
			
		||||
                if focal_length_conf.numel() == 2:
 | 
			
		||||
                    sensor_width = 2.0 / focal_length_conf[0]
 | 
			
		||||
                else:
 | 
			
		||||
                    if focal_length_conf.numel() != 1:
 | 
			
		||||
                        raise ValueError(
 | 
			
		||||
                            "Focal length not parsable: %s." % (str(focal_length_conf))
 | 
			
		||||
                        )
 | 
			
		||||
                    sensor_width = 2.0 / focal_length_conf
 | 
			
		||||
                if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys():
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "pulsar needs znear and zfar values for "
 | 
			
		||||
                        "the OrthographicCameras. Please provide them as keyword "
 | 
			
		||||
                        "argument to the forward method."
 | 
			
		||||
                    )
 | 
			
		||||
                znear = kwargs["znear"][cloud_idx]
 | 
			
		||||
                zfar = kwargs["zfar"][cloud_idx]
 | 
			
		||||
                principal_point_x = (
 | 
			
		||||
                    kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
 | 
			
		||||
                    * 0.5
 | 
			
		||||
                    * self.renderer._renderer.width
 | 
			
		||||
                )
 | 
			
		||||
                principal_point_y = (
 | 
			
		||||
                    kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
 | 
			
		||||
                    * 0.5
 | 
			
		||||
                    * self.renderer._renderer.height
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            if not isinstance(cameras, PerspectiveCameras):
 | 
			
		||||
                # Create a virtual focal length that is closer than znear.
 | 
			
		||||
                znear = kwargs.get("znear", cameras.znear)[cloud_idx]
 | 
			
		||||
                zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
 | 
			
		||||
                focal_length = znear - 1e-6
 | 
			
		||||
                # Create a sensor size that matches the expected fov assuming this f.
 | 
			
		||||
                afov = kwargs.get("fov", cameras.fov)[cloud_idx]
 | 
			
		||||
                if kwargs.get("degrees", cameras.degrees):
 | 
			
		||||
                    afov *= math.pi / 180.0
 | 
			
		||||
                sensor_width = math.tan(afov / 2.0) * 2.0 * focal_length
 | 
			
		||||
                if not (
 | 
			
		||||
                    kwargs.get("aspect_ratio", cameras.aspect_ratio)[cloud_idx]
 | 
			
		||||
                    - self.renderer._renderer.width / self.renderer._renderer.height
 | 
			
		||||
                    < 1e-6
 | 
			
		||||
                ):
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "The aspect ratio ("
 | 
			
		||||
                        f"{kwargs.get('aspect_ratio', cameras.aspect_ratio)[cloud_idx]}) "
 | 
			
		||||
                        "must agree with the resolution width / height ("
 | 
			
		||||
                        f"{self.renderer._renderer.width / self.renderer._renderer.height})."  # noqa: B950
 | 
			
		||||
                    )
 | 
			
		||||
                principal_point_x, principal_point_y = 0.0, 0.0
 | 
			
		||||
            else:
 | 
			
		||||
                # pyre-fixme[16]: `PerspectiveCameras` has no attribute `focal_length`.
 | 
			
		||||
                focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
 | 
			
		||||
                    cloud_idx
 | 
			
		||||
                ]
 | 
			
		||||
                if (
 | 
			
		||||
                    focal_length_conf.numel() == 2
 | 
			
		||||
                    and focal_length_conf[0] * self.renderer._renderer.width
 | 
			
		||||
                    - focal_length_conf[1] * self.renderer._renderer.height
 | 
			
		||||
                    > 1e-5
 | 
			
		||||
                ):
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "Pulsar only supports a single focal length! "
 | 
			
		||||
                        "Provided: %s." % (str(focal_length_conf))
 | 
			
		||||
                    )
 | 
			
		||||
                if "znear" not in kwargs.keys() or "zfar" not in kwargs.keys():
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        "pulsar needs znear and zfar values for "
 | 
			
		||||
                        "the PerspectiveCameras. Please provide them as keyword "
 | 
			
		||||
                        "argument to the forward method."
 | 
			
		||||
                    )
 | 
			
		||||
                znear = kwargs["znear"][cloud_idx]
 | 
			
		||||
                zfar = kwargs["zfar"][cloud_idx]
 | 
			
		||||
                if focal_length_conf.numel() == 2:
 | 
			
		||||
                    focal_length_px = focal_length_conf[0]
 | 
			
		||||
                else:
 | 
			
		||||
                    if focal_length_conf.numel() != 1:
 | 
			
		||||
                        raise ValueError(
 | 
			
		||||
                            "Focal length not parsable: %s." % (str(focal_length_conf))
 | 
			
		||||
                        )
 | 
			
		||||
                    focal_length_px = focal_length_conf
 | 
			
		||||
                focal_length = znear - 1e-6
 | 
			
		||||
                sensor_width = focal_length / focal_length_px * 2.0
 | 
			
		||||
                principal_point_x = (
 | 
			
		||||
                    # pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
 | 
			
		||||
                    kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
 | 
			
		||||
                    * 0.5
 | 
			
		||||
                    * self.renderer._renderer.width
 | 
			
		||||
                )
 | 
			
		||||
                principal_point_y = (
 | 
			
		||||
                    kwargs.get("principal_point", cameras.principal_point)[cloud_idx][1]
 | 
			
		||||
                    * 0.5
 | 
			
		||||
                    * self.renderer._renderer.height
 | 
			
		||||
                )
 | 
			
		||||
        return (
 | 
			
		||||
            focal_length,
 | 
			
		||||
            sensor_width,
 | 
			
		||||
            principal_point_x,
 | 
			
		||||
            principal_point_y,
 | 
			
		||||
            znear,
 | 
			
		||||
            zfar,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _extract_extrinsics(
 | 
			
		||||
        self, kwargs, cloud_idx
 | 
			
		||||
    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
        # Shorthand:
 | 
			
		||||
        cameras = self.rasterizer.cameras
 | 
			
		||||
        R = kwargs.get("R", cameras.R)[cloud_idx]
 | 
			
		||||
        T = kwargs.get("T", cameras.T)[cloud_idx]
 | 
			
		||||
        norm_mat = torch.tensor(
 | 
			
		||||
            [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
            device=R.device,
 | 
			
		||||
        )
 | 
			
		||||
        cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...])
 | 
			
		||||
        cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None]))
 | 
			
		||||
        cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot))
 | 
			
		||||
        return cam_pos, cam_rot
 | 
			
		||||
 | 
			
		||||
    def _get_vert_rad(
 | 
			
		||||
        self, vert_pos, cam_pos, orthogonal_projection, focal_length, kwargs, cloud_idx
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Get point radiuses.
 | 
			
		||||
 | 
			
		||||
        These can be depending on the camera position in case of a perspective
 | 
			
		||||
        transform.
 | 
			
		||||
        """
 | 
			
		||||
        # Normalize point radiuses.
 | 
			
		||||
        # `self.rasterizer.raster_settings.radius` can either be a float
 | 
			
		||||
        # or itself a tensor.
 | 
			
		||||
        raster_rad = self.rasterizer.raster_settings.radius
 | 
			
		||||
        if kwargs.get("radius_world", False):
 | 
			
		||||
            return raster_rad
 | 
			
		||||
        if isinstance(raster_rad, torch.Tensor) and raster_rad.numel() > 1:
 | 
			
		||||
            # In this case it must be a batched torch tensor.
 | 
			
		||||
            raster_rad = raster_rad[cloud_idx]
 | 
			
		||||
        if orthogonal_projection:
 | 
			
		||||
            vert_rad = (
 | 
			
		||||
                torch.ones(
 | 
			
		||||
                    (vert_pos.shape[0],), dtype=torch.float32, device=vert_pos.device
 | 
			
		||||
                )
 | 
			
		||||
                * raster_rad
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False)
 | 
			
		||||
            vert_rad = raster_rad / focal_length * point_dists
 | 
			
		||||
            if isinstance(self.rasterizer.cameras, PerspectiveCameras):
 | 
			
		||||
                # NDC normalization happens through adjusted focal length.
 | 
			
		||||
                pass
 | 
			
		||||
            else:
 | 
			
		||||
                vert_rad = vert_rad / 2.0  # NDC normalization.
 | 
			
		||||
        return vert_rad
 | 
			
		||||
 | 
			
		||||
    def forward(self, point_clouds, **kwargs) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Get the rendering of the provided `Pointclouds`.
 | 
			
		||||
 | 
			
		||||
        The number of point clouds in the `Pointclouds` object determines the
 | 
			
		||||
        number of resulting images. The provided cameras can be either 1 or equal
 | 
			
		||||
        to the number of pointclouds (in the first case, the same camera will be
 | 
			
		||||
        used for all clouds, in the latter case each point cloud will be rendered
 | 
			
		||||
        with the corresponding camera).
 | 
			
		||||
 | 
			
		||||
        The following kwargs are support from PyTorch3D (depending on the selected
 | 
			
		||||
        camera model potentially overriding camera parameters):
 | 
			
		||||
            radius_world (bool): use the provided radiuses from the raster_settings
 | 
			
		||||
              plain as radiuses in world space. Default: False.
 | 
			
		||||
            znear (Iterable[float]): near geometry cutoff. Is required for
 | 
			
		||||
              OrthographicCameras and PerspectiveCameras.
 | 
			
		||||
            zfar (Iterable[float]): far geometry cutoff. Is required for
 | 
			
		||||
              OrthographicCameras and PerspectiveCameras.
 | 
			
		||||
            R (torch.Tensor): [Bx3x3] camera rotation matrices.
 | 
			
		||||
            T (torch.Tensor): [Bx3] camera translation vectors.
 | 
			
		||||
            principal_point (torch.Tensor): [Bx2] camera intrinsic principal
 | 
			
		||||
              point offset vectors.
 | 
			
		||||
            focal_length (torch.Tensor): [Bx1] camera intrinsic focal lengths.
 | 
			
		||||
            aspect_ratio (Iterable[float]): camera aspect ratios.
 | 
			
		||||
            fov (Iterable[float]): camera FOVs.
 | 
			
		||||
            degrees (bool): whether FOVs are specified in degrees or
 | 
			
		||||
              radians.
 | 
			
		||||
            min_x (Iterable[float]): minimum x for the FoVOrthographicCameras.
 | 
			
		||||
            max_x (Iterable[float]): maximum x for the FoVOrthographicCameras.
 | 
			
		||||
            min_y (Iterable[float]): minimum y for the FoVOrthographicCameras.
 | 
			
		||||
            max_y (Iterable[float]): maximum y for the FoVOrthographicCameras.
 | 
			
		||||
 | 
			
		||||
        The following kwargs are supported from pulsar:
 | 
			
		||||
            gamma (float): The gamma value to use. This defines the transparency for
 | 
			
		||||
                differentiability (see pulsar paper for details). Must be in [1., 1e-5]
 | 
			
		||||
                with 1.0 being mostly transparent. This keyword argument is *required*!
 | 
			
		||||
            bg_col (torch.Tensor): The background color. Must be a tensor on the same
 | 
			
		||||
                device as the point clouds, with as many channels as features (no batch
 | 
			
		||||
                dimension - it is the same for all images in the batch).
 | 
			
		||||
                Default: 0.0 for all channels.
 | 
			
		||||
            percent_allowed_difference (float): a value in [0., 1.[ with the maximum
 | 
			
		||||
                allowed difference in channel space. This is used to speed up the
 | 
			
		||||
                computation. Default: 0.01.
 | 
			
		||||
            max_n_hits (int): a hard limit on the number of sphere hits per ray.
 | 
			
		||||
                Default: max int.
 | 
			
		||||
            mode (int): render mode in {0, 1}. 0: render image; 1: render hit map.
 | 
			
		||||
        """
 | 
			
		||||
        orthogonal_projection: bool = self._conf_check(point_clouds, kwargs)
 | 
			
		||||
        # Get access to inputs. We're using the list accessor and process
 | 
			
		||||
        # them sequentially.
 | 
			
		||||
        position_list = point_clouds.points_list()
 | 
			
		||||
        features_list = point_clouds.features_list()
 | 
			
		||||
        # Result list.
 | 
			
		||||
        images = []
 | 
			
		||||
        for cloud_idx, (vert_pos, vert_col) in enumerate(
 | 
			
		||||
            zip(position_list, features_list)
 | 
			
		||||
        ):
 | 
			
		||||
            # Get intrinsics.
 | 
			
		||||
            (
 | 
			
		||||
                focal_length,
 | 
			
		||||
                sensor_width,
 | 
			
		||||
                principal_point_x,
 | 
			
		||||
                principal_point_y,
 | 
			
		||||
                znear,
 | 
			
		||||
                zfar,
 | 
			
		||||
            ) = self._extract_intrinsics(orthogonal_projection, kwargs, cloud_idx)
 | 
			
		||||
            # Get extrinsics.
 | 
			
		||||
            cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx)
 | 
			
		||||
            # Put everything together.
 | 
			
		||||
            cam_params = torch.cat(
 | 
			
		||||
                (
 | 
			
		||||
                    cam_pos,
 | 
			
		||||
                    cam_rot,
 | 
			
		||||
                    torch.tensor(
 | 
			
		||||
                        [
 | 
			
		||||
                            focal_length,
 | 
			
		||||
                            sensor_width,
 | 
			
		||||
                            principal_point_x,
 | 
			
		||||
                            principal_point_y,
 | 
			
		||||
                        ],
 | 
			
		||||
                        dtype=torch.float32,
 | 
			
		||||
                        device=cam_pos.device,
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            # Get point radiuses (can depend on camera position).
 | 
			
		||||
            vert_rad = self._get_vert_rad(
 | 
			
		||||
                vert_pos,
 | 
			
		||||
                cam_pos,
 | 
			
		||||
                orthogonal_projection,
 | 
			
		||||
                focal_length,
 | 
			
		||||
                kwargs,
 | 
			
		||||
                cloud_idx,
 | 
			
		||||
            )
 | 
			
		||||
            # Clean kwargs for passing on.
 | 
			
		||||
            gamma = kwargs["gamma"][cloud_idx]
 | 
			
		||||
            if "first_R_then_T" in kwargs.keys():
 | 
			
		||||
                raise ValueError("`first_R_then_T` is not supported in this interface.")
 | 
			
		||||
            otherargs = {
 | 
			
		||||
                argn: argv
 | 
			
		||||
                for argn, argv in kwargs.items()
 | 
			
		||||
                if argn
 | 
			
		||||
                not in [
 | 
			
		||||
                    "radius_world",
 | 
			
		||||
                    "gamma",
 | 
			
		||||
                    "znear",
 | 
			
		||||
                    "zfar",
 | 
			
		||||
                    "R",
 | 
			
		||||
                    "T",
 | 
			
		||||
                    "principal_point",
 | 
			
		||||
                    "focal_length",
 | 
			
		||||
                    "aspect_ratio",
 | 
			
		||||
                    "fov",
 | 
			
		||||
                    "degrees",
 | 
			
		||||
                    "min_x",
 | 
			
		||||
                    "max_x",
 | 
			
		||||
                    "min_y",
 | 
			
		||||
                    "max_y",
 | 
			
		||||
                ]
 | 
			
		||||
            }
 | 
			
		||||
            # background color
 | 
			
		||||
            if "bg_col" not in otherargs:
 | 
			
		||||
                bg_col = torch.zeros(
 | 
			
		||||
                    vert_col.shape[1], device=cam_params.device, dtype=torch.float32
 | 
			
		||||
                )
 | 
			
		||||
                otherargs["bg_col"] = bg_col
 | 
			
		||||
            # Go!
 | 
			
		||||
            images.append(
 | 
			
		||||
                self.renderer(
 | 
			
		||||
                    vert_pos=vert_pos,
 | 
			
		||||
                    vert_col=vert_col,
 | 
			
		||||
                    vert_rad=vert_rad,
 | 
			
		||||
                    cam_params=cam_params,
 | 
			
		||||
                    gamma=gamma,
 | 
			
		||||
                    max_depth=zfar,
 | 
			
		||||
                    min_depth=znear,
 | 
			
		||||
                    **otherargs,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        return torch.stack(images, dim=0)
 | 
			
		||||
@ -1,7 +1,5 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
 | 
			
		||||
@ -48,7 +46,7 @@ class PointsRenderer(nn.Module):
 | 
			
		||||
            fragments.idx.long().permute(0, 3, 1, 2),
 | 
			
		||||
            weights,
 | 
			
		||||
            point_clouds.features_packed().permute(1, 0),
 | 
			
		||||
            **kwargs
 | 
			
		||||
            **kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # permute so image comes at the end
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,7 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
from .external.kornia_angle_axis_to_rotation_matrix import (
 | 
			
		||||
    angle_axis_to_rotation_matrix as axis_angle_to_matrix,
 | 
			
		||||
)
 | 
			
		||||
from .rotation_conversions import (
 | 
			
		||||
    axis_angle_to_matrix,
 | 
			
		||||
    euler_angles_to_matrix,
 | 
			
		||||
    matrix_to_euler_angles,
 | 
			
		||||
    matrix_to_quaternion,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								pytorch3d/transforms/external/__init__.py
									
									
									
									
										vendored
									
									
								
							
							
						
						@ -1 +0,0 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
@ -1,94 +0,0 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
"""
 | 
			
		||||
This file contains the great angle axis to rotation matrix conversion
 | 
			
		||||
from kornia (https://github.com/arraiyopensource/kornia). The license
 | 
			
		||||
can be found in kornia_license.txt.
 | 
			
		||||
 | 
			
		||||
The method is used unchanged; the documentation has been adjusted
 | 
			
		||||
to match our doc format.
 | 
			
		||||
"""
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def angle_axis_to_rotation_matrix(angle_axis):
 | 
			
		||||
    """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        angle_axis (Tensor): tensor of 3d vector of axis-angle rotations.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Tensor: tensor of 3x3 rotation matrix.
 | 
			
		||||
 | 
			
		||||
    Shape:
 | 
			
		||||
        - Input: :math:`(N, 3)`
 | 
			
		||||
        - Output: :math:`(N, 3, 3)`
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
 | 
			
		||||
    ..code-block::python
 | 
			
		||||
 | 
			
		||||
        >>> input = torch.rand(1, 3)  # Nx3
 | 
			
		||||
        >>> output = tgm.angle_axis_to_rotation_matrix(input)  # Nx3x3
 | 
			
		||||
        >>> output = tgm.angle_axis_to_rotation_matrix(input)  # Nx3x3
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
 | 
			
		||||
        # We want to be careful to only evaluate the square root if the
 | 
			
		||||
        # norm of the angle_axis vector is greater than zero. Otherwise
 | 
			
		||||
        # we get a division by zero.
 | 
			
		||||
        k_one = 1.0
 | 
			
		||||
        theta = torch.sqrt(theta2)
 | 
			
		||||
        wxyz = angle_axis / (theta + eps)
 | 
			
		||||
        wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
 | 
			
		||||
        cos_theta = torch.cos(theta)
 | 
			
		||||
        sin_theta = torch.sin(theta)
 | 
			
		||||
 | 
			
		||||
        r00 = cos_theta + wx * wx * (k_one - cos_theta)
 | 
			
		||||
        r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
 | 
			
		||||
        r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
 | 
			
		||||
        r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
 | 
			
		||||
        r11 = cos_theta + wy * wy * (k_one - cos_theta)
 | 
			
		||||
        r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
 | 
			
		||||
        r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
 | 
			
		||||
        r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
 | 
			
		||||
        r22 = cos_theta + wz * wz * (k_one - cos_theta)
 | 
			
		||||
        rotation_matrix = torch.cat(
 | 
			
		||||
            [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1
 | 
			
		||||
        )
 | 
			
		||||
        return rotation_matrix.view(-1, 3, 3)
 | 
			
		||||
 | 
			
		||||
    def _compute_rotation_matrix_taylor(angle_axis):
 | 
			
		||||
        rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
 | 
			
		||||
        k_one = torch.ones_like(rx)
 | 
			
		||||
        rotation_matrix = torch.cat(
 | 
			
		||||
            [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1
 | 
			
		||||
        )
 | 
			
		||||
        return rotation_matrix.view(-1, 3, 3)
 | 
			
		||||
 | 
			
		||||
    # stolen from ceres/rotation.h
 | 
			
		||||
 | 
			
		||||
    _angle_axis = torch.unsqueeze(angle_axis + 1e-6, dim=1)
 | 
			
		||||
    # _angle_axis.register_hook(lambda grad: pdb.set_trace())
 | 
			
		||||
    # _angle_axis = 1e-6
 | 
			
		||||
    theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
 | 
			
		||||
    theta2 = torch.squeeze(theta2, dim=1)
 | 
			
		||||
 | 
			
		||||
    # compute rotation matrices
 | 
			
		||||
    rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
 | 
			
		||||
    rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
 | 
			
		||||
 | 
			
		||||
    # create mask to handle both cases
 | 
			
		||||
    eps = 1e-6
 | 
			
		||||
    mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
 | 
			
		||||
    mask_pos = (mask).type_as(theta2)
 | 
			
		||||
    mask_neg = (mask == False).type_as(theta2)  # noqa
 | 
			
		||||
 | 
			
		||||
    # create output pose matrix
 | 
			
		||||
    batch_size = angle_axis.shape[0]
 | 
			
		||||
    rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis)
 | 
			
		||||
    rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1)
 | 
			
		||||
    # fill output matrix with masked values
 | 
			
		||||
    rotation_matrix[..., :3, :3] = (
 | 
			
		||||
        mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
 | 
			
		||||
    )
 | 
			
		||||
    return rotation_matrix.to(angle_axis.device).type_as(angle_axis)  # Nx4x4
 | 
			
		||||
							
								
								
									
										201
									
								
								pytorch3d/transforms/external/kornia_license.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						@ -1,201 +0,0 @@
 | 
			
		||||
                                 Apache License
 | 
			
		||||
                           Version 2.0, January 2004
 | 
			
		||||
                        http://www.apache.org/licenses/
 | 
			
		||||
 | 
			
		||||
   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
 | 
			
		||||
 | 
			
		||||
   1. Definitions.
 | 
			
		||||
 | 
			
		||||
      "License" shall mean the terms and conditions for use, reproduction,
 | 
			
		||||
      and distribution as defined by Sections 1 through 9 of this document.
 | 
			
		||||
 | 
			
		||||
      "Licensor" shall mean the copyright owner or entity authorized by
 | 
			
		||||
      the copyright owner that is granting the License.
 | 
			
		||||
 | 
			
		||||
      "Legal Entity" shall mean the union of the acting entity and all
 | 
			
		||||
      other entities that control, are controlled by, or are under common
 | 
			
		||||
      control with that entity. For the purposes of this definition,
 | 
			
		||||
      "control" means (i) the power, direct or indirect, to cause the
 | 
			
		||||
      direction or management of such entity, whether by contract or
 | 
			
		||||
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
 | 
			
		||||
      outstanding shares, or (iii) beneficial ownership of such entity.
 | 
			
		||||
 | 
			
		||||
      "You" (or "Your") shall mean an individual or Legal Entity
 | 
			
		||||
      exercising permissions granted by this License.
 | 
			
		||||
 | 
			
		||||
      "Source" form shall mean the preferred form for making modifications,
 | 
			
		||||
      including but not limited to software source code, documentation
 | 
			
		||||
      source, and configuration files.
 | 
			
		||||
 | 
			
		||||
      "Object" form shall mean any form resulting from mechanical
 | 
			
		||||
      transformation or translation of a Source form, including but
 | 
			
		||||
      not limited to compiled object code, generated documentation,
 | 
			
		||||
      and conversions to other media types.
 | 
			
		||||
 | 
			
		||||
      "Work" shall mean the work of authorship, whether in Source or
 | 
			
		||||
      Object form, made available under the License, as indicated by a
 | 
			
		||||
      copyright notice that is included in or attached to the work
 | 
			
		||||
      (an example is provided in the Appendix below).
 | 
			
		||||
 | 
			
		||||
      "Derivative Works" shall mean any work, whether in Source or Object
 | 
			
		||||
      form, that is based on (or derived from) the Work and for which the
 | 
			
		||||
      editorial revisions, annotations, elaborations, or other modifications
 | 
			
		||||
      represent, as a whole, an original work of authorship. For the purposes
 | 
			
		||||
      of this License, Derivative Works shall not include works that remain
 | 
			
		||||
      separable from, or merely link (or bind by name) to the interfaces of,
 | 
			
		||||
      the Work and Derivative Works thereof.
 | 
			
		||||
 | 
			
		||||
      "Contribution" shall mean any work of authorship, including
 | 
			
		||||
      the original version of the Work and any modifications or additions
 | 
			
		||||
      to that Work or Derivative Works thereof, that is intentionally
 | 
			
		||||
      submitted to Licensor for inclusion in the Work by the copyright owner
 | 
			
		||||
      or by an individual or Legal Entity authorized to submit on behalf of
 | 
			
		||||
      the copyright owner. For the purposes of this definition, "submitted"
 | 
			
		||||
      means any form of electronic, verbal, or written communication sent
 | 
			
		||||
      to the Licensor or its representatives, including but not limited to
 | 
			
		||||
      communication on electronic mailing lists, source code control systems,
 | 
			
		||||
      and issue tracking systems that are managed by, or on behalf of, the
 | 
			
		||||
      Licensor for the purpose of discussing and improving the Work, but
 | 
			
		||||
      excluding communication that is conspicuously marked or otherwise
 | 
			
		||||
      designated in writing by the copyright owner as "Not a Contribution."
 | 
			
		||||
 | 
			
		||||
      "Contributor" shall mean Licensor and any individual or Legal Entity
 | 
			
		||||
      on behalf of whom a Contribution has been received by Licensor and
 | 
			
		||||
      subsequently incorporated within the Work.
 | 
			
		||||
 | 
			
		||||
   2. Grant of Copyright License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      copyright license to reproduce, prepare Derivative Works of,
 | 
			
		||||
      publicly display, publicly perform, sublicense, and distribute the
 | 
			
		||||
      Work and such Derivative Works in Source or Object form.
 | 
			
		||||
 | 
			
		||||
   3. Grant of Patent License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      (except as stated in this section) patent license to make, have made,
 | 
			
		||||
      use, offer to sell, sell, import, and otherwise transfer the Work,
 | 
			
		||||
      where such license applies only to those patent claims licensable
 | 
			
		||||
      by such Contributor that are necessarily infringed by their
 | 
			
		||||
      Contribution(s) alone or by combination of their Contribution(s)
 | 
			
		||||
      with the Work to which such Contribution(s) was submitted. If You
 | 
			
		||||
      institute patent litigation against any entity (including a
 | 
			
		||||
      cross-claim or counterclaim in a lawsuit) alleging that the Work
 | 
			
		||||
      or a Contribution incorporated within the Work constitutes direct
 | 
			
		||||
      or contributory patent infringement, then any patent licenses
 | 
			
		||||
      granted to You under this License for that Work shall terminate
 | 
			
		||||
      as of the date such litigation is filed.
 | 
			
		||||
 | 
			
		||||
   4. Redistribution. You may reproduce and distribute copies of the
 | 
			
		||||
      Work or Derivative Works thereof in any medium, with or without
 | 
			
		||||
      modifications, and in Source or Object form, provided that You
 | 
			
		||||
      meet the following conditions:
 | 
			
		||||
 | 
			
		||||
      (a) You must give any other recipients of the Work or
 | 
			
		||||
          Derivative Works a copy of this License; and
 | 
			
		||||
 | 
			
		||||
      (b) You must cause any modified files to carry prominent notices
 | 
			
		||||
          stating that You changed the files; and
 | 
			
		||||
 | 
			
		||||
      (c) You must retain, in the Source form of any Derivative Works
 | 
			
		||||
          that You distribute, all copyright, patent, trademark, and
 | 
			
		||||
          attribution notices from the Source form of the Work,
 | 
			
		||||
          excluding those notices that do not pertain to any part of
 | 
			
		||||
          the Derivative Works; and
 | 
			
		||||
 | 
			
		||||
      (d) If the Work includes a "NOTICE" text file as part of its
 | 
			
		||||
          distribution, then any Derivative Works that You distribute must
 | 
			
		||||
          include a readable copy of the attribution notices contained
 | 
			
		||||
          within such NOTICE file, excluding those notices that do not
 | 
			
		||||
          pertain to any part of the Derivative Works, in at least one
 | 
			
		||||
          of the following places: within a NOTICE text file distributed
 | 
			
		||||
          as part of the Derivative Works; within the Source form or
 | 
			
		||||
          documentation, if provided along with the Derivative Works; or,
 | 
			
		||||
          within a display generated by the Derivative Works, if and
 | 
			
		||||
          wherever such third-party notices normally appear. The contents
 | 
			
		||||
          of the NOTICE file are for informational purposes only and
 | 
			
		||||
          do not modify the License. You may add Your own attribution
 | 
			
		||||
          notices within Derivative Works that You distribute, alongside
 | 
			
		||||
          or as an addendum to the NOTICE text from the Work, provided
 | 
			
		||||
          that such additional attribution notices cannot be construed
 | 
			
		||||
          as modifying the License.
 | 
			
		||||
 | 
			
		||||
      You may add Your own copyright statement to Your modifications and
 | 
			
		||||
      may provide additional or different license terms and conditions
 | 
			
		||||
      for use, reproduction, or distribution of Your modifications, or
 | 
			
		||||
      for any such Derivative Works as a whole, provided Your use,
 | 
			
		||||
      reproduction, and distribution of the Work otherwise complies with
 | 
			
		||||
      the conditions stated in this License.
 | 
			
		||||
 | 
			
		||||
   5. Submission of Contributions. Unless You explicitly state otherwise,
 | 
			
		||||
      any Contribution intentionally submitted for inclusion in the Work
 | 
			
		||||
      by You to the Licensor shall be under the terms and conditions of
 | 
			
		||||
      this License, without any additional terms or conditions.
 | 
			
		||||
      Notwithstanding the above, nothing herein shall supersede or modify
 | 
			
		||||
      the terms of any separate license agreement you may have executed
 | 
			
		||||
      with Licensor regarding such Contributions.
 | 
			
		||||
 | 
			
		||||
   6. Trademarks. This License does not grant permission to use the trade
 | 
			
		||||
      names, trademarks, service marks, or product names of the Licensor,
 | 
			
		||||
      except as required for reasonable and customary use in describing the
 | 
			
		||||
      origin of the Work and reproducing the content of the NOTICE file.
 | 
			
		||||
 | 
			
		||||
   7. Disclaimer of Warranty. Unless required by applicable law or
 | 
			
		||||
      agreed to in writing, Licensor provides the Work (and each
 | 
			
		||||
      Contributor provides its Contributions) on an "AS IS" BASIS,
 | 
			
		||||
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 | 
			
		||||
      implied, including, without limitation, any warranties or conditions
 | 
			
		||||
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
 | 
			
		||||
      PARTICULAR PURPOSE. You are solely responsible for determining the
 | 
			
		||||
      appropriateness of using or redistributing the Work and assume any
 | 
			
		||||
      risks associated with Your exercise of permissions under this License.
 | 
			
		||||
 | 
			
		||||
   8. Limitation of Liability. In no event and under no legal theory,
 | 
			
		||||
      whether in tort (including negligence), contract, or otherwise,
 | 
			
		||||
      unless required by applicable law (such as deliberate and grossly
 | 
			
		||||
      negligent acts) or agreed to in writing, shall any Contributor be
 | 
			
		||||
      liable to You for damages, including any direct, indirect, special,
 | 
			
		||||
      incidental, or consequential damages of any character arising as a
 | 
			
		||||
      result of this License or out of the use or inability to use the
 | 
			
		||||
      Work (including but not limited to damages for loss of goodwill,
 | 
			
		||||
      work stoppage, computer failure or malfunction, or any and all
 | 
			
		||||
      other commercial damages or losses), even if such Contributor
 | 
			
		||||
      has been advised of the possibility of such damages.
 | 
			
		||||
 | 
			
		||||
   9. Accepting Warranty or Additional Liability. While redistributing
 | 
			
		||||
      the Work or Derivative Works thereof, You may choose to offer,
 | 
			
		||||
      and charge a fee for, acceptance of support, warranty, indemnity,
 | 
			
		||||
      or other liability obligations and/or rights consistent with this
 | 
			
		||||
      License. However, in accepting such obligations, You may act only
 | 
			
		||||
      on Your own behalf and on Your sole responsibility, not on behalf
 | 
			
		||||
      of any other Contributor, and only if You agree to indemnify,
 | 
			
		||||
      defend, and hold each Contributor harmless for any liability
 | 
			
		||||
      incurred by, or claims asserted against, such Contributor by reason
 | 
			
		||||
      of your accepting any such warranty or additional liability.
 | 
			
		||||
 | 
			
		||||
   END OF TERMS AND CONDITIONS
 | 
			
		||||
 | 
			
		||||
   APPENDIX: How to apply the Apache License to your work.
 | 
			
		||||
 | 
			
		||||
      To apply the Apache License to your work, attach the following
 | 
			
		||||
      boilerplate notice, with the fields enclosed by brackets "[]"
 | 
			
		||||
      replaced with your own identifying information. (Don't include
 | 
			
		||||
      the brackets!)  The text should be enclosed in the appropriate
 | 
			
		||||
      comment syntax for the file format. We also recommend that a
 | 
			
		||||
      file or class name and description of purpose be included on the
 | 
			
		||||
      same "printed page" as the copyright notice for easier
 | 
			
		||||
      identification within third-party archives.
 | 
			
		||||
 | 
			
		||||
   Copyright [yyyy] [name of copyright owner]
 | 
			
		||||
 | 
			
		||||
   Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
   you may not use this file except in compliance with the License.
 | 
			
		||||
   You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
       http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
   Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
   distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
   See the License for the specific language governing permissions and
 | 
			
		||||
   limitations under the License.
 | 
			
		||||
							
								
								
									
										20
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						@ -13,13 +13,8 @@ from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
 | 
			
		||||
def get_extensions():
 | 
			
		||||
    this_dir = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
    extensions_dir = os.path.join(this_dir, "pytorch3d", "csrc")
 | 
			
		||||
 | 
			
		||||
    main_source = os.path.join(extensions_dir, "ext.cpp")
 | 
			
		||||
    sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
 | 
			
		||||
    source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"))
 | 
			
		||||
 | 
			
		||||
    sources = [main_source] + sources
 | 
			
		||||
 | 
			
		||||
    sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)
 | 
			
		||||
    source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu"), recursive=True)
 | 
			
		||||
    extension = CppExtension
 | 
			
		||||
 | 
			
		||||
    extra_compile_args = {"cxx": ["-std=c++14"]}
 | 
			
		||||
@ -30,7 +25,18 @@ def get_extensions():
 | 
			
		||||
        extension = CUDAExtension
 | 
			
		||||
        sources += source_cuda
 | 
			
		||||
        define_macros += [("WITH_CUDA", None)]
 | 
			
		||||
        cub_home = os.environ.get("CUB_HOME", None)
 | 
			
		||||
        if cub_home is None:
 | 
			
		||||
            raise Exception(
 | 
			
		||||
                "The environment variable `CUB_HOME` was not found. "
 | 
			
		||||
                "NVIDIA CUB is required for compilation and can be downloaded "
 | 
			
		||||
                "from `https://github.com/NVIDIA/cub/releases`. You can unpack "
 | 
			
		||||
                "it to a location of your choice and set the environment variable "
 | 
			
		||||
                "`CUB_HOME` to the folder containing the `CMakeListst.txt` file."
 | 
			
		||||
            )
 | 
			
		||||
        nvcc_args = [
 | 
			
		||||
            "-I%s" % (os.path.realpath(cub_home).replace("\\ ", " ")),
 | 
			
		||||
            "-std=c++14",
 | 
			
		||||
            "-DCUDA_HAS_FP16=1",
 | 
			
		||||
            "-D__CUDA_NO_HALF_OPERATORS__",
 | 
			
		||||
            "-D__CUDA_NO_HALF_CONVERSIONS__",
 | 
			
		||||
 | 
			
		||||
| 
		 After Width: | Height: | Size: 1.9 KiB  | 
| 
		 After Width: | Height: | Size: 3.3 KiB  | 
| 
		 After Width: | Height: | Size: 1.9 KiB  | 
| 
		 After Width: | Height: | Size: 3.0 KiB  | 
| 
		 After Width: | Height: | Size: 2.1 KiB  | 
| 
		 After Width: | Height: | Size: 3.4 KiB  | 
| 
		 After Width: | Height: | Size: 2.1 KiB  | 
| 
		 After Width: | Height: | Size: 3.2 KiB  | 
@ -16,6 +16,8 @@ from PIL import Image
 | 
			
		||||
from pytorch3d.renderer.cameras import (
 | 
			
		||||
    FoVOrthographicCameras,
 | 
			
		||||
    FoVPerspectiveCameras,
 | 
			
		||||
    OrthographicCameras,
 | 
			
		||||
    PerspectiveCameras,
 | 
			
		||||
    look_at_view_transform,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum
 | 
			
		||||
@ -25,6 +27,7 @@ from pytorch3d.renderer.points import (
 | 
			
		||||
    PointsRasterizationSettings,
 | 
			
		||||
    PointsRasterizer,
 | 
			
		||||
    PointsRenderer,
 | 
			
		||||
    PulsarPointsRenderer,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
from pytorch3d.utils.ico_sphere import ico_sphere
 | 
			
		||||
@ -72,6 +75,145 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                )
 | 
			
		||||
            self.assertClose(rgb, image_ref)
 | 
			
		||||
 | 
			
		||||
    def test_simple_sphere_pulsar(self):
 | 
			
		||||
        for device in [torch.device("cpu"), torch.device("cuda")]:
 | 
			
		||||
            sphere_mesh = ico_sphere(1, device)
 | 
			
		||||
            verts_padded = sphere_mesh.verts_padded()
 | 
			
		||||
            # Shift vertices to check coordinate frames are correct.
 | 
			
		||||
            verts_padded[..., 1] += 0.2
 | 
			
		||||
            verts_padded[..., 0] += 0.2
 | 
			
		||||
            pointclouds = Pointclouds(
 | 
			
		||||
                points=verts_padded, features=torch.ones_like(verts_padded)
 | 
			
		||||
            )
 | 
			
		||||
            for azimuth in [0.0, 90.0]:
 | 
			
		||||
                R, T = look_at_view_transform(2.7, 0.0, azimuth)
 | 
			
		||||
                for camera_name, cameras in [
 | 
			
		||||
                    ("fovperspective", FoVPerspectiveCameras(device=device, R=R, T=T)),
 | 
			
		||||
                    (
 | 
			
		||||
                        "fovorthographic",
 | 
			
		||||
                        FoVOrthographicCameras(device=device, R=R, T=T),
 | 
			
		||||
                    ),
 | 
			
		||||
                    ("perspective", PerspectiveCameras(device=device, R=R, T=T)),
 | 
			
		||||
                    ("orthographic", OrthographicCameras(device=device, R=R, T=T)),
 | 
			
		||||
                ]:
 | 
			
		||||
                    raster_settings = PointsRasterizationSettings(
 | 
			
		||||
                        image_size=256, radius=5e-2, points_per_pixel=1
 | 
			
		||||
                    )
 | 
			
		||||
                    rasterizer = PointsRasterizer(
 | 
			
		||||
                        cameras=cameras, raster_settings=raster_settings
 | 
			
		||||
                    )
 | 
			
		||||
                    renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device)
 | 
			
		||||
                    # Load reference image
 | 
			
		||||
                    filename = (
 | 
			
		||||
                        "pulsar_simple_pointcloud_sphere_"
 | 
			
		||||
                        f"azimuth{azimuth}_{camera_name}.png"
 | 
			
		||||
                    )
 | 
			
		||||
                    image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
 | 
			
		||||
                    images = renderer(
 | 
			
		||||
                        pointclouds, gamma=(1e-3,), znear=(1.0,), zfar=(100.0,)
 | 
			
		||||
                    )
 | 
			
		||||
                    rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
                    if DEBUG:
 | 
			
		||||
                        filename = "DEBUG_%s" % filename
 | 
			
		||||
                        Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                            DATA_DIR / filename
 | 
			
		||||
                        )
 | 
			
		||||
                    self.assertClose(rgb, image_ref, rtol=7e-3, atol=5e-3)
 | 
			
		||||
 | 
			
		||||
    def test_unified_inputs_pulsar(self):
 | 
			
		||||
        # Test data on different devices.
 | 
			
		||||
        for device in [torch.device("cpu"), torch.device("cuda")]:
 | 
			
		||||
            sphere_mesh = ico_sphere(1, device)
 | 
			
		||||
            verts_padded = sphere_mesh.verts_padded()
 | 
			
		||||
            pointclouds = Pointclouds(
 | 
			
		||||
                points=verts_padded, features=torch.ones_like(verts_padded)
 | 
			
		||||
            )
 | 
			
		||||
            R, T = look_at_view_transform(2.7, 0.0, 0.0)
 | 
			
		||||
            # Test the different camera types.
 | 
			
		||||
            for _, cameras in [
 | 
			
		||||
                ("fovperspective", FoVPerspectiveCameras(device=device, R=R, T=T)),
 | 
			
		||||
                (
 | 
			
		||||
                    "fovorthographic",
 | 
			
		||||
                    FoVOrthographicCameras(device=device, R=R, T=T),
 | 
			
		||||
                ),
 | 
			
		||||
                ("perspective", PerspectiveCameras(device=device, R=R, T=T)),
 | 
			
		||||
                ("orthographic", OrthographicCameras(device=device, R=R, T=T)),
 | 
			
		||||
            ]:
 | 
			
		||||
                # Test different ways for image size specification.
 | 
			
		||||
                for image_size in (256, (256, 256)):
 | 
			
		||||
                    raster_settings = PointsRasterizationSettings(
 | 
			
		||||
                        image_size=image_size, radius=5e-2, points_per_pixel=1
 | 
			
		||||
                    )
 | 
			
		||||
                    rasterizer = PointsRasterizer(
 | 
			
		||||
                        cameras=cameras, raster_settings=raster_settings
 | 
			
		||||
                    )
 | 
			
		||||
                    # Test that the compositor can be provided. It's value is ignored
 | 
			
		||||
                    # so use a dummy.
 | 
			
		||||
                    _ = PulsarPointsRenderer(rasterizer=rasterizer, compositor=1).to(
 | 
			
		||||
                        device
 | 
			
		||||
                    )
 | 
			
		||||
                    # Constructor without compositor.
 | 
			
		||||
                    _ = PulsarPointsRenderer(rasterizer=rasterizer).to(device)
 | 
			
		||||
                    # Constructor with n_channels.
 | 
			
		||||
                    _ = PulsarPointsRenderer(rasterizer=rasterizer, n_channels=3).to(
 | 
			
		||||
                        device
 | 
			
		||||
                    )
 | 
			
		||||
                    # Constructor with max_num_spheres.
 | 
			
		||||
                    renderer = PulsarPointsRenderer(
 | 
			
		||||
                        rasterizer=rasterizer, max_num_spheres=1000
 | 
			
		||||
                    ).to(device)
 | 
			
		||||
                    # Test the forward function.
 | 
			
		||||
                    if isinstance(cameras, (PerspectiveCameras, OrthographicCameras)):
 | 
			
		||||
                        # znear and zfar is required in this case.
 | 
			
		||||
                        self.assertRaises(
 | 
			
		||||
                            ValueError,
 | 
			
		||||
                            lambda: renderer.forward(
 | 
			
		||||
                                point_clouds=pointclouds, gamma=(1e-4,)
 | 
			
		||||
                            ),
 | 
			
		||||
                        )
 | 
			
		||||
                        renderer.forward(
 | 
			
		||||
                            point_clouds=pointclouds,
 | 
			
		||||
                            gamma=(1e-4,),
 | 
			
		||||
                            znear=(1.0,),
 | 
			
		||||
                            zfar=(2.0,),
 | 
			
		||||
                        )
 | 
			
		||||
                        # znear and zfar must be batched.
 | 
			
		||||
                        self.assertRaises(
 | 
			
		||||
                            TypeError,
 | 
			
		||||
                            lambda: renderer.forward(
 | 
			
		||||
                                point_clouds=pointclouds,
 | 
			
		||||
                                gamma=(1e-4,),
 | 
			
		||||
                                znear=1.0,
 | 
			
		||||
                                zfar=(2.0,),
 | 
			
		||||
                            ),
 | 
			
		||||
                        )
 | 
			
		||||
                        self.assertRaises(
 | 
			
		||||
                            TypeError,
 | 
			
		||||
                            lambda: renderer.forward(
 | 
			
		||||
                                point_clouds=pointclouds,
 | 
			
		||||
                                gamma=(1e-4,),
 | 
			
		||||
                                znear=(1.0,),
 | 
			
		||||
                                zfar=2.0,
 | 
			
		||||
                            ),
 | 
			
		||||
                        )
 | 
			
		||||
                    else:
 | 
			
		||||
                        # gamma must be batched.
 | 
			
		||||
                        self.assertRaises(
 | 
			
		||||
                            TypeError,
 | 
			
		||||
                            lambda: renderer.forward(
 | 
			
		||||
                                point_clouds=pointclouds, gamma=1e-4
 | 
			
		||||
                            ),
 | 
			
		||||
                        )
 | 
			
		||||
                        renderer.forward(point_clouds=pointclouds, gamma=(1e-4,))
 | 
			
		||||
                        # rasterizer width and height change.
 | 
			
		||||
                        renderer.rasterizer.raster_settings.image_size = 0
 | 
			
		||||
                        self.assertRaises(
 | 
			
		||||
                            ValueError,
 | 
			
		||||
                            lambda: renderer.forward(
 | 
			
		||||
                                point_clouds=pointclouds, gamma=(1e-4,)
 | 
			
		||||
                            ),
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
    def test_pointcloud_with_features(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
 | 
			
		||||
 | 
			
		||||