diff --git a/INSTALL.md b/INSTALL.md index fa87746d..069bf4e1 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -14,18 +14,20 @@ The core library is written in PyTorch. Several components have underlying imple - gcc & g++ ≥ 4.9 - [fvcore](https://github.com/facebookresearch/fvcore) - If CUDA is to be used, use at least version 9.2. +- If CUDA is to be used, the CUB library must be available. Starting from CUDA 11, CUB is part of CUDA. If you're using an earlier CUDA version and are not using conda, download the CUB library from https://github.com/NVIDIA/cub/releases and unpack it to a folder of your choice. Define the environment variable CUB_HOME before building and point it to the directory that contains `CMakeLists.txt` for CUB. -These can be installed by running: +The dependencies can be installed by running: ``` conda create -n pytorch3d python=3.8 conda activate pytorch3d conda install -c pytorch pytorch=1.6.0 torchvision cudatoolkit=10.2 conda install -c conda-forge -c fvcore fvcore +conda install -c cub ``` ### Tests/Linting and Demos -For developing on top of PyTorch3D or contributing, you will need to run the linter and tests. If you want to run any of the notebook tutorials as `docs/tutorials` you will also need matplotlib. +For developing on top of PyTorch3D or contributing, you will need to run the linter and tests. If you want to run any of the notebook tutorials as `docs/tutorials` or the examples in `docs/examples` you will also need matplotlib and OpenCV. - scikit-image - black - isort @@ -35,12 +37,13 @@ For developing on top of PyTorch3D or contributing, you will need to run the lin - jupyter - imageio - plotly +- opencv-python These can be installed by running: ``` -# Demos +# Demos and examples conda install jupyter -pip install scikit-image matplotlib imageio plotly +pip install scikit-image matplotlib imageio plotly opencv-python # Tests/Linting pip install black 'isort<5' flake8 flake8-bugbear flake8-comprehensions @@ -81,6 +84,8 @@ To install using the code of the released version instead of from the main branc pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable' ``` +For CUDA builds with versions earlier than CUDA 11, set `CUB_HOME` before building as described above. + **Install from Github on macOS:** Some environment variables should be provided, like this. ``` @@ -92,7 +97,7 @@ MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++ pip install 'git+https://git git clone https://github.com/facebookresearch/pytorch3d.git cd pytorch3d && pip install -e . ``` -To rebuild after installing from a local clone run, `rm -rf build/ **/*.so` then `pip install -e .`. You often need to rebuild pytorch3d after reinstalling PyTorch. +To rebuild after installing from a local clone run, `rm -rf build/ **/*.so` then `pip install -e .`. You often need to rebuild pytorch3d after reinstalling PyTorch. For CUDA builds with versions earlier than CUDA 11, set `CUB_HOME` before building as described above. **Install from local clone on macOS:** ``` diff --git a/README.md b/README.md index 7bfde938..eac683f6 100644 --- a/README.md +++ b/README.md @@ -106,11 +106,23 @@ If you find PyTorch3D useful in your research, please cite our tech report: } ``` +If you are using the pulsar backend for sphere-rendering (the `PulsarPointRenderer` or `pytorch3d.renderer.points.pulsar.Renderer`), please cite the tech report: + +```bibtex +@article{lassner2020pulsar, + author = {Christoph Lassner}, + title = {Fast Differentiable Raycasting for Neural Rendering using Sphere-based Representations}, + journal = {arXiv:2004.07484}, + year = {2020}, +} +``` ## News Please see below for a timeline of the codebase updates in reverse chronological order. We are sharing updates on the releases as well as research projects which are built with PyTorch3D. The changelogs for the releases are available under [`Releases`](https://github.com/facebookresearch/pytorch3d/releases), and the builds can be installed using `conda` as per the instructions in [INSTALL.md](INSTALL.md). +**[November 2nd 2020]:** PyTorch3D v0.3 released, integrating the pulsar backend. + **[Aug 28th 2020]:** PyTorch3D v0.2.5 released **[July 17th 2020]:** PyTorch3D tech report published on ArXiv: https://arxiv.org/abs/2007.08501 diff --git a/docs/examples/pulsar_basic.py b/docs/examples/pulsar_basic.py index 62dccd03..50f88202 100755 --- a/docs/examples/pulsar_basic.py +++ b/docs/examples/pulsar_basic.py @@ -5,6 +5,7 @@ This example demonstrates the most trivial, direct interface of the pulsar sphere renderer. It renders and saves an image with 10 random spheres. Output: basic.png. """ +import math from os import path import imageio @@ -12,11 +13,15 @@ import torch from pytorch3d.renderer.points.pulsar import Renderer +torch.manual_seed(1) + n_points = 10 width = 1_000 height = 1_000 device = torch.device("cuda") -renderer = Renderer(width, height, n_points).to(device) +# The PyTorch3D system is right handed; in pulsar you can choose the handedness. +# For easy reproducibility we use a right handed coordinate system here. +renderer = Renderer(width, height, n_points, right_handed_system=True).to(device) # Generate sample data. vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0 vert_pos[:, 2] += 25.0 @@ -29,7 +34,7 @@ cam_params = torch.tensor( 0.0, 0.0, # Position 0, 0, 0 (x, y, z). 0.0, - 0.0, + math.pi, # Because of the right handed system, the camera must look 'back'. 0.0, # Rotation 0, 0, 0 (in axis-angle format). 5.0, # Focal length in world size. 2.0, # Sensor size in world size. diff --git a/docs/examples/pulsar_basic_unified.py b/docs/examples/pulsar_basic_unified.py new file mode 100755 index 00000000..50efeb31 --- /dev/null +++ b/docs/examples/pulsar_basic_unified.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +This example demonstrates the most trivial use of the pulsar PyTorch3D +interface for sphere renderering. It renders and saves an image with +10 random spheres. +Output: basic-pt3d.png. +""" +from os import path + +import imageio +import torch +from pytorch3d.renderer import PerspectiveCameras # , look_at_view_transform +from pytorch3d.renderer import ( + PointsRasterizationSettings, + PointsRasterizer, + PulsarPointsRenderer, +) +from pytorch3d.structures import Pointclouds + + +torch.manual_seed(1) + +n_points = 10 +width = 1_000 +height = 1_000 +device = torch.device("cuda") + +# Generate sample data. +vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0 +vert_pos[:, 2] += 25.0 +vert_pos[:, :2] -= 5.0 +vert_col = torch.rand(n_points, 3, dtype=torch.float32, device=device) +pcl = Pointclouds(points=vert_pos[None, ...], features=vert_col[None, ...]) +# Alternatively, you can also use the look_at_view_transform to get R and T: +# R, T = look_at_view_transform( +# dist=30.0, elev=0.0, azim=180.0, at=((0.0, 0.0, 30.0),), up=((0, 1, 0),), +# ) +cameras = PerspectiveCameras( + # The focal length must be double the size for PyTorch3D because of the NDC + # coordinates spanning a range of two - and they must be normalized by the + # sensor width (see the pulsar example). This means we need here + # 5.0 * 2.0 / 2.0 to get the equivalent results as in pulsar. + focal_length=(5.0 * 2.0 / 2.0,), + R=torch.eye(3, dtype=torch.float32, device=device)[None, ...], + T=torch.zeros((1, 3), dtype=torch.float32, device=device), + image_size=((width, height),), + device=device, +) +vert_rad = torch.rand(n_points, dtype=torch.float32, device=device) +raster_settings = PointsRasterizationSettings( + image_size=(width, height), + radius=vert_rad, +) +rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) +renderer = PulsarPointsRenderer(rasterizer=rasterizer).to(device) +# Render. +image = renderer( + pcl, + gamma=(1.0e-1,), # Renderer blending parameter gamma, in [1., 1e-5]. + znear=(1.0,), + zfar=(45.0,), + radius_world=True, + bg_col=torch.ones((3,), dtype=torch.float32, device=device), +)[0] +print("Writing image to `%s`." % (path.abspath("basic-pt3d.png"))) +imageio.imsave("basic-pt3d.png", (image.cpu().detach() * 255.0).to(torch.uint8).numpy()) diff --git a/docs/examples/pulsar_cam.py b/docs/examples/pulsar_cam.py index dcc08759..12d26a81 100755 --- a/docs/examples/pulsar_cam.py +++ b/docs/examples/pulsar_cam.py @@ -7,7 +7,9 @@ pulsar interface. For this, a reference image has been pre-generated The same scene parameterization is loaded and the camera parameters distorted. Gradient-based optimization is used to converge towards the original camera parameters. +Output: cam.gif. """ +import math from os import path import cv2 @@ -15,6 +17,7 @@ import imageio import numpy as np import torch from pytorch3d.renderer.points.pulsar import Renderer +from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_rotation_6d from torch import nn, optim @@ -66,19 +69,18 @@ class SceneModel(nn.Module): ) self.register_parameter( "cam_rot", + # We're using the 6D rot. representation for better gradients. nn.Parameter( - torch.tensor( - [ - # We're using the 6D rot. representation for better gradients. - 0.9995, - 0.0300445, - -0.0098482, - -0.0299445, - 0.9995, - 0.0101482, - ], - dtype=torch.float32, - ), + matrix_to_rotation_6d( + axis_angle_to_matrix( + torch.tensor( + [ + [0.02, math.pi + 0.02, 0.01], + ], + dtype=torch.float32, + ) + ) + )[0], requires_grad=True, ), ) @@ -88,7 +90,7 @@ class SceneModel(nn.Module): torch.tensor([4.8, 1.8], dtype=torch.float32), requires_grad=True ), ) - self.renderer = Renderer(width, height, n_points) + self.renderer = Renderer(width, height, n_points, right_handed_system=True) def forward(self): return self.renderer.forward( @@ -106,7 +108,7 @@ ref = ( torch.from_numpy( imageio.imread( "../../tests/pulsar/reference/examples_TestRenderer_test_cam.png" - ) + )[:, ::-1, :].copy() ).to(torch.float32) / 255.0 ).to(device) diff --git a/docs/examples/pulsar_cam_unified.py b/docs/examples/pulsar_cam_unified.py new file mode 100755 index 00000000..cca44dff --- /dev/null +++ b/docs/examples/pulsar_cam_unified.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +This example demonstrates camera parameter optimization with the pulsar +PyTorch3D interface. For this, a reference image has been pre-generated +(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_cam.png`). +The same scene parameterization is loaded and the camera parameters +distorted. Gradient-based optimization is used to converge towards the +original camera parameters. +Output: cam-pt3d.gif +""" +from os import path + +import cv2 +import imageio +import numpy as np +import torch +from pytorch3d.renderer.cameras import PerspectiveCameras # , look_at_view_transform +from pytorch3d.renderer.points import ( + PointsRasterizationSettings, + PointsRasterizer, + PulsarPointsRenderer, +) +from pytorch3d.structures.pointclouds import Pointclouds +from pytorch3d.transforms import axis_angle_to_matrix +from torch import nn, optim + + +n_points = 20 +width = 1_000 +height = 1_000 +device = torch.device("cuda") + + +class SceneModel(nn.Module): + """ + A simple scene model to demonstrate use of pulsar in PyTorch modules. + + The scene model is parameterized with sphere locations (vert_pos), + channel content (vert_col), radiuses (vert_rad), camera position (cam_pos), + camera rotation (cam_rot) and sensor focal length and width (cam_sensor). + + The forward method of the model renders this scene description. Any + of these parameters could instead be passed as inputs to the forward + method and come from a different model. + """ + + def __init__(self): + super(SceneModel, self).__init__() + self.gamma = 0.1 + # Points. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False)) + self.register_parameter( + "vert_col", + nn.Parameter( + torch.rand(n_points, 3, dtype=torch.float32), + requires_grad=False, + ), + ) + self.register_parameter( + "vert_rad", + nn.Parameter( + torch.rand(n_points, dtype=torch.float32), + requires_grad=False, + ), + ) + self.register_parameter( + "cam_pos", + nn.Parameter( + torch.tensor([0.1, 0.1, 0.0], dtype=torch.float32), + requires_grad=True, + ), + ) + self.register_parameter( + "cam_rot", + # We're using the 6D rot. representation for better gradients. + nn.Parameter( + axis_angle_to_matrix( + torch.tensor( + [ + [0.02, 0.02, 0.01], + ], + dtype=torch.float32, + ) + )[0], + requires_grad=True, + ), + ) + self.register_parameter( + "focal_length", + nn.Parameter( + torch.tensor( + [ + 4.8 * 2.0 / 2.0, + ], + dtype=torch.float32, + ), + requires_grad=True, + ), + ) + self.cameras = PerspectiveCameras( + # The focal length must be double the size for PyTorch3D because of the NDC + # coordinates spanning a range of two - and they must be normalized by the + # sensor width (see the pulsar example). This means we need here + # 5.0 * 2.0 / 2.0 to get the equivalent results as in pulsar. + # + # R, T and f are provided here, but will be provided again + # at every call to the forward method. The reason are problems + # with PyTorch which makes device placement for gradients problematic + # for tensors which are themselves on a 'gradient path' but not + # leafs in the calculation tree. This will be addressed by an architectural + # change in PyTorch3D in the future. Until then, this workaround is + # recommended. + focal_length=self.focal_length, + R=self.cam_rot[None, ...], + T=self.cam_pos[None, ...], + image_size=((width, height),), + device=device, + ) + raster_settings = PointsRasterizationSettings( + image_size=(width, height), + radius=self.vert_rad, + ) + rasterizer = PointsRasterizer( + cameras=self.cameras, raster_settings=raster_settings + ) + self.renderer = PulsarPointsRenderer(rasterizer=rasterizer) + + def forward(self): + # The Pointclouds object creates copies of it's arguments - that's why + # we have to create a new object in every forward step. + pcl = Pointclouds( + points=self.vert_pos[None, ...], features=self.vert_col[None, ...] + ) + return self.renderer( + pcl, + gamma=(self.gamma,), + zfar=(45.0,), + znear=(1.0,), + radius_world=True, + bg_col=torch.ones((3,), dtype=torch.float32, device=device), + # As mentioned above: workaround for device placement of gradients for + # camera parameters. + focal_length=self.focal_length, + R=self.cam_rot[None, ...], + T=self.cam_pos[None, ...], + )[0] + + +# Load reference. +ref = ( + torch.from_numpy( + imageio.imread( + "../../tests/pulsar/reference/examples_TestRenderer_test_cam.png" + )[:, ::-1, :].copy() + ).to(torch.float32) + / 255.0 +).to(device) +# Set up model. +model = SceneModel().to(device) +# Optimizer. +optimizer = optim.SGD( + [ + {"params": [model.cam_pos], "lr": 1e-4}, + {"params": [model.cam_rot], "lr": 5e-6}, + # Using a higher lr for the focal length here, because + # the sensor width can not be optimized directly. + {"params": [model.focal_length], "lr": 1e-3}, + ] +) + +print("Writing video to `%s`." % (path.abspath("cam-pt3d.gif"))) +writer = imageio.get_writer("cam-pt3d.gif", format="gif", fps=25) + +# Optimize. +for i in range(300): + optimizer.zero_grad() + result = model() + # Visualize. + result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) + cv2.imshow("opt", result_im[:, :, ::-1]) + writer.append_data(result_im) + overlay_img = np.ascontiguousarray( + ((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[ + :, :, ::-1 + ] + ) + overlay_img = cv2.putText( + overlay_img, + "Step %d" % (i), + (10, 40), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 0), + 2, + cv2.LINE_AA, + False, + ) + cv2.imshow("overlay", overlay_img) + cv2.waitKey(1) + # Update. + loss = ((result - ref) ** 2).sum() + print("loss {}: {}".format(i, loss.item())) + loss.backward() + optimizer.step() +writer.close() diff --git a/docs/examples/pulsar_multiview.py b/docs/examples/pulsar_multiview.py index 4be9af72..9b816b31 100755 --- a/docs/examples/pulsar_multiview.py +++ b/docs/examples/pulsar_multiview.py @@ -7,7 +7,12 @@ pulsar interface. For this, reference images have been pre-generated The camera parameters are assumed given. The scene is initialized with random spheres. Gradient-based optimization is used to optimize sphere parameters and prune spheres to converge to a 3D representation. + +This example is not available yet through the 'unified' interface, +because opacity support has not landed in PyTorch3D for general data +structures yet. """ +import math from os import path import cv2 @@ -77,7 +82,7 @@ class SceneModel(nn.Module): 0.0, 30.0 - np.cos(angle) * 35.0, 0.0, - -angle, + -angle + math.pi, 0.0, 5.0, 2.0, @@ -87,7 +92,7 @@ class SceneModel(nn.Module): dtype=torch.float32, ), ) - self.renderer = Renderer(width, height, n_points) + self.renderer = Renderer(width, height, n_points, right_handed_system=True) def forward(self, cam=None): if cam is None: @@ -184,7 +189,7 @@ for i in range(300): 0.0, 30.0 - np.cos(angle) * 35.0, 0.0, - -angle, + -angle + math.pi, 0.0, 5.0, 2.0, diff --git a/docs/examples/pulsar_optimization.py b/docs/examples/pulsar_optimization.py index 67b2f81b..ded9a61b 100755 --- a/docs/examples/pulsar_optimization.py +++ b/docs/examples/pulsar_optimization.py @@ -8,6 +8,8 @@ The scene is initialized with random spheres. Gradient-based optimization is used to converge towards a faithful scene representation. """ +import math + import cv2 import imageio import numpy as np @@ -58,11 +60,15 @@ class SceneModel(nn.Module): ) self.register_buffer( "cam_params", - torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32), + torch.tensor( + [0.0, 0.0, 0.0, 0.0, math.pi, 0.0, 5.0, 2.0], dtype=torch.float32 + ), ) # The volumetric optimization works better with a higher number of tracked # intersections per ray. - self.renderer = Renderer(width, height, n_points, n_track=32) + self.renderer = Renderer( + width, height, n_points, n_track=32, right_handed_system=True + ) def forward(self): return self.renderer.forward( @@ -81,7 +87,7 @@ ref = ( torch.from_numpy( imageio.imread( "../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png" - ) + )[:, ::-1, :].copy() ).to(torch.float32) / 255.0 ).to(device) diff --git a/docs/examples/pulsar_optimization_unified.py b/docs/examples/pulsar_optimization_unified.py new file mode 100755 index 00000000..59ac72fb --- /dev/null +++ b/docs/examples/pulsar_optimization_unified.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +This example demonstrates scene optimization with the PyTorch3D +pulsar interface. For this, a reference image has been pre-generated +(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png`). +The scene is initialized with random spheres. Gradient-based +optimization is used to converge towards a faithful +scene representation. +""" +import math + +import cv2 +import imageio +import numpy as np +import torch +from pytorch3d.renderer.cameras import PerspectiveCameras # , look_at_view_transform +from pytorch3d.renderer.points import ( + PointsRasterizationSettings, + PointsRasterizer, + PulsarPointsRenderer, +) +from pytorch3d.structures.pointclouds import Pointclouds +from torch import nn, optim + + +n_points = 10_000 +width = 1_000 +height = 1_000 +device = torch.device("cuda") + + +class SceneModel(nn.Module): + """ + A simple scene model to demonstrate use of pulsar in PyTorch modules. + + The scene model is parameterized with sphere locations (vert_pos), + channel content (vert_col), radiuses (vert_rad), camera position (cam_pos), + camera rotation (cam_rot) and sensor focal length and width (cam_sensor). + + The forward method of the model renders this scene description. Any + of these parameters could instead be passed as inputs to the forward + method and come from a different model. + """ + + def __init__(self): + super(SceneModel, self).__init__() + self.gamma = 1.0 + # Points. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True)) + self.register_parameter( + "vert_col", + nn.Parameter( + torch.ones(n_points, 3, dtype=torch.float32, device=device) * 0.5, + requires_grad=True, + ), + ) + self.register_parameter( + "vert_rad", + nn.Parameter( + torch.ones(n_points, dtype=torch.float32) * 0.3, requires_grad=True + ), + ) + self.register_buffer( + "cam_params", + torch.tensor( + [0.0, 0.0, 0.0, 0.0, math.pi, 0.0, 5.0, 2.0], dtype=torch.float32 + ), + ) + self.cameras = PerspectiveCameras( + # The focal length must be double the size for PyTorch3D because of the NDC + # coordinates spanning a range of two - and they must be normalized by the + # sensor width (see the pulsar example). This means we need here + # 5.0 * 2.0 / 2.0 to get the equivalent results as in pulsar. + focal_length=5.0, + R=torch.eye(3, dtype=torch.float32, device=device)[None, ...], + T=torch.zeros((1, 3), dtype=torch.float32, device=device), + image_size=((width, height),), + device=device, + ) + raster_settings = PointsRasterizationSettings( + image_size=(width, height), + radius=self.vert_rad, + ) + rasterizer = PointsRasterizer( + cameras=self.cameras, raster_settings=raster_settings + ) + self.renderer = PulsarPointsRenderer(rasterizer=rasterizer, n_track=32) + + def forward(self): + # The Pointclouds object creates copies of it's arguments - that's why + # we have to create a new object in every forward step. + pcl = Pointclouds( + points=self.vert_pos[None, ...], features=self.vert_col[None, ...] + ) + return self.renderer( + pcl, + gamma=(self.gamma,), + zfar=(45.0,), + znear=(1.0,), + radius_world=True, + bg_col=torch.ones((3,), dtype=torch.float32, device=device), + )[0] + + +# Load reference. +ref = ( + torch.from_numpy( + imageio.imread( + "../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png" + )[:, ::-1, :].copy() + ).to(torch.float32) + / 255.0 +).to(device) +# Set up model. +model = SceneModel().to(device) +# Optimizer. +optimizer = optim.SGD( + [ + {"params": [model.vert_col], "lr": 1e0}, + {"params": [model.vert_rad], "lr": 5e-3}, + {"params": [model.vert_pos], "lr": 1e-2}, + ] +) + +# Optimize. +for i in range(500): + optimizer.zero_grad() + result = model() + # Visualize. + result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) + cv2.imshow("opt", result_im[:, :, ::-1]) + overlay_img = np.ascontiguousarray( + ((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[ + :, :, ::-1 + ] + ) + overlay_img = cv2.putText( + overlay_img, + "Step %d" % (i), + (10, 40), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 0), + 2, + cv2.LINE_AA, + False, + ) + cv2.imshow("overlay", overlay_img) + cv2.waitKey(1) + # Update. + loss = ((result - ref) ** 2).sum() + print("loss {}: {}".format(i, loss.item())) + loss.backward() + optimizer.step() + # Cleanup. + with torch.no_grad(): + model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0) + # Remove points. + model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0 + model.vert_rad.data[model.vert_rad < 0.001] = 0.0001 + vd = ( + (model.vert_col - torch.ones(3, dtype=torch.float32).to(device)) + .abs() + .sum(dim=1) + ) + model.vert_pos.data[vd <= 0.2] = -1000.0 diff --git a/docs/notes/assets/pulsar_bm.png b/docs/notes/assets/pulsar_bm.png new file mode 100644 index 00000000..83bfddaf Binary files /dev/null and b/docs/notes/assets/pulsar_bm.png differ diff --git a/docs/notes/renderer.md b/docs/notes/renderer.md index 1c492bf3..9e5d7402 100644 --- a/docs/notes/renderer.md +++ b/docs/notes/renderer.md @@ -25,7 +25,7 @@ To learn about more the implementation and start using the renderer refer to [ge ## Tech Report -For an in depth explanation of the renderer design, key features and benchmarks please refer to the PyTorch3D Technical Report on ArXiv: [Accelerating 3D Deep Learning with PyTorch3D](https://arxiv.org/abs/2007.08501) +For an in depth explanation of the renderer design, key features and benchmarks please refer to the PyTorch3D Technical Report on ArXiv: [Accelerating 3D Deep Learning with PyTorch3D](https://arxiv.org/abs/2007.08501), for the pulsar backend see here: [Fast Differentiable Raycasting for Neural Rendering using Sphere-based Representations](https://arxiv.org/abs/2004.07484). --- diff --git a/docs/notes/renderer_getting_started.md b/docs/notes/renderer_getting_started.md index 542ff9ec..8b4aca5b 100644 --- a/docs/notes/renderer_getting_started.md +++ b/docs/notes/renderer_getting_started.md @@ -55,6 +55,16 @@ While we tried to emulate several aspects of OpenGL, there are differences in th --- +### The pulsar backend + +Since v0.3, [pulsar](https://arxiv.org/abs/2004.07484) can be used as a backend for point-rendering. It has a focus on efficiency, which comes with pros and cons: it is highly optimized and all rendering stages are integrated in the CUDA kernels. This leads to significantly higher speed and better scaling behavior. We use it at Facebook Reality Labs to render and optimize scenes with millions of spheres in resolutions up to 4K. You can find a runtime comparison plot below (settings: `bin_size=None`, `points_per_pixel=5`, `image_size=1024`, `radius=1e-2`, `composite_params.radius=1e-4`; benchmarked on an RTX 2070 GPU). + + + +Pulsar's processing steps are tightly integrated CUDA kernels and do not work with custom `rasterizer` and `compositor` components. We provide two ways to use Pulsar: (1) there is a unified interface to match the PyTorch3D calling convention seamlessly. This is, for example, illustrated in the [point cloud tutorial](https://github.com/facebookresearch/pytorch3d/blob/master/docs/tutorials/render_colored_points.ipynb). (2) There is a direct interface available to the pulsar backend, which exposes the full functionality of the backend (including opacity, which is not yet available in PyTorch3D). Examples showing its use as well as the matching PyTorch3D interface code are available in [this folder](https://github.com/facebookresearch/pytorch3d/tree/master/docs/examples). + +--- + ### Texturing options For mesh texturing we offer several options (in `pytorch3d/renderer/mesh/texturing.py`): diff --git a/docs/tutorials/render_colored_points.ipynb b/docs/tutorials/render_colored_points.ipynb index 25b223e8..7f13afe5 100644 --- a/docs/tutorials/render_colored_points.ipynb +++ b/docs/tutorials/render_colored_points.ipynb @@ -73,6 +73,7 @@ " FoVOrthographicCameras, \n", " PointsRasterizationSettings,\n", " PointsRenderer,\n", + " PulsarPointsRenderer,\n", " PointsRasterizer,\n", " AlphaCompositor,\n", " NormWeightedCompositor\n", @@ -169,10 +170,11 @@ "\n", "# Create a points renderer by compositing points using an alpha compositor (nearer points\n", "# are weighted more heavily). See [1] for an explanation.\n", + "rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)\n", "renderer = PointsRenderer(\n", - " rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings),\n", + " rasterizer=rasterizer,\n", " compositor=AlphaCompositor()\n", - ")\n" + ")" ] }, { @@ -202,12 +204,13 @@ "outputs": [], "source": [ "renderer = PointsRenderer(\n", - " rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings),\n", + " rasterizer=rasterizer,\n", " # Pass in background_color to the alpha compositor, setting the background color \n", " # to the 3 item tuple, representing rgb on a scale of 0 -> 1, in this case blue\n", " compositor=AlphaCompositor(background_color=(0, 0, 1))\n", ")\n", "images = renderer(point_cloud)\n", + "\n", "plt.figure(figsize=(10, 10))\n", "plt.imshow(images[0, ..., :3].cpu().numpy())\n", "plt.grid(\"off\")\n", @@ -288,6 +291,39 @@ "plt.axis(\"off\");" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Using the pulsar backend\n", + "\n", + "Switching to the pulsar backend is easy! The pulsar backend has a compositor built-in, so the `compositor` argument is not required when creating it (a warning will be displayed if you provide it nevertheless). It pre-allocates memory on the rendering device, that's why it needs the `n_channels` at construction time.\n", + "\n", + "All parameters for the renderer forward function are batch-wise except the background color (in this example, `gamma`) and you have to provide as many values as you have examples in your batch. The background color is optional and by default set to all zeros. You can find a detailed explanation of how gamma influences the rendering function here in the paper [Fast Differentiable Raycasting for Neural Rendering using\n", + "Sphere-based Representations](https://arxiv.org/pdf/2004.07484.pdf).\n", + "\n", + "You can also use the `native` backend for the pulsar backend which already provides access to point opacity. The native backend can be imported from `pytorch3d.renderer.points.pulsar`; you can find examples for this in the folder `docs/examples`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "renderer = PulsarPointsRenderer(\n", + " rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings),\n", + " n_channels=4\n", + ").to(device)\n", + "\n", + "images = renderer(point_cloud, gamma=(1e-4,),\n", + " bg_col=torch.tensor([0.0, 1.0, 0.0, 1.0], dtype=torch.float32, device=device))\n", + "plt.figure(figsize=(10, 10))\n", + "plt.imshow(images[0, ..., :3].cpu().numpy())\n", + "plt.grid(\"off\")\n", + "plt.axis(\"off\");" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -412,9 +448,9 @@ "bento/extensions/theme/main.css": true }, "kernelspec": { - "display_name": "pytorch3d_etc (local)", + "display_name": "Python 3", "language": "python", - "name": "pytorch3d_etc_local" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -426,7 +462,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.5+" + "version": "3.6.8" } }, "nbformat": 4, diff --git a/pytorch3d/csrc/pulsar/global.h b/pytorch3d/csrc/pulsar/global.h index 6c0e94c1..88257d67 100644 --- a/pytorch3d/csrc/pulsar/global.h +++ b/pytorch3d/csrc/pulsar/global.h @@ -38,9 +38,27 @@ // Don't care about pytorch warnings; they shouldn't clutter our warnings. #pragma clang diagnostic push #pragma clang diagnostic ignored "-Weverything" -#include #include #pragma clang diagnostic pop +#ifdef WITH_CUDA +#include +#else +#ifndef cudaStream_t +typedef void* cudaStream_t; +#endif +struct int2 { + int x, y; +}; +struct ushort2 { + unsigned short x, y; +}; +struct float2 { + float x, y; +}; +struct float3 { + float x, y, z; +}; +#endif namespace py = pybind11; inline float3 make_float3(const float& x, const float& y, const float& z) { float3 res; diff --git a/pytorch3d/csrc/pulsar/pytorch/renderer.cpp b/pytorch3d/csrc/pulsar/pytorch/renderer.cpp index b44f6dc5..d58c1489 100644 --- a/pytorch3d/csrc/pulsar/pytorch/renderer.cpp +++ b/pytorch3d/csrc/pulsar/pytorch/renderer.cpp @@ -5,8 +5,10 @@ #include "./util.h" #include +#ifdef WITH_CUDA #include #include +#endif namespace PRE = ::pulsar::Renderer; @@ -58,10 +60,13 @@ Renderer::Renderer( Renderer::~Renderer() { if (this->device_type == c10::DeviceType::CUDA) { +// Can't happen in the case that not compiled with CUDA. +#ifdef WITH_CUDA at::cuda::CUDAGuard device_guard(this->device_tracker.device()); for (auto nrend : this->renderer_vec) { PRE::destruct(&nrend); } +#endif } else { for (auto nrend : this->renderer_vec) { PRE::destruct(&nrend); @@ -87,6 +92,7 @@ void Renderer::ensure_on_device(torch::Device device, bool /*non_blocking*/) { "Only CPU and CUDA device types are supported."); if (device.type() != this->device_type || device.index() != this->device_index) { +#ifdef WITH_CUDA LOG_IF(INFO, PULSAR_LOG_INIT) << "Transferring render buffers between devices."; int prev_active; @@ -136,6 +142,11 @@ void Renderer::ensure_on_device(torch::Device device, bool /*non_blocking*/) { cudaSetDevice(prev_active); this->device_type = device.type(); this->device_index = device.index(); +#else + throw std::runtime_error( + "pulsar was built without CUDA " + "but a device move to a CUDA device was initiated."); +#endif } }; @@ -148,6 +159,7 @@ void Renderer::ensure_n_renderers_gte(const size_t& batch_size) { for (ptrdiff_t i = 0; i < diff; ++i) { this->renderer_vec.emplace_back(); if (this->device_type == c10::DeviceType::CUDA) { +#ifdef WITH_CUDA PRE::construct( &this->renderer_vec[this->renderer_vec.size() - 1], this->max_num_balls(), @@ -158,6 +170,7 @@ void Renderer::ensure_n_renderers_gte(const size_t& batch_size) { this->renderer_vec[0].cam.background_normalization_depth, this->renderer_vec[0].cam.n_channels, this->n_track()); +#endif } else { PRE::construct( &this->renderer_vec[this->renderer_vec.size() - 1], @@ -708,6 +721,10 @@ std::tuple Renderer::forward( opacity_ptr = opacity_contiguous.data_ptr(); } if (this->device_type == c10::DeviceType::CUDA) { +// No else check necessary - if not compiled with CUDA +// we can't even reach this code (the renderer can't be +// moved to a CUDA device). +#ifdef WITH_CUDA int prev_active; cudaGetDevice(&prev_active); cudaSetDevice(this->device_index); @@ -756,6 +773,7 @@ std::tuple Renderer::forward( << time_ms / static_cast(batch_size) << "ms" << std::endl; #endif cudaSetDevice(prev_active); +#endif } else { #ifdef PULSAR_TIMINGS_BATCHED_ENABLED START_TIME(batch_forward); @@ -816,7 +834,11 @@ std::tuple Renderer::forward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); if (mode == 1) results[batch_i] = results[batch_i].slice(2, 0, 1, 1); @@ -829,7 +851,11 @@ std::tuple Renderer::forward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Forward render complete."; @@ -1048,6 +1074,9 @@ Renderer::backward( opacity_ptr = opacity_contiguous.data_ptr(); } if (this->device_type == c10::DeviceType::CUDA) { +// No else check necessary - it's not possible to move +// the renderer to a CUDA device if not built with CUDA. +#ifdef WITH_CUDA int prev_active; cudaGetDevice(&prev_active); cudaSetDevice(this->device_index); @@ -1162,6 +1191,7 @@ Renderer::backward( std::cout << "Backward render batched time per example: " << time_ms / static_cast(batch_size) << "ms" << std::endl; #endif +#endif // WITH_CUDA } else { #ifdef PULSAR_TIMINGS_BATCHED_ENABLED START_TIME(batch_backward); @@ -1285,7 +1315,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } std::get<0>(ret) = torch::stack(results); @@ -1297,7 +1331,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } } @@ -1313,7 +1351,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } std::get<1>(ret) = torch::stack(results); @@ -1326,7 +1368,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } } @@ -1341,7 +1387,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } std::get<2>(ret) = torch::stack(results); @@ -1353,7 +1403,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } } @@ -1371,7 +1425,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); res_p2[batch_i] = from_blob( reinterpret_cast( @@ -1381,7 +1439,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); res_p3[batch_i] = from_blob( reinterpret_cast( @@ -1391,7 +1453,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); res_p4[batch_i] = from_blob( reinterpret_cast( @@ -1401,7 +1467,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } std::get<3>(ret) = torch::stack(res_p1); @@ -1416,7 +1486,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); std::get<4>(ret) = from_blob( reinterpret_cast(this->renderer_vec[0].grad_cam_d + 3), @@ -1425,7 +1499,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); std::get<5>(ret) = from_blob( reinterpret_cast(this->renderer_vec[0].grad_cam_d + 6), @@ -1434,7 +1512,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); std::get<6>(ret) = from_blob( reinterpret_cast(this->renderer_vec[0].grad_cam_d + 9), @@ -1443,7 +1525,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } } @@ -1458,7 +1544,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } std::get<7>(ret) = torch::stack(results); @@ -1470,7 +1560,11 @@ Renderer::backward( this->device_index, torch::kFloat, this->device_type == c10::DeviceType::CUDA +#ifdef WITH_CUDA ? at::cuda::getCurrentCUDAStream() +#else + ? (cudaStream_t) nullptr +#endif : (cudaStream_t) nullptr); } } diff --git a/pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp b/pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp index c1e9e108..f4f7dcf5 100644 --- a/pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp +++ b/pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp @@ -1,6 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifdef WITH_CUDA #include #include +#endif #include #include "./tensor_util.h" @@ -23,6 +25,7 @@ torch::Tensor sphere_ids_from_result_info_nograd( /*dim=*/3, /*start=*/3, /*end=*/forw_info.size(3), /*step=*/2) .contiguous(); if (forw_info.device().type() == c10::DeviceType::CUDA) { +#ifdef WITH_CUDA cudaMemcpyAsync( result.data_ptr(), tmp.data_ptr(), @@ -30,6 +33,11 @@ torch::Tensor sphere_ids_from_result_info_nograd( tmp.size(3), cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream()); +#else + throw std::runtime_error( + "Copy on CUDA device initiated but built " + "without CUDA support."); +#endif } else { memcpy( result.data_ptr(), diff --git a/pytorch3d/csrc/pulsar/pytorch/util.cpp b/pytorch3d/csrc/pulsar/pytorch/util.cpp index 847e697e..0b80a538 100644 --- a/pytorch3d/csrc/pulsar/pytorch/util.cpp +++ b/pytorch3d/csrc/pulsar/pytorch/util.cpp @@ -1,4 +1,5 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifdef WITH_CUDA #include namespace pulsar { @@ -22,3 +23,4 @@ void cudaDevToHost( } // namespace pytorch } // namespace pulsar +#endif diff --git a/pytorch3d/csrc/pulsar/pytorch/util.h b/pytorch3d/csrc/pulsar/pytorch/util.h index bab41678..4aa0a654 100644 --- a/pytorch3d/csrc/pulsar/pytorch/util.h +++ b/pytorch3d/csrc/pulsar/pytorch/util.h @@ -41,11 +41,16 @@ torch::Tensor from_blob( const int num_elements = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); if (device_type == c10::DeviceType::CUDA) { +#ifdef WITH_CUDA cudaDevToDev( ret.data_ptr(), static_cast(ptr), sizeof(T) * num_elements, stream); +#else + throw std::runtime_error( + "Initiating devToDev copy on a build without CUDA."); +#endif // TODO: check for synchronization. } else { memcpy(ret.data_ptr(), ptr, sizeof(T) * num_elements); diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 72a01122..f383f6f7 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -46,6 +46,7 @@ from .points import ( PointsRasterizationSettings, PointsRasterizer, PointsRenderer, + PulsarPointsRenderer, rasterize_points, ) from .utils import TensorProperties, convert_to_tensors_and_broadcast diff --git a/pytorch3d/renderer/points/pulsar/unified.py b/pytorch3d/renderer/points/pulsar/unified.py index 08b10af6..06d27c79 100644 --- a/pytorch3d/renderer/points/pulsar/unified.py +++ b/pytorch3d/renderer/points/pulsar/unified.py @@ -18,6 +18,15 @@ from ..rasterizer import PointsRasterizer from .renderer import Renderer as PulsarRenderer +def _ensure_float_tensor(val_in, device): + """Make sure that the value provided is wrapped a PyTorch float tensor.""" + if not isinstance(val_in, torch.Tensor): + val_out = torch.tensor(val_in, dtype=torch.float32, device=device).reshape((1,)) + else: + val_out = val_in.to(torch.float32).to(device).reshape((1,)) + return val_out + + class PulsarPointsRenderer(nn.Module): """ This renderer is a PyTorch3D interface wrapper around the pulsar renderer. @@ -36,6 +45,7 @@ class PulsarPointsRenderer(nn.Module): compositor: Optional[Union[NormWeightedCompositor, AlphaCompositor]] = None, n_channels: int = 3, max_num_spheres: int = int(1e6), # noqa: B008 + **kwargs, ): """ rasterizer (PointsRasterizer): An object encapsulating rasterization parameters. @@ -43,6 +53,8 @@ class PulsarPointsRenderer(nn.Module): n_channels (int): The number of channels of the resulting image. Default: 3. max_num_spheres (int): The maximum number of spheres intended to render with this renderer. Default: 1e6. + kwargs (Any): kwargs to pass on to the pulsar renderer. + See `pytorch3d.renderer.points.pulsar.renderer.Renderer` for all options. """ super().__init__() self.rasterizer = rasterizer @@ -87,6 +99,7 @@ class PulsarPointsRenderer(nn.Module): orthogonal_projection=orthogonal_projection, right_handed_system=True, n_channels=n_channels, + **kwargs, ) def _conf_check(self, point_clouds, kwargs: Dict[str, Any]) -> bool: @@ -165,8 +178,8 @@ class PulsarPointsRenderer(nn.Module): ) return orthogonal_projection - def _extract_intrinsics( - self, orthogonal_projection, kwargs, cloud_idx + def _extract_intrinsics( # noqa: C901 + self, orthogonal_projection, kwargs, cloud_idx, device ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]: """ Translate the camera intrinsics from PyTorch3D format to pulsar format. @@ -174,7 +187,7 @@ class PulsarPointsRenderer(nn.Module): # Shorthand: cameras = self.rasterizer.cameras if orthogonal_projection: - focal_length = 0.0 + focal_length = torch.zeros((1,), dtype=torch.float32) if isinstance(cameras, FoVOrthographicCameras): # pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `znear`. znear = kwargs.get("znear", cameras.znear)[cloud_idx] @@ -212,7 +225,10 @@ class PulsarPointsRenderer(nn.Module): raise ValueError( f"The orthographic camera must have positive size! Is: {sensor_width}." # noqa: B950 ) - principal_point_x, principal_point_y = 0.0, 0.0 + principal_point_x, principal_point_y = ( + torch.zeros((1,), dtype=torch.float32), + torch.zeros((1,), dtype=torch.float32), + ) else: # Currently, this means it must be an 'OrthographicCameras' object. focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[ @@ -276,7 +292,10 @@ class PulsarPointsRenderer(nn.Module): "must agree with the resolution width / height (" f"{self.renderer._renderer.width / self.renderer._renderer.height})." # noqa: B950 ) - principal_point_x, principal_point_y = 0.0, 0.0 + principal_point_x, principal_point_y = ( + torch.zeros((1,), dtype=torch.float32), + torch.zeros((1,), dtype=torch.float32), + ) else: # pyre-fixme[16]: `PerspectiveCameras` has no attribute `focal_length`. focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[ @@ -308,7 +327,13 @@ class PulsarPointsRenderer(nn.Module): "Focal length not parsable: %s." % (str(focal_length_conf)) ) focal_length_px = focal_length_conf - focal_length = znear - 1e-6 + focal_length = torch.tensor( + [ + znear - 1e-6, + ], + dtype=torch.float32, + device=focal_length_px.device, + ) sensor_width = focal_length / focal_length_px * 2.0 principal_point_x = ( # pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`. @@ -321,6 +346,12 @@ class PulsarPointsRenderer(nn.Module): * 0.5 * self.renderer._renderer.height ) + focal_length = _ensure_float_tensor(focal_length, device) + sensor_width = _ensure_float_tensor(sensor_width, device) + principal_point_x = _ensure_float_tensor(principal_point_x, device) + principal_point_y = _ensure_float_tensor(principal_point_y, device) + znear = _ensure_float_tensor(znear, device) + zfar = _ensure_float_tensor(zfar, device) return ( focal_length, sensor_width, @@ -338,11 +369,17 @@ class PulsarPointsRenderer(nn.Module): R = kwargs.get("R", cameras.R)[cloud_idx] T = kwargs.get("T", cameras.T)[cloud_idx] norm_mat = torch.tensor( - [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32, device=R.device, ) - cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]) + cam_rot = torch.matmul(norm_mat, R[:3, :3][None, ...]).permute((0, 2, 1)) + norm_mat = torch.tensor( + [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype=torch.float32, + device=R.device, + ) + cam_rot = torch.matmul(norm_mat, cam_rot) cam_pos = torch.flatten(torch.matmul(cam_rot, T[..., None])) cam_rot = torch.flatten(matrix_to_rotation_6d(cam_rot)) return cam_pos, cam_rot @@ -374,7 +411,7 @@ class PulsarPointsRenderer(nn.Module): ) else: point_dists = torch.norm((vert_pos - cam_pos), p=2, dim=1, keepdim=False) - vert_rad = raster_rad / focal_length * point_dists + vert_rad = raster_rad / focal_length.to(vert_pos.device) * point_dists if isinstance(self.rasterizer.cameras, PerspectiveCameras): # NDC normalization happens through adjusted focal length. pass @@ -382,6 +419,7 @@ class PulsarPointsRenderer(nn.Module): vert_rad = vert_rad / 2.0 # NDC normalization. return vert_rad + # point_clouds is not typed to avoid a cyclic dependency. def forward(self, point_clouds, **kwargs) -> torch.Tensor: """ Get the rendering of the provided `Pointclouds`. @@ -439,6 +477,8 @@ class PulsarPointsRenderer(nn.Module): for cloud_idx, (vert_pos, vert_col) in enumerate( zip(position_list, features_list) ): + # Get extrinsics. + cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx) # Get intrinsics. ( focal_length, @@ -447,23 +487,21 @@ class PulsarPointsRenderer(nn.Module): principal_point_y, znear, zfar, - ) = self._extract_intrinsics(orthogonal_projection, kwargs, cloud_idx) - # Get extrinsics. - cam_pos, cam_rot = self._extract_extrinsics(kwargs, cloud_idx) + ) = self._extract_intrinsics( + orthogonal_projection, kwargs, cloud_idx, cam_pos.device + ) # Put everything together. cam_params = torch.cat( ( cam_pos, - cam_rot, - torch.tensor( + cam_rot.to(cam_pos.device), + torch.cat( [ focal_length, sensor_width, principal_point_x, principal_point_y, ], - dtype=torch.float32, - device=cam_pos.device, ), ) ) diff --git a/setup.py b/setup.py index cf041287..3bfcd280 100755 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ import glob import os import runpy +import warnings import torch from setuptools import find_packages, setup @@ -26,22 +27,25 @@ def get_extensions(): sources += source_cuda define_macros += [("WITH_CUDA", None)] cub_home = os.environ.get("CUB_HOME", None) - if cub_home is None: - raise Exception( - "The environment variable `CUB_HOME` was not found. " - "NVIDIA CUB is required for compilation and can be downloaded " - "from `https://github.com/NVIDIA/cub/releases`. You can unpack " - "it to a location of your choice and set the environment variable " - "`CUB_HOME` to the folder containing the `CMakeListst.txt` file." - ) nvcc_args = [ - "-I%s" % (os.path.realpath(cub_home).replace("\\ ", " ")), "-std=c++14", "-DCUDA_HAS_FP16=1", "-D__CUDA_NO_HALF_OPERATORS__", "-D__CUDA_NO_HALF_CONVERSIONS__", "-D__CUDA_NO_HALF2_OPERATORS__", ] + if cub_home is None: + warnings.warn( + "The environment variable `CUB_HOME` was not found. " + "NVIDIA CUB is required for compilation and can be downloaded " + "from `https://github.com/NVIDIA/cub/releases`. You can unpack " + "it to a location of your choice and set the environment variable " + "`CUB_HOME` to the folder containing the `CMakeListst.txt` file." + ) + else: + nvcc_args.insert( + 0, "-I%s" % (os.path.realpath(cub_home).replace("\\ ", " ")) + ) nvcc_flags_env = os.getenv("NVCC_FLAGS", "") if nvcc_flags_env != "": nvcc_args.extend(nvcc_flags_env.split(" "))