diff --git a/docs/examples/pulsar_basic.py b/docs/examples/pulsar_basic.py new file mode 100755 index 00000000..62dccd03 --- /dev/null +++ b/docs/examples/pulsar_basic.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +This example demonstrates the most trivial, direct interface of the pulsar +sphere renderer. It renders and saves an image with 10 random spheres. +Output: basic.png. +""" +from os import path + +import imageio +import torch +from pytorch3d.renderer.points.pulsar import Renderer + + +n_points = 10 +width = 1_000 +height = 1_000 +device = torch.device("cuda") +renderer = Renderer(width, height, n_points).to(device) +# Generate sample data. +vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0 +vert_pos[:, 2] += 25.0 +vert_pos[:, :2] -= 5.0 +vert_col = torch.rand(n_points, 3, dtype=torch.float32, device=device) +vert_rad = torch.rand(n_points, dtype=torch.float32, device=device) +cam_params = torch.tensor( + [ + 0.0, + 0.0, + 0.0, # Position 0, 0, 0 (x, y, z). + 0.0, + 0.0, + 0.0, # Rotation 0, 0, 0 (in axis-angle format). + 5.0, # Focal length in world size. + 2.0, # Sensor size in world size. + ], + dtype=torch.float32, + device=device, +) +# Render. +image = renderer( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, # Renderer blending parameter gamma, in [1., 1e-5]. + 45.0, # Maximum depth. +) +print("Writing image to `%s`." % (path.abspath("basic.png"))) +imageio.imsave("basic.png", (image.cpu().detach() * 255.0).to(torch.uint8).numpy()) diff --git a/docs/examples/pulsar_cam.py b/docs/examples/pulsar_cam.py new file mode 100755 index 00000000..dcc08759 --- /dev/null +++ b/docs/examples/pulsar_cam.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +This example demonstrates camera parameter optimization with the plain +pulsar interface. For this, a reference image has been pre-generated +(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_cam.png`). +The same scene parameterization is loaded and the camera parameters +distorted. Gradient-based optimization is used to converge towards the +original camera parameters. +""" +from os import path + +import cv2 +import imageio +import numpy as np +import torch +from pytorch3d.renderer.points.pulsar import Renderer +from torch import nn, optim + + +n_points = 20 +width = 1_000 +height = 1_000 +device = torch.device("cuda") + + +class SceneModel(nn.Module): + """ + A simple scene model to demonstrate use of pulsar in PyTorch modules. + + The scene model is parameterized with sphere locations (vert_pos), + channel content (vert_col), radiuses (vert_rad), camera position (cam_pos), + camera rotation (cam_rot) and sensor focal length and width (cam_sensor). + + The forward method of the model renders this scene description. Any + of these parameters could instead be passed as inputs to the forward + method and come from a different model. + """ + + def __init__(self): + super(SceneModel, self).__init__() + self.gamma = 0.1 + # Points. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False)) + self.register_parameter( + "vert_col", + nn.Parameter( + torch.rand(n_points, 3, dtype=torch.float32), requires_grad=False + ), + ) + self.register_parameter( + "vert_rad", + nn.Parameter( + torch.rand(n_points, dtype=torch.float32), requires_grad=False + ), + ) + self.register_parameter( + "cam_pos", + nn.Parameter( + torch.tensor([0.1, 0.1, 0.0], dtype=torch.float32), requires_grad=True + ), + ) + self.register_parameter( + "cam_rot", + nn.Parameter( + torch.tensor( + [ + # We're using the 6D rot. representation for better gradients. + 0.9995, + 0.0300445, + -0.0098482, + -0.0299445, + 0.9995, + 0.0101482, + ], + dtype=torch.float32, + ), + requires_grad=True, + ), + ) + self.register_parameter( + "cam_sensor", + nn.Parameter( + torch.tensor([4.8, 1.8], dtype=torch.float32), requires_grad=True + ), + ) + self.renderer = Renderer(width, height, n_points) + + def forward(self): + return self.renderer.forward( + self.vert_pos, + self.vert_col, + self.vert_rad, + torch.cat([self.cam_pos, self.cam_rot, self.cam_sensor]), + self.gamma, + 45.0, + ) + + +# Load reference. +ref = ( + torch.from_numpy( + imageio.imread( + "../../tests/pulsar/reference/examples_TestRenderer_test_cam.png" + ) + ).to(torch.float32) + / 255.0 +).to(device) +# Set up model. +model = SceneModel().to(device) +# Optimizer. +optimizer = optim.SGD( + [ + {"params": [model.cam_pos], "lr": 1e-4}, # 1e-3 + {"params": [model.cam_rot], "lr": 5e-6}, + {"params": [model.cam_sensor], "lr": 1e-4}, + ] +) + +print("Writing video to `%s`." % (path.abspath("cam.gif"))) +writer = imageio.get_writer("cam.gif", format="gif", fps=25) + +# Optimize. +for i in range(300): + optimizer.zero_grad() + result = model() + # Visualize. + result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) + cv2.imshow("opt", result_im[:, :, ::-1]) + writer.append_data(result_im) + overlay_img = np.ascontiguousarray( + ((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[ + :, :, ::-1 + ] + ) + overlay_img = cv2.putText( + overlay_img, + "Step %d" % (i), + (10, 40), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 0), + 2, + cv2.LINE_AA, + False, + ) + cv2.imshow("overlay", overlay_img) + cv2.waitKey(1) + # Update. + loss = ((result - ref) ** 2).sum() + print("loss {}: {}".format(i, loss.item())) + loss.backward() + optimizer.step() +writer.close() diff --git a/docs/examples/pulsar_multiview.py b/docs/examples/pulsar_multiview.py new file mode 100755 index 00000000..4be9af72 --- /dev/null +++ b/docs/examples/pulsar_multiview.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +This example demonstrates multiview 3D reconstruction using the plain +pulsar interface. For this, reference images have been pre-generated +(you can find them at `../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png`). +The camera parameters are assumed given. The scene is initialized with +random spheres. Gradient-based optimization is used to optimize sphere +parameters and prune spheres to converge to a 3D representation. +""" +from os import path + +import cv2 +import imageio +import numpy as np +import torch +from pytorch3d.renderer.points.pulsar import Renderer +from torch import nn, optim + + +n_points = 400_000 +width = 1_000 +height = 1_000 +visualize_ids = [0, 1] +device = torch.device("cuda") + + +class SceneModel(nn.Module): + """ + A simple scene model to demonstrate use of pulsar in PyTorch modules. + + The scene model is parameterized with sphere locations (vert_pos), + channel content (vert_col), radiuses (vert_rad), camera position (cam_pos), + camera rotation (cam_rot) and sensor focal length and width (cam_sensor). + + The forward method of the model renders this scene description. Any + of these parameters could instead be passed as inputs to the forward + method and come from a different model. Optionally, camera parameters can + be provided to the forward method in which case the scene is rendered + using those parameters. + """ + + def __init__(self): + super(SceneModel, self).__init__() + self.gamma = 1.0 + # Points. + torch.manual_seed(1) + vert_pos = torch.rand((1, n_points, 3), dtype=torch.float32) * 10.0 + vert_pos[:, :, 2] += 25.0 + vert_pos[:, :, :2] -= 5.0 + self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True)) + self.register_parameter( + "vert_col", + nn.Parameter( + torch.ones(1, n_points, 3, dtype=torch.float32) * 0.5, + requires_grad=True, + ), + ) + self.register_parameter( + "vert_rad", + nn.Parameter( + torch.ones(1, n_points, dtype=torch.float32) * 0.05, requires_grad=True + ), + ) + self.register_parameter( + "vert_opy", + nn.Parameter( + torch.ones(1, n_points, dtype=torch.float32), requires_grad=True + ), + ) + self.register_buffer( + "cam_params", + torch.tensor( + [ + [ + np.sin(angle) * 35.0, + 0.0, + 30.0 - np.cos(angle) * 35.0, + 0.0, + -angle, + 0.0, + 5.0, + 2.0, + ] + for angle in [-1.5, -0.8, -0.4, -0.1, 0.1, 0.4, 0.8, 1.5] + ], + dtype=torch.float32, + ), + ) + self.renderer = Renderer(width, height, n_points) + + def forward(self, cam=None): + if cam is None: + cam = self.cam_params + n_views = 8 + else: + n_views = 1 + return self.renderer.forward( + self.vert_pos.expand(n_views, -1, -1), + self.vert_col.expand(n_views, -1, -1), + self.vert_rad.expand(n_views, -1), + cam, + self.gamma, + 45.0, + ) + + +# Load reference. +ref = torch.stack( + [ + torch.from_numpy( + imageio.imread( + "../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png" + % idx + ) + ).to(torch.float32) + / 255.0 + for idx in range(8) + ] +).to(device) +# Set up model. +model = SceneModel().to(device) +# Optimizer. +optimizer = optim.SGD( + [ + {"params": [model.vert_col], "lr": 1e-1}, + {"params": [model.vert_rad], "lr": 1e-3}, + {"params": [model.vert_pos], "lr": 1e-3}, + ] +) + +# For visualization. +angle = 0.0 +print("Writing video to `%s`." % (path.abspath("multiview.avi"))) +writer = imageio.get_writer("multiview.gif", format="gif", fps=25) + +# Optimize. +for i in range(300): + optimizer.zero_grad() + result = model() + # Visualize. + result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) + cv2.imshow("opt", result_im[0, :, :, ::-1]) + overlay_img = np.ascontiguousarray( + ((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[ + 0, :, :, ::-1 + ] + ) + overlay_img = cv2.putText( + overlay_img, + "Step %d" % (i), + (10, 40), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 0), + 2, + cv2.LINE_AA, + False, + ) + cv2.imshow("overlay", overlay_img) + cv2.waitKey(1) + # Update. + loss = ((result - ref) ** 2).sum() + print("loss {}: {}".format(i, loss.item())) + loss.backward() + optimizer.step() + # Cleanup. + with torch.no_grad(): + model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0) + # Remove points. + model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0 + model.vert_rad.data[model.vert_rad < 0.001] = 0.0001 + vd = ( + (model.vert_col - torch.ones(1, 1, 3, dtype=torch.float32).to(device)) + .abs() + .sum(dim=2) + ) + model.vert_pos.data[vd <= 0.2] = -1000.0 + # Rotating visualization. + cam_control = torch.tensor( + [ + [ + np.sin(angle) * 35.0, + 0.0, + 30.0 - np.cos(angle) * 35.0, + 0.0, + -angle, + 0.0, + 5.0, + 2.0, + ] + ], + dtype=torch.float32, + ).to(device) + with torch.no_grad(): + result = model.forward(cam=cam_control)[0] + result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) + cv2.imshow("vis", result_im[:, :, ::-1]) + writer.append_data(result_im) + angle += 0.05 +writer.close() diff --git a/docs/examples/pulsar_optimization.py b/docs/examples/pulsar_optimization.py new file mode 100755 index 00000000..67b2f81b --- /dev/null +++ b/docs/examples/pulsar_optimization.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +""" +This example demonstrates scene optimization with the plain +pulsar interface. For this, a reference image has been pre-generated +(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png`). +The scene is initialized with random spheres. Gradient-based +optimization is used to converge towards a faithful +scene representation. +""" +import cv2 +import imageio +import numpy as np +import torch +from pytorch3d.renderer.points.pulsar import Renderer +from torch import nn, optim + + +n_points = 10_000 +width = 1_000 +height = 1_000 +device = torch.device("cuda") + + +class SceneModel(nn.Module): + """ + A simple scene model to demonstrate use of pulsar in PyTorch modules. + + The scene model is parameterized with sphere locations (vert_pos), + channel content (vert_col), radiuses (vert_rad), camera position (cam_pos), + camera rotation (cam_rot) and sensor focal length and width (cam_sensor). + + The forward method of the model renders this scene description. Any + of these parameters could instead be passed as inputs to the forward + method and come from a different model. + """ + + def __init__(self): + super(SceneModel, self).__init__() + self.gamma = 1.0 + # Points. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True)) + self.register_parameter( + "vert_col", + nn.Parameter( + torch.ones(n_points, 3, dtype=torch.float32) * 0.5, requires_grad=True + ), + ) + self.register_parameter( + "vert_rad", + nn.Parameter( + torch.ones(n_points, dtype=torch.float32) * 0.3, requires_grad=True + ), + ) + self.register_buffer( + "cam_params", + torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32), + ) + # The volumetric optimization works better with a higher number of tracked + # intersections per ray. + self.renderer = Renderer(width, height, n_points, n_track=32) + + def forward(self): + return self.renderer.forward( + self.vert_pos, + self.vert_col, + self.vert_rad, + self.cam_params, + self.gamma, + 45.0, + return_forward_info=True, + ) + + +# Load reference. +ref = ( + torch.from_numpy( + imageio.imread( + "../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png" + ) + ).to(torch.float32) + / 255.0 +).to(device) +# Set up model. +model = SceneModel().to(device) +# Optimizer. +optimizer = optim.SGD( + [ + {"params": [model.vert_col], "lr": 1e0}, + {"params": [model.vert_rad], "lr": 5e-3}, + {"params": [model.vert_pos], "lr": 1e-2}, + ] +) + +# Optimize. +for i in range(500): + optimizer.zero_grad() + result, result_info = model() + # Visualize. + result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) + cv2.imshow("opt", result_im[:, :, ::-1]) + overlay_img = np.ascontiguousarray( + ((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[ + :, :, ::-1 + ] + ) + overlay_img = cv2.putText( + overlay_img, + "Step %d" % (i), + (10, 40), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 0), + 2, + cv2.LINE_AA, + False, + ) + cv2.imshow("overlay", overlay_img) + cv2.waitKey(1) + # Update. + loss = ((result - ref) ** 2).sum() + print("loss {}: {}".format(i, loss.item())) + loss.backward() + optimizer.step() + # Cleanup. + with torch.no_grad(): + model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0) + # Remove points. + model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0 + model.vert_rad.data[model.vert_rad < 0.001] = 0.0001 + vd = ( + (model.vert_col - torch.ones(3, dtype=torch.float32).to(device)) + .abs() + .sum(dim=1) + ) + model.vert_pos.data[vd <= 0.2] = -1000.0 diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index c80800e3..5fae7948 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -1,6 +1,11 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +// clang-format off +#include "./pulsar/global.h" // Include before . #include +// clang-format on +#include "./pulsar/pytorch/renderer.h" +#include "./pulsar/pytorch/tensor_util.h" #include "blending/sigmoid_alpha_blend.h" #include "compositing/alpha_composite.h" #include "compositing/norm_weighted_sum.h" @@ -65,4 +70,90 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_point_dist_backward", &FacePointDistanceBackward); m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward); m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward); + + // Pulsar. +#ifdef PULSAR_LOGGING_ENABLED + c10::ShowLogInfoToStderr(); +#endif + py::class_< + pulsar::pytorch::Renderer, + std::shared_ptr>(m, "PulsarRenderer") + .def(py::init< + const uint&, + const uint&, + const uint&, + const bool&, + const bool&, + const float&, + const uint&, + const uint&>()) + .def( + "__eq__", + [](const pulsar::pytorch::Renderer& a, + const pulsar::pytorch::Renderer& b) { return a == b; }, + py::is_operator()) + .def( + "__ne__", + [](const pulsar::pytorch::Renderer& a, + const pulsar::pytorch::Renderer& b) { return !(a == b); }, + py::is_operator()) + .def( + "__repr__", + [](const pulsar::pytorch::Renderer& self) { + std::stringstream ss; + ss << self; + return ss.str(); + }) + .def( + "forward", + &pulsar::pytorch::Renderer::forward, + py::arg("vert_pos"), + py::arg("vert_col"), + py::arg("vert_radii"), + + py::arg("cam_pos"), + py::arg("pixel_0_0_center"), + py::arg("pixel_vec_x"), + py::arg("pixel_vec_y"), + py::arg("focal_length"), + py::arg("principal_point_offsets"), + + py::arg("gamma"), + py::arg("max_depth"), + py::arg("min_depth") /* = 0.f*/, + py::arg( + "bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */ + , + py::arg("opacity") /* = at::nullopt ... */, + py::arg("percent_allowed_difference") = 0.01f, + py::arg("max_n_hits") = MAX_UINT, + py::arg("mode") = 0) + .def("backward", &pulsar::pytorch::Renderer::backward) + .def_property( + "device_tracker", + [](const pulsar::pytorch::Renderer& self) { + return self.device_tracker; + }, + [](pulsar::pytorch::Renderer& self, const torch::Tensor& val) { + self.device_tracker = val; + }) + .def_property_readonly("width", &pulsar::pytorch::Renderer::width) + .def_property_readonly("height", &pulsar::pytorch::Renderer::height) + .def_property_readonly( + "max_num_balls", &pulsar::pytorch::Renderer::max_num_balls) + .def_property_readonly( + "orthogonal", &pulsar::pytorch::Renderer::orthogonal) + .def_property_readonly( + "right_handed", &pulsar::pytorch::Renderer::right_handed) + .def_property_readonly("n_track", &pulsar::pytorch::Renderer::n_track); + m.def( + "pulsar_sphere_ids_from_result_info_nograd", + &pulsar::pytorch::sphere_ids_from_result_info_nograd); + // Constants. + m.attr("EPS") = py::float_(EPS); + m.attr("MAX_FLOAT") = py::float_(MAX_FLOAT); + m.attr("MAX_INT") = py::int_(MAX_INT); + m.attr("MAX_UINT") = py::int_(MAX_UINT); + m.attr("MAX_USHORT") = py::int_(MAX_USHORT); + m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES); } diff --git a/pytorch3d/csrc/pulsar/constants.h b/pytorch3d/csrc/pulsar/constants.h new file mode 100644 index 00000000..787f0456 --- /dev/null +++ b/pytorch3d/csrc/pulsar/constants.h @@ -0,0 +1,12 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_CONSTANTS_H_ +#define PULSAR_NATIVE_CONSTANTS_H_ + +#define EPS 1E-6 +#define FEPS 1E-6f +#define MAX_FLOAT 3.4E38f +#define MAX_INT 2147483647 +#define MAX_UINT 4294967295u +#define MAX_USHORT 65535u + +#endif diff --git a/pytorch3d/csrc/pulsar/cuda/README.md b/pytorch3d/csrc/pulsar/cuda/README.md new file mode 100644 index 00000000..60c5d07c --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/README.md @@ -0,0 +1,5 @@ +# CUDA device compilation units + +This folder contains `.cu` files to create compilation units +for device-specific functions. See `../include/README.md` for +more information. diff --git a/pytorch3d/csrc/pulsar/cuda/commands.h b/pytorch3d/csrc/pulsar/cuda/commands.h new file mode 100644 index 00000000..fa966e8b --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/commands.h @@ -0,0 +1,501 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_CUDA_COMMANDS_H_ +#define PULSAR_NATIVE_CUDA_COMMANDS_H_ + +// Definitions for GPU commands. +#include +#include +namespace cg = cooperative_groups; + +#ifdef __DRIVER_TYPES_H__ +#ifndef DEVICE_RESET +#define DEVICE_RESET cudaDeviceReset(); +#endif +#else +#ifndef DEVICE_RESET +#define DEVICE_RESET +#endif +#endif + +#define HANDLECUDA(CMD) CMD +// handleCudaError((CMD), __FILE__, __LINE__) +inline void +handleCudaError(const cudaError_t err, const char* file, const int line) { + if (err != cudaSuccess) { +#ifndef __NVCC__ + fprintf( + stderr, + "%s(%i) : getLastCudaError() CUDA error :" + " (%d) %s.\n", + file, + line, + static_cast(err), + cudaGetErrorString(err)); + DEVICE_RESET + exit(1); +#endif + } +} +inline void +getLastCudaError(const char* errorMessage, const char* file, const int line) { + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + fprintf(stderr, "Error: %s.", errorMessage); + handleCudaError(err, file, line); + } +} + +#define ALIGN(VAL) __align__(VAL) +#define SYNC() HANDLECUDE(cudaDeviceSynchronize()) +#define THREADFENCE_B() __threadfence_block() +#define SHFL_SYNC(a, b, c) __shfl_sync((a), (b), (c)) +#define SHARED __shared__ +#define ACTIVEMASK() __activemask() +#define BALLOT(mask, val) __ballot_sync((mask), val) +/** + * Find the cumulative sum within a warp up to the current + * thread lane, with each mask thread contributing base. + */ +template +DEVICE T +WARP_CUMSUM(const cg::coalesced_group& group, const uint& mask, const T& base) { + T ret = base; + T shfl_val; + shfl_val = __shfl_down_sync(mask, ret, 1u); // Deactivate the rightmost lane. + ret += (group.thread_rank() < 31) * shfl_val; + shfl_val = __shfl_down_sync(mask, ret, 2u); + ret += (group.thread_rank() < 30) * shfl_val; + shfl_val = __shfl_down_sync(mask, ret, 4u); // ...4 + ret += (group.thread_rank() < 28) * shfl_val; + shfl_val = __shfl_down_sync(mask, ret, 8u); // ...8 + ret += (group.thread_rank() < 24) * shfl_val; + shfl_val = __shfl_down_sync(mask, ret, 16u); // ...16 + ret += (group.thread_rank() < 16) * shfl_val; + return ret; +} + +template +DEVICE T +WARP_MAX(const cg::coalesced_group& group, const uint& mask, const T& base) { + T ret = base; + ret = max(ret, __shfl_down_sync(mask, ret, 16u)); + ret = max(ret, __shfl_down_sync(mask, ret, 8u)); + ret = max(ret, __shfl_down_sync(mask, ret, 4u)); + ret = max(ret, __shfl_down_sync(mask, ret, 2u)); + ret = max(ret, __shfl_down_sync(mask, ret, 1u)); + return ret; +} + +template +DEVICE T +WARP_SUM(const cg::coalesced_group& group, const uint& mask, const T& base) { + T ret = base; + ret = ret + __shfl_down_sync(mask, ret, 16u); + ret = ret + __shfl_down_sync(mask, ret, 8u); + ret = ret + __shfl_down_sync(mask, ret, 4u); + ret = ret + __shfl_down_sync(mask, ret, 2u); + ret = ret + __shfl_down_sync(mask, ret, 1u); + return ret; +} + +INLINE DEVICE float3 WARP_SUM_FLOAT3( + const cg::coalesced_group& group, + const uint& mask, + const float3& base) { + float3 ret = base; + ret.x = WARP_SUM(group, mask, base.x); + ret.y = WARP_SUM(group, mask, base.y); + ret.z = WARP_SUM(group, mask, base.z); + return ret; +} + +// Floating point. +// #define FMUL(a, b) __fmul_rn((a), (b)) +#define FMUL(a, b) ((a) * (b)) +#define FDIV(a, b) __fdiv_rn((a), (b)) +// #define FSUB(a, b) __fsub_rn((a), (b)) +#define FSUB(a, b) ((a) - (b)) +#define FADD(a, b) __fadd_rn((a), (b)) +#define FSQRT(a) __fsqrt_rn(a) +#define FEXP(a) fasterexp(a) +#define FLN(a) fasterlog(a) +#define FPOW(a, b) __powf((a), (b)) +#define FMAX(a, b) fmax((a), (b)) +#define FMIN(a, b) fmin((a), (b)) +#define FCEIL(a) ceilf(a) +#define FFLOOR(a) floorf(a) +#define FROUND(x) nearbyintf(x) +#define FSATURATE(x) __saturatef(x) +#define FABS(a) abs(a) +#define IASF(a, loc) (loc) = __int_as_float(a) +#define FASI(a, loc) (loc) = __float_as_int(a) +#define FABSLEQAS(a, b, c) \ + ((a) <= (b) ? FSUB((b), (a)) <= (c) : FSUB((a), (b)) < (c)) +/** Calculates x*y+z. */ +#define FMA(x, y, z) __fmaf_rn((x), (y), (z)) +#define I2F(a) __int2float_rn(a) +#define FRCP(x) __frcp_rn(x) +__device__ static float atomicMax(float* address, float val) { + int* address_as_i = (int*)address; + int old = *address_as_i, assumed; + do { + assumed = old; + old = ::atomicCAS( + address_as_i, + assumed, + __float_as_int(::fmaxf(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +} +__device__ static float atomicMin(float* address, float val) { + int* address_as_i = (int*)address; + int old = *address_as_i, assumed; + do { + assumed = old; + old = ::atomicCAS( + address_as_i, + assumed, + __float_as_int(::fminf(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +} +#define DMAX(a, b) FMAX(a, b) +#define DMIN(a, b) FMIN(a, b) +#define DSQRT(a) sqrt(a) +#define DSATURATE(a) DMIN(1., DMAX(0., (a))) +// half +#define HADD(a, b) __hadd((a), (b)) +#define HSUB2(a, b) __hsub2((a), (b)) +#define HMUL2(a, b) __hmul2((a), (b)) +#define HSQRT(a) hsqrt(a) + +// uint. +#define CLZ(VAL) __clz(VAL) +#define POPC(a) __popc(a) +// +// +// +// +// +// +// +// +// +#define ATOMICADD(PTR, VAL) atomicAdd((PTR), (VAL)) +#define ATOMICADD_F3(PTR, VAL) \ + ATOMICADD(&((PTR)->x), VAL.x); \ + ATOMICADD(&((PTR)->y), VAL.y); \ + ATOMICADD(&((PTR)->z), VAL.z); +#if (CUDART_VERSION >= 10000) +#define ATOMICADD_B(PTR, VAL) atomicAdd_block((PTR), (VAL)) +#else +#define ATOMICADD_B(PTR, VAL) ATOMICADD(PTR, VAL) +#endif +// +// +// +// +// int. +#define IMIN(a, b) min((a), (b)) +#define IMAX(a, b) max((a), (b)) +#define IABS(a) abs(a) + +// Checks. +#define CHECKOK THCudaCheck +#define ARGCHECK THArgCheck + +// Math. +#define NORM3DF(x, y, z) norm3df(x, y, z) +#define RNORM3DF(x, y, z) rnorm3df(x, y, z) + +// High level. +INLINE DEVICE void prefetch_l1(unsigned long addr) { + asm(" prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); +} +#define PREFETCH(PTR) prefetch_l1((unsigned long)(PTR)) +#define GET_SORT_WS_SIZE(RES_PTR, KEY_TYPE, VAL_TYPE, NUM_OBJECTS) \ + cub::DeviceRadixSort::SortPairsDescending( \ + (void*)NULL, \ + *(RES_PTR), \ + reinterpret_cast(NULL), \ + reinterpret_cast(NULL), \ + reinterpret_cast(NULL), \ + reinterpret_cast(NULL), \ + (NUM_OBJECTS)); +#define GET_REDUCE_WS_SIZE(RES_PTR, TYPE, REDUCE_OP, NUM_OBJECTS) \ + { \ + TYPE init = TYPE(); \ + cub::DeviceReduce::Reduce( \ + (void*)NULL, \ + *(RES_PTR), \ + (TYPE*)NULL, \ + (TYPE*)NULL, \ + (NUM_OBJECTS), \ + (REDUCE_OP), \ + init); \ + } +#define GET_SELECT_WS_SIZE( \ + RES_PTR, TYPE_SELECTOR, TYPE_SELECTION, NUM_OBJECTS) \ + { \ + cub::DeviceSelect::Flagged( \ + (void*)NULL, \ + *(RES_PTR), \ + (TYPE_SELECTION*)NULL, \ + (TYPE_SELECTOR*)NULL, \ + (TYPE_SELECTION*)NULL, \ + (int*)NULL, \ + (NUM_OBJECTS)); \ + } +#define GET_SUM_WS_SIZE(RES_PTR, TYPE_SUM, NUM_OBJECTS) \ + { \ + cub::DeviceReduce::Sum( \ + (void*)NULL, \ + *(RES_PTR), \ + (TYPE_SUM*)NULL, \ + (TYPE_SUM*)NULL, \ + NUM_OBJECTS); \ + } +#define GET_MM_WS_SIZE(RES_PTR, TYPE, NUM_OBJECTS) \ + { \ + TYPE init = TYPE(); \ + cub::DeviceReduce::Max( \ + (void*)NULL, *(RES_PTR), (TYPE*)NULL, (TYPE*)NULL, (NUM_OBJECTS)); \ + } +#define SORT_DESCENDING( \ + TMPN1, SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS) \ + void* TMPN1 = NULL; \ + size_t TMPN1##_bytes = 0; \ + cub::DeviceRadixSort::SortPairsDescending( \ + TMPN1, \ + TMPN1##_bytes, \ + (SORT_PTR), \ + (SORTED_PTR), \ + (VAL_PTR), \ + (VAL_SORTED_PTR), \ + (NUM_OBJECTS)); \ + HANDLECUDA(cudaMalloc(&TMPN1, TMPN1##_bytes)); \ + cub::DeviceRadixSort::SortPairsDescending( \ + TMPN1, \ + TMPN1##_bytes, \ + (SORT_PTR), \ + (SORTED_PTR), \ + (VAL_PTR), \ + (VAL_SORTED_PTR), \ + (NUM_OBJECTS)); \ + HANDLECUDA(cudaFree(TMPN1)); +#define SORT_DESCENDING_WS( \ + TMPN1, \ + SORT_PTR, \ + SORTED_PTR, \ + VAL_PTR, \ + VAL_SORTED_PTR, \ + NUM_OBJECTS, \ + WORKSPACE_PTR, \ + WORKSPACE_BYTES) \ + cub::DeviceRadixSort::SortPairsDescending( \ + (WORKSPACE_PTR), \ + (WORKSPACE_BYTES), \ + (SORT_PTR), \ + (SORTED_PTR), \ + (VAL_PTR), \ + (VAL_SORTED_PTR), \ + (NUM_OBJECTS)); +#define SORT_ASCENDING_WS( \ + SORT_PTR, \ + SORTED_PTR, \ + VAL_PTR, \ + VAL_SORTED_PTR, \ + NUM_OBJECTS, \ + WORKSPACE_PTR, \ + WORKSPACE_BYTES, \ + STREAM) \ + cub::DeviceRadixSort::SortPairs( \ + (WORKSPACE_PTR), \ + (WORKSPACE_BYTES), \ + (SORT_PTR), \ + (SORTED_PTR), \ + (VAL_PTR), \ + (VAL_SORTED_PTR), \ + (NUM_OBJECTS), \ + 0, \ + sizeof(*(SORT_PTR)) * 8, \ + (STREAM)); +#define SUM_WS( \ + SUM_PTR, OUT_PTR, NUM_OBJECTS, WORKSPACE_PTR, WORKSPACE_BYTES, STREAM) \ + cub::DeviceReduce::Sum( \ + (WORKSPACE_PTR), \ + (WORKSPACE_BYTES), \ + (SUM_PTR), \ + (OUT_PTR), \ + (NUM_OBJECTS), \ + (STREAM)); +#define MIN_WS( \ + MIN_PTR, OUT_PTR, NUM_OBJECTS, WORKSPACE_PTR, WORKSPACE_BYTES, STREAM) \ + cub::DeviceReduce::Min( \ + (WORKSPACE_PTR), \ + (WORKSPACE_BYTES), \ + (MIN_PTR), \ + (OUT_PTR), \ + (NUM_OBJECTS), \ + (STREAM)); +#define MAX_WS( \ + MAX_PTR, OUT_PTR, NUM_OBJECTS, WORKSPACE_PTR, WORKSPACE_BYTES, STREAM) \ + cub::DeviceReduce::Min( \ + (WORKSPACE_PTR), \ + (WORKSPACE_BYTES), \ + (MAX_PTR), \ + (OUT_PTR), \ + (NUM_OBJECTS), \ + (STREAM)); +// +// +// +// TODO: rewrite using nested contexts instead of temporary names. +#define REDUCE(REDUCE_PTR, RESULT_PTR, NUM_ITEMS, REDUCE_OP, REDUCE_INIT) \ + cub::DeviceReduce::Reduce( \ + TMPN1, \ + TMPN1##_bytes, \ + (REDUCE_PTR), \ + (RESULT_PTR), \ + (NUM_ITEMS), \ + (REDUCE_OP), \ + (REDUCE_INIT)); \ + HANDLECUDA(cudaMalloc(&TMPN1, TMPN1##_bytes)); \ + cub::DeviceReduce::Reduce( \ + TMPN1, \ + TMPN1##_bytes, \ + (REDUCE_PTR), \ + (RESULT_PTR), \ + (NUM_ITEMS), \ + (REDUCE_OP), \ + (REDUCE_INIT)); \ + HANDLECUDA(cudaFree(TMPN1)); +#define REDUCE_WS( \ + REDUCE_PTR, \ + RESULT_PTR, \ + NUM_ITEMS, \ + REDUCE_OP, \ + REDUCE_INIT, \ + WORKSPACE_PTR, \ + WORSPACE_BYTES, \ + STREAM) \ + cub::DeviceReduce::Reduce( \ + (WORKSPACE_PTR), \ + (WORSPACE_BYTES), \ + (REDUCE_PTR), \ + (RESULT_PTR), \ + (NUM_ITEMS), \ + (REDUCE_OP), \ + (REDUCE_INIT), \ + (STREAM)); +#define SELECT_FLAGS_WS( \ + FLAGS_PTR, \ + ITEM_PTR, \ + OUT_PTR, \ + NUM_SELECTED_PTR, \ + NUM_ITEMS, \ + WORKSPACE_PTR, \ + WORSPACE_BYTES, \ + STREAM) \ + cub::DeviceSelect::Flagged( \ + (WORKSPACE_PTR), \ + (WORSPACE_BYTES), \ + (ITEM_PTR), \ + (FLAGS_PTR), \ + (OUT_PTR), \ + (NUM_SELECTED_PTR), \ + (NUM_ITEMS), \ + stream = (STREAM)); + +#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \ + HANDLECUDA(cudaMemcpy( \ + (PTR_D), (PTR_H), sizeof(TYPE) * (SIZE), cudaMemcpyHostToDevice)) +#define COPY_DEV_HOST(PTR_H, PTR_D, TYPE, SIZE) \ + HANDLECUDA(cudaMemcpy( \ + (PTR_H), (PTR_D), sizeof(TYPE) * (SIZE), cudaMemcpyDeviceToHost)) +#define COPY_DEV_DEV(PTR_T, PTR_S, TYPE, SIZE) \ + HANDLECUDA(cudaMemcpy( \ + (PTR_T), (PTR_S), sizeof(TYPE) * (SIZE), cudaMemcpyDeviceToDevice)) +// +// We *must* use cudaMallocManaged for pointers on device that should +// interact with pytorch. However, this comes at a significant speed penalty. +// We're using plain CUDA pointers for the rendering operations and +// explicitly copy results to managed pointers wrapped for pytorch (see +// pytorch/util.h). +#define MALLOC(VAR, TYPE, SIZE) cudaMalloc(&(VAR), sizeof(TYPE) * (SIZE)) +#define FREE(PTR) HANDLECUDA(cudaFree(PTR)) +#define MEMSET(VAR, VAL, TYPE, SIZE, STREAM) \ + HANDLECUDA(cudaMemsetAsync((VAR), (VAL), sizeof(TYPE) * (SIZE), (STREAM))) + +#define LAUNCH_MAX_PARALLEL_1D(FUNC, N, STREAM, ...) \ + { \ + int64_t max_threads = \ + at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; \ + uint num_threads = min((N), max_threads); \ + uint num_blocks = iDivCeil((N), num_threads); \ + FUNC<<>>(__VA_ARGS__); \ + } +#define LAUNCH_PARALLEL_1D(FUNC, N, TN, STREAM, ...) \ + { \ + uint num_threads = min(static_cast(N), static_cast(TN)); \ + uint num_blocks = iDivCeil((N), num_threads); \ + FUNC<<>>(__VA_ARGS__); \ + } +#define LAUNCH_MAX_PARALLEL_2D(FUNC, NX, NY, STREAM, ...) \ + { \ + int64_t max_threads = \ + at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; \ + int64_t max_threads_sqrt = static_cast(sqrt(max_threads)); \ + dim3 num_threads, num_blocks; \ + num_threads.x = min((NX), max_threads_sqrt); \ + num_blocks.x = iDivCeil((NX), num_threads.x); \ + num_threads.y = min((NY), max_threads_sqrt); \ + num_blocks.y = iDivCeil((NY), num_threads.y); \ + num_threads.z = 1; \ + num_blocks.z = 1; \ + FUNC<<>>(__VA_ARGS__); \ + } +#define LAUNCH_PARALLEL_2D(FUNC, NX, NY, TX, TY, STREAM, ...) \ + { \ + dim3 num_threads, num_blocks; \ + num_threads.x = min((NX), (TX)); \ + num_blocks.x = iDivCeil((NX), num_threads.x); \ + num_threads.y = min((NY), (TY)); \ + num_blocks.y = iDivCeil((NY), num_threads.y); \ + num_threads.z = 1; \ + num_blocks.z = 1; \ + FUNC<<>>(__VA_ARGS__); \ + } + +#define GET_PARALLEL_IDX_1D(VARNAME, N) \ + const uint VARNAME = __mul24(blockIdx.x, blockDim.x) + threadIdx.x; \ + if (VARNAME >= (N)) { \ + return; \ + } +#define GET_PARALLEL_IDS_2D(VAR_X, VAR_Y, WIDTH, HEIGHT) \ + const uint VAR_X = __mul24(blockIdx.x, blockDim.x) + threadIdx.x; \ + const uint VAR_Y = __mul24(blockIdx.y, blockDim.y) + threadIdx.y; \ + if (VAR_X >= (WIDTH) || VAR_Y >= (HEIGHT)) \ + return; +#define END_PARALLEL() +#define END_PARALLEL_NORET() +#define END_PARALLEL_2D_NORET() +#define END_PARALLEL_2D() +#define RETURN_PARALLEL() return +#define CHECKLAUNCH() THCudaCheck(cudaGetLastError()); +#define ISONDEVICE true +#define SYNCDEVICE() HANDLECUDA(cudaDeviceSynchronize()) +#define START_TIME(TN) \ + cudaEvent_t __time_start_##TN, __time_stop_##TN; \ + cudaEventCreate(&__time_start_##TN); \ + cudaEventCreate(&__time_stop_##TN); \ + cudaEventRecord(__time_start_##TN); +#define STOP_TIME(TN) cudaEventRecord(__time_stop_##TN); +#define GET_TIME(TN, TOPTR) \ + cudaEventSynchronize(__time_stop_##TN); \ + cudaEventElapsedTime((TOPTR), __time_start_##TN, __time_stop_##TN); +#define START_TIME_CU(TN) START_TIME(CN) +#define STOP_TIME_CU(TN) STOP_TIME(TN) +#define GET_TIME_CU(TN, TOPTR) GET_TIME(TN, TOPTR) + +#endif diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.backward.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.backward.gpu.cu new file mode 100644 index 00000000..6969a3fc --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.backward.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.backward.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.backward_dbg.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.backward_dbg.gpu.cu new file mode 100644 index 00000000..e38f3b5b --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.backward_dbg.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.backward_dbg.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.calc_gradients.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.calc_gradients.gpu.cu new file mode 100644 index 00000000..0668eced --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.calc_gradients.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.calc_gradients.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.calc_signature.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.calc_signature.gpu.cu new file mode 100644 index 00000000..8e05de28 --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.calc_signature.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.calc_signature.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.construct.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.construct.gpu.cu new file mode 100644 index 00000000..f04df519 --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.construct.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.construct.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.create_selector.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.create_selector.gpu.cu new file mode 100644 index 00000000..c9a64f78 --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.create_selector.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.create_selector.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.destruct.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.destruct.gpu.cu new file mode 100644 index 00000000..e5dec0db --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.destruct.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.destruct.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.fill_bg.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.fill_bg.gpu.cu new file mode 100644 index 00000000..01e6c6f9 --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.fill_bg.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.fill_bg.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.forward.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.forward.gpu.cu new file mode 100644 index 00000000..c73b8f69 --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.forward.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.forward.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.norm_cam_gradients.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.norm_cam_gradients.gpu.cu new file mode 100644 index 00000000..2c26f5a9 --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.norm_cam_gradients.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.norm_cam_gradients.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.norm_sphere_gradients.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.norm_sphere_gradients.gpu.cu new file mode 100644 index 00000000..4ee8128e --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.norm_sphere_gradients.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.norm_sphere_gradients.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/cuda/renderer.render.gpu.cu b/pytorch3d/csrc/pulsar/cuda/renderer.render.gpu.cu new file mode 100644 index 00000000..c9a664e1 --- /dev/null +++ b/pytorch3d/csrc/pulsar/cuda/renderer.render.gpu.cu @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.render.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/global.h b/pytorch3d/csrc/pulsar/global.h new file mode 100644 index 00000000..6c0e94c1 --- /dev/null +++ b/pytorch3d/csrc/pulsar/global.h @@ -0,0 +1,85 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_GLOBAL_H +#define PULSAR_GLOBAL_H + +#include "./constants.h" +#ifndef WIN32 +#include +#endif + +#if defined(_WIN64) || defined(_WIN32) +#define uint unsigned int +#define ushort unsigned short +#endif + +#include "./logging.h" // <- include before torch/extension.h + +#define MAX_GRAD_SPHERES 128 + +#ifdef __CUDACC__ +#define INLINE __forceinline__ +#define HOST __host__ +#define DEVICE __device__ +#define GLOBAL __global__ +#define RESTRICT __restrict__ +#define DEBUGBREAK() +#pragma diag_suppress = attribute_not_allowed +#pragma diag_suppress = 1866 +#pragma diag_suppress = 2941 +#pragma diag_suppress = 2951 +#pragma diag_suppress = 2967 +#else // __CUDACC__ +#define INLINE inline +#define HOST +#define DEVICE +#define GLOBAL +#define RESTRICT +#define DEBUGBREAK() std::raise(SIGINT) +// Don't care about pytorch warnings; they shouldn't clutter our warnings. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Weverything" +#include +#include +#pragma clang diagnostic pop +namespace py = pybind11; +inline float3 make_float3(const float& x, const float& y, const float& z) { + float3 res; + res.x = x; + res.y = y; + res.z = z; + return res; +} + +inline bool operator==(const float3& a, const float3& b) { + return a.x == b.x && a.y == b.y && a.z == b.z; +} +#endif // __CUDACC__ +#define IHD INLINE HOST DEVICE + +// An assertion command that can be used on host and device. +#ifdef PULSAR_ASSERTIONS +#ifdef __CUDACC__ +#define PASSERT(VAL) \ + if (!(VAL)) { \ + printf( \ + "Pulsar assertion failed in %s, line %d: %s.\n", \ + __FILE__, \ + __LINE__, \ + #VAL); \ + } +#else +#define PASSERT(VAL) \ + if (!(VAL)) { \ + printf( \ + "Pulsar assertion failed in %s, line %d: %s.\n", \ + __FILE__, \ + __LINE__, \ + #VAL); \ + std::raise(SIGINT); \ + } +#endif +#else +#define PASSERT(VAL) +#endif + +#endif diff --git a/pytorch3d/csrc/pulsar/host/README.md b/pytorch3d/csrc/pulsar/host/README.md new file mode 100644 index 00000000..34f1bade --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/README.md @@ -0,0 +1,5 @@ +# Device-specific host compilation units + +This folder contains `.cpp` files to create compilation units +for device specific functions. See `../include/README.md` for +more information. diff --git a/pytorch3d/csrc/pulsar/host/commands.h b/pytorch3d/csrc/pulsar/host/commands.h new file mode 100644 index 00000000..737c6224 --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/commands.h @@ -0,0 +1,383 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_COMMANDS_H_ +#define PULSAR_NATIVE_COMMANDS_H_ + +#ifdef _MSC_VER +#include +#define __builtin_popcount (int)__popcnt +#endif + +// Definitions for CPU commands. +// #include +// #include + +namespace cg { +struct coalesced_group { + INLINE uint thread_rank() const { + return 0u; + } + INLINE uint size() const { + return 1u; + } + INLINE uint ballot(uint val) const { + return static_cast(val > 0); + } +}; + +struct thread_block { + INLINE uint thread_rank() const { + return 0u; + } + INLINE uint size() const { + return 1u; + } + INLINE void sync() const {} +}; + +INLINE coalesced_group coalesced_threads() { + coalesced_group ret; + return ret; +} + +INLINE thread_block this_thread_block() { + thread_block ret; + return ret; +} +} // namespace cg +#define SHFL_SYNC(a, b, c) (b) +template +T WARP_CUMSUM( + const cg::coalesced_group& group, + const uint& mask, + const T& base) { + return base; +} + +template +DEVICE T +WARP_MAX(const cg::coalesced_group& group, const uint& mask, const T& base) { + return base; +} + +template +DEVICE T +WARP_SUM(const cg::coalesced_group& group, const uint& mask, const T& base) { + return base; +} + +INLINE DEVICE float3 WARP_SUM_FLOAT3( + const cg::coalesced_group& group, + const uint& mask, + const float3& base) { + return base; +} + +#define ACTIVEMASK() (1u << 31) +#define ALIGN(VAL) +#define SYNC() +#define THREADFENCE_B() +#define BALLOT(mask, val) (val != 0) +#define SHARED +// Floating point. +#define FMAX(a, b) std::fmax((a), (b)) +#define FMIN(a, b) std::fmin((a), (b)) +INLINE float atomicMax(float* address, float val) { + *address = std::max(*address, val); + return *address; +} +INLINE float atomicMin(float* address, float val) { + *address = std::min(*address, val); + return *address; +} +#define FMUL(a, b) ((a) * (b)) +#define FDIV(a, b) ((a) / (b)) +#define FSUB(a, b) ((a) - (b)) +#define FABSLEQAS(a, b, c) \ + ((a) <= (b) ? FSUB((b), (a)) <= (c) : FSUB((a), (b)) < (c)) +#define FADD(a, b) ((a) + (b)) +#define FSQRT(a) sqrtf(a) +#define FEXP(a) fasterexp(a) +#define FLN(a) fasterlog(a) +#define FPOW(a, b) powf((a), (b)) +#define FROUND(x) roundf(x) +#define FCEIL(a) ceilf(a) +#define FFLOOR(a) floorf(a) +#define FSATURATE(x) std::max(0.f, std::min(1.f, x)) +#define FABS(a) abs(a) +#define FMA(x, y, z) ((x) * (y) + (z)) +#define I2F(a) static_cast(a) +#define FRCP(x) (1.f / (x)) +#define IASF(x, loc) memcpy(&(loc), &(x), sizeof(x)) +#define FASI(x, loc) memcpy(&(loc), &(x), sizeof(x)) +#define DMAX(a, b) std::max((a), (b)) +#define DMIN(a, b) std::min((a), (b)) +#define DSATURATE(a) DMIN(1., DMAX(0., (a))) +#define DSQRT(a) sqrt(a) +// +// +// +// +// +// +// +// +// +// +// +// +// uint. +#define CLZ(VAL) _clz(VAL) +template +INLINE T ATOMICADD(T* address, T val) { + T old = *address; + *address += val; + return old; +} +template +INLINE void ATOMICADD_F3(T* address, T val) { + ATOMICADD(&(address->x), val.x); + ATOMICADD(&(address->y), val.y); + ATOMICADD(&(address->z), val.z); +} +#define ATOMICADD_B(a, b) ATOMICADD((a), (b)) +#define POPC(a) __builtin_popcount(a) + +// int. +#define IMIN(a, b) std::min((a), (b)) +#define IMAX(a, b) std::max((a), (b)) +#define IABS(a) abs(a) + +// Checks. +#define CHECKOK THCheck +#define ARGCHECK THArgCheck + +// Math. +#define NORM3DF(x, y, z) sqrtf(x* x + y * y + z * z) +#define RNORM3DF(x, y, z) (1.f / sqrtf(x * x + y * y + z * z)) + +// High level. +#define PREFETCH(PTR) +#define GET_SORT_WS_SIZE(RES_PTR, KEY_TYPE, VAL_TYPE, NUM_OBJECTS) \ + *(RES_PTR) = 0; +#define GET_REDUCE_WS_SIZE(RES_PTR, TYPE, REDUCE_OP, NUM_OBJECTS) \ + *(RES_PTR) = 0; +#define GET_SELECT_WS_SIZE( \ + RES_PTR, TYPE_SELECTOR, TYPE_SELECTION, NUM_OBJECTS) \ + *(RES_PTR) = 0; +#define GET_SUM_WS_SIZE(RES_PTR, TYPE_SUM, NUM_OBJECTS) *(RES_PTR) = 0; +#define GET_MM_WS_SIZE(RES_PTR, TYPE, NUM_OBJECTS) *(RES_PTR) = 0; + +#define SORT_DESCENDING( \ + TMPN1, SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS) \ + std::vector TMPN1(NUM_OBJECTS); \ + std::iota(TMPN1.begin(), TMPN1.end(), 0); \ + const auto TMPN1##_val_ptr = (SORT_PTR); \ + std::sort( \ + TMPN1.begin(), TMPN1.end(), [&TMPN1##_val_ptr](size_t i1, size_t i2) { \ + return TMPN1##_val_ptr[i1] > TMPN1##_val_ptr[i2]; \ + }); \ + for (int i = 0; i < (NUM_OBJECTS); ++i) { \ + (SORTED_PTR)[i] = (SORT_PTR)[TMPN1[i]]; \ + } \ + for (int i = 0; i < (NUM_OBJECTS); ++i) { \ + (VAL_SORTED_PTR)[i] = (VAL_PTR)[TMPN1[i]]; \ + } + +#define SORT_ASCENDING( \ + SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS, STREAM) \ + { \ + std::vector TMPN1(NUM_OBJECTS); \ + std::iota(TMPN1.begin(), TMPN1.end(), 0); \ + const auto TMPN1_val_ptr = (SORT_PTR); \ + std::sort( \ + TMPN1.begin(), \ + TMPN1.end(), \ + [&TMPN1_val_ptr](size_t i1, size_t i2) -> bool { \ + return TMPN1_val_ptr[i1] < TMPN1_val_ptr[i2]; \ + }); \ + for (int i = 0; i < (NUM_OBJECTS); ++i) { \ + (SORTED_PTR)[i] = (SORT_PTR)[TMPN1[i]]; \ + } \ + for (int i = 0; i < (NUM_OBJECTS); ++i) { \ + (VAL_SORTED_PTR)[i] = (VAL_PTR)[TMPN1[i]]; \ + } \ + } + +#define SORT_DESCENDING_WS( \ + TMPN1, \ + SORT_PTR, \ + SORTED_PTR, \ + VAL_PTR, \ + VAL_SORTED_PTR, \ + NUM_OBJECTS, \ + WORSPACE_PTR, \ + WORKSPACE_SIZE) \ + SORT_DESCENDING( \ + TMPN1, SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS) + +#define SORT_ASCENDING_WS( \ + SORT_PTR, \ + SORTED_PTR, \ + VAL_PTR, \ + VAL_SORTED_PTR, \ + NUM_OBJECTS, \ + WORSPACE_PTR, \ + WORKSPACE_SIZE, \ + STREAM) \ + SORT_ASCENDING( \ + SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS, STREAM) + +#define REDUCE(REDUCE_PTR, RESULT_PTR, NUM_ITEMS, REDUCE_OP, REDUCE_INIT) \ + { \ + *(RESULT_PTR) = (REDUCE_INIT); \ + for (int i = 0; i < (NUM_ITEMS); ++i) { \ + *(RESULT_PTR) = REDUCE_OP(*(RESULT_PTR), (REDUCE_PTR)[i]); \ + } \ + } +#define REDUCE_WS( \ + REDUCE_PTR, \ + RESULT_PTR, \ + NUM_ITEMS, \ + REDUCE_OP, \ + REDUCE_INIT, \ + WORKSPACE_PTR, \ + WORKSPACE_SIZE, \ + STREAM) \ + REDUCE(REDUCE_PTR, RESULT_PTR, NUM_ITEMS, REDUCE_OP, REDUCE_INIT) + +#define SELECT_FLAGS_WS( \ + FLAGS_PTR, \ + ITEM_PTR, \ + OUT_PTR, \ + NUM_SELECTED_PTR, \ + NUM_ITEMS, \ + WORKSPACE_PTR, \ + WORSPACE_BYTES, \ + STREAM) \ + { \ + *NUM_SELECTED_PTR = 0; \ + ptrdiff_t write_pos = 0; \ + for (int i = 0; i < NUM_ITEMS; ++i) { \ + if (FLAGS_PTR[i]) { \ + OUT_PTR[write_pos++] = ITEM_PTR[i]; \ + *NUM_SELECTED_PTR += 1; \ + } \ + } \ + } + +template +void SUM_WS( + T* SUM_PTR, + T* OUT_PTR, + size_t NUM_OBJECTS, + char* WORKSPACE_PTR, + size_t WORKSPACE_BYTES, + cudaStream_t STREAM) { + *(OUT_PTR) = T(); + for (int i = 0; i < (NUM_OBJECTS); ++i) { + *(OUT_PTR) = *(OUT_PTR) + (SUM_PTR)[i]; + } +} + +template +void MIN_WS( + T* MIN_PTR, + T* OUT_PTR, + size_t NUM_OBJECTS, + char* WORKSPACE_PTR, + size_t WORKSPACE_BYTES, + cudaStream_t STREAM) { + *(OUT_PTR) = T(); + for (int i = 0; i < (NUM_OBJECTS); ++i) { + *(OUT_PTR) = std::min(*(OUT_PTR), (MIN_PTR)[i]); + } +} + +template +void MAX_WS( + T* MAX_PTR, + T* OUT_PTR, + size_t NUM_OBJECTS, + char* WORKSPACE_PTR, + size_t WORKSPACE_BYTES, + cudaStream_t STREAM) { + *(OUT_PTR) = T(); + for (int i = 0; i < (NUM_OBJECTS); ++i) { + *(OUT_PTR) = std::max(*(OUT_PTR), (MAX_PTR)[i]); + } +} +// +// +// +// +#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \ + std::memcpy((PTR_D), (PTR_H), sizeof(TYPE) * (SIZE)) +// +#define COPY_DEV_HOST(PTR_H, PTR_D, TYPE, SIZE) \ + std::memcpy((PTR_H), (PTR_D), sizeof(TYPE) * (SIZE)) +// +#define COPY_DEV_DEV(PTR_T, PTR_S, TYPE, SIZE) \ + std::memcpy((PTR_T), (PTR_S), sizeof(TYPE) * SIZE) +// + +#define MALLOC(VAR, TYPE, SIZE) MALLOC_HOST(VAR, TYPE, SIZE) +#define FREE(PTR) FREE_HOST(PTR) +#define MEMSET(VAR, VAL, TYPE, SIZE, STREAM) \ + memset((VAR), (VAL), sizeof(TYPE) * (SIZE)) +// + +#define LAUNCH_MAX_PARALLEL_1D(FUNC, N, STREAM, ...) FUNC(__VA_ARGS__); +#define LAUNCH_PARALLEL_1D(FUNC, N, TN, STREAM, ...) FUNC(__VA_ARGS__); +#define LAUNCH_MAX_PARALLEL_2D(FUNC, NX, NY, STREAM, ...) FUNC(__VA_ARGS__); +#define LAUNCH_PARALLEL_2D(FUNC, NX, NY, TX, TY, STREAM, ...) FUNC(__VA_ARGS__); +// +// +// +// +// +#define GET_PARALLEL_IDX_1D(VARNAME, N) \ + for (uint VARNAME = 0; VARNAME < (N); ++VARNAME) { +#define GET_PARALLEL_IDS_2D(VAR_X, VAR_Y, WIDTH, HEIGHT) \ + int2 blockDim; \ + blockDim.x = 1; \ + blockDim.y = 1; \ + uint __parallel_2d_width = WIDTH; \ + uint __parallel_2d_height = HEIGHT; \ + for (uint VAR_Y = 0; VAR_Y < __parallel_2d_height; ++(VAR_Y)) { \ + for (uint VAR_X = 0; VAR_X < __parallel_2d_width; ++(VAR_X)) { +// +// +// +#define END_PARALLEL() \ + end_parallel:; \ + } +#define END_PARALLEL_NORET() } +#define END_PARALLEL_2D() \ + end_parallel:; \ + } \ + } +#define END_PARALLEL_2D_NORET() \ + } \ + } +#define RETURN_PARALLEL() goto end_parallel; +#define CHECKLAUNCH() +#define ISONDEVICE false +#define SYNCDEVICE() +#define START_TIME(TN) \ + auto __time_start_##TN = std::chrono::steady_clock::now(); +#define STOP_TIME(TN) auto __time_stop_##TN = std::chrono::steady_clock::now(); +#define GET_TIME(TN, TOPTR) \ + *TOPTR = std::chrono::duration_cast( \ + __time_stop_##TN - __time_start_##TN) \ + .count() +#define START_TIME_CU(TN) \ + cudaEvent_t __time_start_##TN, __time_stop_##TN; \ + cudaEventCreate(&__time_start_##TN); \ + cudaEventCreate(&__time_stop_##TN); \ + cudaEventRecord(__time_start_##TN); +#define STOP_TIME_CU(TN) cudaEventRecord(__time_stop_##TN); +#define GET_TIME_CU(TN, TOPTR) \ + cudaEventSynchronize(__time_stop_##TN); \ + cudaEventElapsedTime((TOPTR), __time_start_##TN, __time_stop_##TN); + +#endif diff --git a/pytorch3d/csrc/pulsar/host/renderer.backward.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.backward.cpu.cpp new file mode 100644 index 00000000..6969a3fc --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.backward.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.backward.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.backward_dbg.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.backward_dbg.cpu.cpp new file mode 100644 index 00000000..e38f3b5b --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.backward_dbg.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.backward_dbg.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.calc_gradients.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.calc_gradients.cpu.cpp new file mode 100644 index 00000000..0668eced --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.calc_gradients.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.calc_gradients.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.calc_signature.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.calc_signature.cpu.cpp new file mode 100644 index 00000000..8e05de28 --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.calc_signature.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.calc_signature.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.construct.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.construct.cpu.cpp new file mode 100644 index 00000000..f04df519 --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.construct.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.construct.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.create_selector.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.create_selector.cpu.cpp new file mode 100644 index 00000000..c9a64f78 --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.create_selector.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.create_selector.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.destruct.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.destruct.cpu.cpp new file mode 100644 index 00000000..e5dec0db --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.destruct.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.destruct.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.fill_bg.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.fill_bg.cpu.cpp new file mode 100644 index 00000000..01e6c6f9 --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.fill_bg.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.fill_bg.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.forward.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.forward.cpu.cpp new file mode 100644 index 00000000..c73b8f69 --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.forward.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.forward.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.norm_cam_gradients.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.norm_cam_gradients.cpu.cpp new file mode 100644 index 00000000..2c26f5a9 --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.norm_cam_gradients.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.norm_cam_gradients.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.norm_sphere_gradients.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.norm_sphere_gradients.cpu.cpp new file mode 100644 index 00000000..4ee8128e --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.norm_sphere_gradients.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.norm_sphere_gradients.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/host/renderer.render.cpu.cpp b/pytorch3d/csrc/pulsar/host/renderer.render.cpu.cpp new file mode 100644 index 00000000..c9a664e1 --- /dev/null +++ b/pytorch3d/csrc/pulsar/host/renderer.render.cpu.cpp @@ -0,0 +1,2 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "../include/renderer.render.instantiate.h" diff --git a/pytorch3d/csrc/pulsar/include/README.md b/pytorch3d/csrc/pulsar/include/README.md new file mode 100644 index 00000000..e963ff04 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/README.md @@ -0,0 +1,16 @@ +# The `include` folder + +This folder contains header files with implementations of several useful +algorithms. These implementations are usually done in files called `x.device.h` +and use macros that route every device specific command to the right +implementation (see `commands.h`). + +If you're using a device specific implementation, include `x.device.h`. +This gives you the high-speed, device specific implementation that lets +you work with all the details of the datastructure. All function calls are +inlined. If you need to work with the high-level interface and be able to +dynamically pick a device, only include `x.h`. The functions there are +templated with a boolean `DEV` flag and are instantiated in device specific +compilation units. You will not be able to use any other functions, but can +use `func(params)` to work on a CUDA device, or `func(params)` +to work on the host. diff --git a/pytorch3d/csrc/pulsar/include/camera.device.h b/pytorch3d/csrc/pulsar/include/camera.device.h new file mode 100644 index 00000000..5633bc62 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/camera.device.h @@ -0,0 +1,18 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_CAMERA_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_CAMERA_DEVICE_H_ + +#include "../global.h" +#include "./camera.h" +#include "./commands.h" + +namespace pulsar { +IHD CamGradInfo::CamGradInfo() { + cam_pos = make_float3(0.f, 0.f, 0.f); + pixel_0_0_center = make_float3(0.f, 0.f, 0.f); + pixel_dir_x = make_float3(0.f, 0.f, 0.f); + pixel_dir_y = make_float3(0.f, 0.f, 0.f); +} +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/camera.h b/pytorch3d/csrc/pulsar/include/camera.h new file mode 100644 index 00000000..2bf2a454 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/camera.h @@ -0,0 +1,72 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_CAMERA_H_ +#define PULSAR_NATIVE_INCLUDE_CAMERA_H_ + +#include "../global.h" + +namespace pulsar { +/** + * Everything that's needed to raycast with our camera model. + */ +struct CamInfo { + float3 eye; /** Position in world coordinates. */ + float3 pixel_0_0_center; /** LUC center of pixel position in world + coordinates. */ + float3 pixel_dir_x; /** Direction for increasing x for one pixel to the next, + * in world coordinates. */ + float3 pixel_dir_y; /** Direction for increasing y for one pixel to the next, + * in world coordinates. */ + float3 sensor_dir_z; /** Normalized direction vector from eye through the + * sensor in z direction (optical axis). */ + float half_pixel_size; /** Half size of a pixel, in world coordinates. This + * must be consistent with pixel_dir_x and pixel_dir_y! + */ + float focal_length; /** The focal length, if applicable. */ + uint aperture_width; /** Full image width in px, possibly not fully used + * in case of a shifted principal point. */ + uint aperture_height; /** Full image height in px, possibly not fully used + * in case of a shifted principal point. */ + uint film_width; /** Resulting image width. */ + uint film_height; /** Resulting image height. */ + /** The top left coordinates (inclusive) of the film in the full aperture. */ + uint film_border_left, film_border_top; + int32_t principal_point_offset_x; /** Horizontal principal point offset. */ + int32_t principal_point_offset_y; /** Vertical principal point offset. */ + float min_dist; /** Minimum distance for a ball to be rendered. */ + float max_dist; /** Maximum distance for a ball to be rendered. */ + float norm_fac; /** 1 / (max_dist - min_dist), pre-computed. */ + /** The depth where to place the background, in normalized coordinates where + * 0. is the backmost depth and 1. the frontmost. */ + float background_normalization_depth; + /** The number of image content channels to use. Usually three. */ + uint n_channels; + /** Whether to use an orthogonal instead of a perspective projection. */ + bool orthogonal_projection; + /** Whether to use a right-handed system (inverts the z axis). */ + bool right_handed; +}; + +inline bool operator==(const CamInfo& a, const CamInfo& b) { + return a.film_width == b.film_width && a.film_height == b.film_height && + a.background_normalization_depth == b.background_normalization_depth && + a.n_channels == b.n_channels && + a.orthogonal_projection == b.orthogonal_projection && + a.right_handed == b.right_handed; +}; + +struct CamGradInfo { + HOST DEVICE CamGradInfo(); + float3 cam_pos; + float3 pixel_0_0_center; + float3 pixel_dir_x; + float3 pixel_dir_y; +}; + +// TODO: remove once https://github.com/NVlabs/cub/issues/172 is resolved. +struct IntWrapper { + int val; +}; + +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/closest_sphere_tracker.device.h b/pytorch3d/csrc/pulsar/include/closest_sphere_tracker.device.h new file mode 100644 index 00000000..85423e50 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/closest_sphere_tracker.device.h @@ -0,0 +1,131 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_CLOSEST_SPHERE_TRACKER_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_CLOSEST_SPHERE_TRACKER_DEVICE_H_ + +#include "../global.h" + +namespace pulsar { +namespace Renderer { + +/** + * A facility to track the closest spheres to the camera. + * + * Their max number is defined by MAX_GRAD_SPHERES (this is defined in + * `pulsar/native/global.h`). This is done to keep the performance as high as + * possible because this struct needs to do updates continuously on the GPU. + */ +struct ClosestSphereTracker { + public: + IHD ClosestSphereTracker(const int& n_track) : n_hits(0), n_track(n_track) { + PASSERT(n_track < MAX_GRAD_SPHERES); + // Initialize the sphere IDs to -1 and the weights to 0. + for (int i = 0; i < n_track; ++i) { + this->most_important_sphere_ids[i] = -1; + this->closest_sphere_intersection_depths[i] = MAX_FLOAT; + } + }; + + IHD void track( + const uint& sphere_idx, + const float& intersection_depth, + const uint& coord_x, + const uint& coord_y) { + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_TRACKER_PIX, + "tracker|tracking sphere %u (depth: %f).\n", + sphere_idx, + intersection_depth); + for (int i = IMIN(this->n_hits, n_track) - 1; i >= -1; --i) { + if (i < 0 || + this->closest_sphere_intersection_depths[i] < intersection_depth) { + // Write position is i+1. + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_TRACKER_PIX, + "tracker|determined writing position: %d.\n", + i + 1); + if (i + 1 < n_track) { + // Shift every other sphere back. + for (int j = n_track - 1; j > i + 1; --j) { + this->closest_sphere_intersection_depths[j] = + this->closest_sphere_intersection_depths[j - 1]; + this->most_important_sphere_ids[j] = + this->most_important_sphere_ids[j - 1]; + } + this->closest_sphere_intersection_depths[i + 1] = intersection_depth; + this->most_important_sphere_ids[i + 1] = sphere_idx; + } + break; + } + } +#if PULSAR_LOG_TRACKER_PIX + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_TRACKER_PIX, + "tracker|sphere list after adding sphere %u:\n", + sphere_idx); + for (int i = 0; i < n_track; ++i) { + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_TRACKER_PIX, + "tracker|sphere %d: %d (depth: %f).\n", + i, + this->most_important_sphere_ids[i], + this->closest_sphere_intersection_depths[i]); + } +#endif // PULSAR_LOG_TRACKER_PIX + this->n_hits += 1; + } + + /** + * Get the number of hits registered. + */ + IHD int get_n_hits() const { + return this->n_hits; + } + + /** + * Get the idx closest sphere ID. + * + * For example, get_closest_sphere_id(0) gives the overall closest + * sphere id. + * + * This method is implemented for highly optimized scenarios and will *not* + * perform an index check at runtime if assertions are disabled. idx must be + * >=0 and < IMIN(n_hits, n_track) for a valid result, if it is >= + * n_hits it will return -1. + */ + IHD int get_closest_sphere_id(const int& idx) { + PASSERT(idx >= 0 && idx < n_track); + return this->most_important_sphere_ids[idx]; + } + + /** + * Get the idx closest sphere normalized_depth. + * + * For example, get_closest_sphere_depth(0) gives the overall closest + * sphere depth (normalized). + * + * This method is implemented for highly optimized scenarios and will *not* + * perform an index check at runtime if assertions are disabled. idx must be + * >=0 and < IMIN(n_hits, n_track) for a valid result, if it is >= + * n_hits it will return 1. + FEPS. + */ + IHD float get_closest_sphere_depth(const int& idx) { + PASSERT(idx >= 0 && idx < n_track); + return this->closest_sphere_intersection_depths[idx]; + } + + private: + /** The number of registered hits so far. */ + int n_hits; + /** The number of intersections to track. Must be (malloc(sizeof(TYPE) * (SIZE))) +#define FREE_HOST(PTR) free(PTR) + +/* Include command definitions depending on CPU or GPU use. */ + +#ifdef __CUDACC__ +// TODO: find out which compiler we're using here and use the suppression. +// #pragma push +// #pragma diag_suppress = 68 +#include +#include +// #pragma pop +#include "../cuda/commands.h" +#else +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Weverything" +#include +#pragma clang diagnostic pop +#include "../host/commands.h" +#endif + +#endif diff --git a/pytorch3d/csrc/pulsar/include/fastermath.h b/pytorch3d/csrc/pulsar/include/fastermath.h new file mode 100644 index 00000000..2276ae6a --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/fastermath.h @@ -0,0 +1,87 @@ +#ifndef PULSAR_NATIVE_INCLUDE_FASTERMATH_H_ +#define PULSAR_NATIVE_INCLUDE_FASTERMATH_H_ + +/*=====================================================================* + * Copyright (C) 2011 Paul Mineiro * + * All rights reserved. * + * * + * Redistribution and use in source and binary forms, with * + * or without modification, are permitted provided that the * + * following conditions are met: * + * * + * * Redistributions of source code must retain the * + * above copyright notice, this list of conditions and * + * the following disclaimer. * + * * + * * Redistributions in binary form must reproduce the * + * above copyright notice, this list of conditions and * + * the following disclaimer in the documentation and/or * + * other materials provided with the distribution. * + * * + * * Neither the name of Paul Mineiro nor the names * + * of other contributors may be used to endorse or promote * + * products derived from this software without specific * + * prior written permission. * + * * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND * + * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, * + * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES * + * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE * + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER * + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, * + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES * + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE * + * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR * + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF * + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY * + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE * + * POSSIBILITY OF SUCH DAMAGE. * + * * + * Contact: Paul Mineiro * + *=====================================================================*/ + +#include +#include "./commands.h" + +#ifdef __cplusplus +#define cast_uint32_t static_cast +#else +#define cast_uint32_t (uint32_t) +#endif + +IHD float fasterlog2(float x) { + union { + float f; + uint32_t i; + } vx = {x}; + float y = vx.i; + y *= 1.1920928955078125e-7f; + return y - 126.94269504f; +} + +IHD float fasterlog(float x) { + // return 0.69314718f * fasterlog2 (x); + union { + float f; + uint32_t i; + } vx = {x}; + float y = vx.i; + y *= 8.2629582881927490e-8f; + return y - 87.989971088f; +} + +IHD float fasterpow2(float p) { + float clipp = (p < -126) ? -126.0f : p; + union { + uint32_t i; + float f; + } v = {cast_uint32_t((1 << 23) * (clipp + 126.94269504f))}; + return v.f; +} + +IHD float fasterexp(float p) { + return fasterpow2(1.442695040f * p); +} + +#endif diff --git a/pytorch3d/csrc/pulsar/include/math.h b/pytorch3d/csrc/pulsar/include/math.h new file mode 100644 index 00000000..48995bcc --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/math.h @@ -0,0 +1,150 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_IMPL_MATH_H_ +#define PULSAR_NATIVE_IMPL_MATH_H_ + +#include "./camera.h" +#include "./commands.h" +#include "./fastermath.h" + +/** + * Get the direction of val. + * + * Returns +1 if val is positive, -1 if val is zero or negative. + */ +IHD int sign_dir(const int& val) { + return -(static_cast((val <= 0)) << 1) + 1; +}; + +/** + * Get the direction of val. + * + * Returns +1 if val is positive, -1 if val is zero or negative. + */ +IHD float sign_dir(const float& val) { + return static_cast(1 - (static_cast((val <= 0)) << 1)); +}; + +/** + * Integer ceil division. + */ +IHD uint iDivCeil(uint a, uint b) { + return (a % b != 0) ? (a / b + 1) : (a / b); +} + +IHD float3 outer_product_sum(const float3& a) { + return make_float3( + a.x * a.x + a.x * a.y + a.x * a.z, + a.x * a.y + a.y * a.y + a.y * a.z, + a.x * a.z + a.y * a.z + a.z * a.z); +} + +// TODO: put intrinsics here. +IHD float3 operator+(const float3& a, const float3& b) { + return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); +} + +IHD void operator+=(float3& a, const float3& b) { + a.x += b.x; + a.y += b.y; + a.z += b.z; +} + +IHD void operator-=(float3& a, const float3& b) { + a.x -= b.x; + a.y -= b.y; + a.z -= b.z; +} + +IHD void operator/=(float3& a, const float& b) { + a.x /= b; + a.y /= b; + a.z /= b; +} + +IHD void operator*=(float3& a, const float& b) { + a.x *= b; + a.y *= b; + a.z *= b; +} + +IHD float3 operator/(const float3& a, const float& b) { + return make_float3(a.x / b, a.y / b, a.z / b); +} + +IHD float3 operator-(const float3& a, const float3& b) { + return make_float3(a.x - b.x, a.y - b.y, a.z - b.z); +} + +IHD float3 operator*(const float3& a, const float& b) { + return make_float3(a.x * b, a.y * b, a.z * b); +} + +IHD float3 operator*(const float3& a, const float3& b) { + return make_float3(a.x * b.x, a.y * b.y, a.z * b.z); +} + +IHD float3 operator*(const float& a, const float3& b) { + return b * a; +} + +IHD float length(const float3& v) { + // TODO: benchmark what's faster. + return NORM3DF(v.x, v.y, v.z); + // return __fsqrt_rn(v.x * v.x + v.y * v.y + v.z * v.z); +} + +/** + * Left-hand multiplication of the constructed rotation matrix with the vector. + */ +IHD float3 rotate( + const float3& v, + const float3& dir_x, + const float3& dir_y, + const float3& dir_z) { + return make_float3( + dir_x.x * v.x + dir_x.y * v.y + dir_x.z * v.z, + dir_y.x * v.x + dir_y.y * v.y + dir_y.z * v.z, + dir_z.x * v.x + dir_z.y * v.y + dir_z.z * v.z); +} + +IHD float3 normalize(const float3& v) { + return v * RNORM3DF(v.x, v.y, v.z); +} + +INLINE DEVICE float dot(const float3& a, const float3& b) { + return FADD(FADD(FMUL(a.x, b.x), FMUL(a.y, b.y)), FMUL(a.z, b.z)); +} + +INLINE DEVICE float3 cross(const float3& a, const float3& b) { + // TODO: faster + return make_float3( + a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x); +} + +namespace pulsar { +IHD CamGradInfo operator+(const CamGradInfo& a, const CamGradInfo& b) { + CamGradInfo res; + res.cam_pos = a.cam_pos + b.cam_pos; + res.pixel_0_0_center = a.pixel_0_0_center + b.pixel_0_0_center; + res.pixel_dir_x = a.pixel_dir_x + b.pixel_dir_x; + res.pixel_dir_y = a.pixel_dir_y + b.pixel_dir_y; + return res; +} + +IHD CamGradInfo operator*(const CamGradInfo& a, const float& b) { + CamGradInfo res; + res.cam_pos = a.cam_pos * b; + res.pixel_0_0_center = a.pixel_0_0_center * b; + res.pixel_dir_x = a.pixel_dir_x * b; + res.pixel_dir_y = a.pixel_dir_y * b; + return res; +} + +IHD IntWrapper operator+(const IntWrapper& a, const IntWrapper& b) { + IntWrapper res; + res.val = a.val + b.val; + return res; +} +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.backward.device.h b/pytorch3d/csrc/pulsar/include/renderer.backward.device.h new file mode 100644 index 00000000..175be513 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.backward.device.h @@ -0,0 +1,182 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_RENDERER_BACKWARD_DEVICE_H_ +#define PULSAR_NATIVE_RENDERER_BACKWARD_DEVICE_H_ + +#include "./camera.device.h" +#include "./math.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +template +void backward( + Renderer* self, + const float* grad_im, + const float* image, + const float* forw_info, + const float* vert_pos, + const float* vert_col, + const float* vert_rad, + const CamInfo& cam, + const float& gamma, + float percent_allowed_difference, + const uint& max_n_hits, + const float* vert_opy_d, + const size_t& num_balls, + const uint& mode, + const bool& dif_pos, + const bool& dif_col, + const bool& dif_rad, + const bool& dif_cam, + const bool& dif_opy, + cudaStream_t stream) { + ARGCHECK(gamma > 0.f && gamma <= 1.f, 6, "gamma must be in [0., 1.]"); + ARGCHECK( + percent_allowed_difference >= 0.f && percent_allowed_difference <= 1.f, + 7, + "percent_allowed_difference must be in [0., 1.]"); + ARGCHECK(max_n_hits >= 1u, 8, "max_n_hits must be >= 1"); + ARGCHECK( + num_balls > 0 && num_balls <= self->max_num_balls, + 9, + "num_balls must be >0 and less than max num balls!"); + ARGCHECK( + cam.film_width == self->cam.film_width && + cam.film_height == self->cam.film_height, + 5, + "cam film size must agree"); + ARGCHECK(mode <= 1, 10, "mode must be <= 1!"); + if (percent_allowed_difference < EPS) { + LOG(WARNING) << "percent_allowed_difference < " << FEPS << "! Clamping to " + << FEPS << "."; + percent_allowed_difference = FEPS; + } + if (percent_allowed_difference > 1.f - FEPS) { + LOG(WARNING) << "percent_allowed_difference > " << (1.f - FEPS) + << "! Clamping to " << (1.f - FEPS) << "."; + percent_allowed_difference = 1.f - FEPS; + } + LOG_IF(INFO, PULSAR_LOG_RENDER) << "Rendering backward pass..."; + // Update camera. + self->cam.eye = cam.eye; + self->cam.pixel_0_0_center = cam.pixel_0_0_center - cam.eye; + self->cam.pixel_dir_x = cam.pixel_dir_x; + self->cam.pixel_dir_y = cam.pixel_dir_y; + self->cam.sensor_dir_z = cam.sensor_dir_z; + self->cam.half_pixel_size = cam.half_pixel_size; + self->cam.focal_length = cam.focal_length; + self->cam.aperture_width = cam.aperture_width; + self->cam.aperture_height = cam.aperture_height; + self->cam.min_dist = cam.min_dist; + self->cam.max_dist = cam.max_dist; + self->cam.norm_fac = cam.norm_fac; + self->cam.principal_point_offset_x = cam.principal_point_offset_x; + self->cam.principal_point_offset_y = cam.principal_point_offset_y; + self->cam.film_border_left = cam.film_border_left; + self->cam.film_border_top = cam.film_border_top; +#ifdef PULSAR_TIMINGS_ENABLED + START_TIME(calc_signature); +#endif + LAUNCH_MAX_PARALLEL_1D( + calc_signature, + num_balls, + stream, + *self, + reinterpret_cast(vert_pos), + vert_col, + vert_rad, + num_balls); + CHECKLAUNCH(); +#ifdef PULSAR_TIMINGS_ENABLED + STOP_TIME(calc_signature); + START_TIME(calc_gradients); +#endif + MEMSET(self->grad_pos_d, 0, float3, num_balls, stream); + MEMSET(self->grad_col_d, 0, float, num_balls * self->cam.n_channels, stream); + MEMSET(self->grad_rad_d, 0, float, num_balls, stream); + MEMSET(self->grad_cam_d, 0, float, 12, stream); + MEMSET(self->grad_cam_buf_d, 0, CamGradInfo, num_balls, stream); + MEMSET(self->grad_opy_d, 0, float, num_balls, stream); + MEMSET(self->ids_sorted_d, 0, int, num_balls, stream); + LAUNCH_PARALLEL_2D( + calc_gradients, + self->cam.film_width, + self->cam.film_height, + GRAD_BLOCK_SIZE, + GRAD_BLOCK_SIZE, + stream, + self->cam, + grad_im, + gamma, + reinterpret_cast(vert_pos), + vert_col, + vert_rad, + vert_opy_d, + num_balls, + image, + forw_info, + self->di_d, + self->ii_d, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + self->grad_rad_d, + self->grad_col_d, + self->grad_pos_d, + self->grad_cam_buf_d, + self->grad_opy_d, + self->ids_sorted_d, + self->n_track); + CHECKLAUNCH(); +#ifdef PULSAR_TIMINGS_ENABLED + STOP_TIME(calc_gradients); + START_TIME(normalize); +#endif + LAUNCH_MAX_PARALLEL_1D( + norm_sphere_gradients, num_balls, stream, *self, num_balls); + CHECKLAUNCH(); + if (dif_cam) { + SUM_WS( + self->grad_cam_buf_d, + reinterpret_cast(self->grad_cam_d), + static_cast(num_balls), + self->workspace_d, + self->workspace_size, + stream); + CHECKLAUNCH(); + SUM_WS( + (IntWrapper*)(self->ids_sorted_d), + (IntWrapper*)(self->n_grad_contributions_d), + static_cast(num_balls), + self->workspace_d, + self->workspace_size, + stream); + CHECKLAUNCH(); + LAUNCH_MAX_PARALLEL_1D( + norm_cam_gradients, static_cast(1), stream, *self); + CHECKLAUNCH(); + } +#ifdef PULSAR_TIMINGS_ENABLED + STOP_TIME(normalize); + float time_ms; + // This blocks the result and prevents batch-processing from parallelizing. + GET_TIME(calc_signature, &time_ms); + std::cout << "Time for signature calculation: " << time_ms << " ms" + << std::endl; + GET_TIME(calc_gradients, &time_ms); + std::cout << "Time for gradient calculation: " << time_ms << " ms" + << std::endl; + GET_TIME(normalize, &time_ms); + std::cout << "Time for aggregation and normalization: " << time_ms << " ms" + << std::endl; +#endif + LOG_IF(INFO, PULSAR_LOG_RENDER) << "Backward pass complete."; +} + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.backward.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.backward.instantiate.h new file mode 100644 index 00000000..3ad16599 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.backward.instantiate.h @@ -0,0 +1,30 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./renderer.backward.device.h" + +namespace pulsar { +namespace Renderer { + +template void backward( + Renderer* self, + const float* grad_im, + const float* image, + const float* forw_info, + const float* vert_pos, + const float* vert_col, + const float* vert_rad, + const CamInfo& cam, + const float& gamma, + float percent_allowed_difference, + const uint& max_n_hits, + const float* vert_opy, + const size_t& num_balls, + const uint& mode, + const bool& dif_pos, + const bool& dif_col, + const bool& dif_rad, + const bool& dif_cam, + const bool& dif_opy, + cudaStream_t stream); + +} // namespace Renderer +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/include/renderer.backward_dbg.device.h b/pytorch3d/csrc/pulsar/include/renderer.backward_dbg.device.h new file mode 100644 index 00000000..5e1c0172 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.backward_dbg.device.h @@ -0,0 +1,150 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_RENDERER_BACKWARD_DBG_DEVICE_H_ +#define PULSAR_NATIVE_RENDERER_BACKWARD_DBG_DEVICE_H_ + +#include "./camera.device.h" +#include "./math.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +template +void backward_dbg( + Renderer* self, + const float* grad_im, + const float* image, + const float* forw_info, + const float* vert_pos, + const float* vert_col, + const float* vert_rad, + const CamInfo& cam, + const float& gamma, + float percent_allowed_difference, + const uint& max_n_hits, + const float* vert_opy_d, + const size_t& num_balls, + const uint& mode, + const bool& dif_pos, + const bool& dif_col, + const bool& dif_rad, + const bool& dif_cam, + const bool& dif_opy, + const uint& pos_x, + const uint& pos_y, + cudaStream_t stream) { + ARGCHECK(gamma > 0.f && gamma <= 1.f, 6, "gamma must be in [0., 1.]"); + ARGCHECK( + percent_allowed_difference >= 0.f && percent_allowed_difference <= 1.f, + 7, + "percent_allowed_difference must be in [0., 1.]"); + ARGCHECK(max_n_hits >= 1u, 8, "max_n_hits must be >= 1"); + ARGCHECK( + num_balls > 0 && num_balls <= self->max_num_balls, + 9, + "num_balls must be >0 and less than max num balls!"); + ARGCHECK( + cam.film_width == self->cam.film_width && + cam.film_height == self->cam.film_height, + 5, + "cam film size must agree"); + ARGCHECK(mode <= 1, 10, "mode must be <= 1!"); + if (percent_allowed_difference < EPS) { + LOG(WARNING) << "percent_allowed_difference < " << FEPS << "! Clamping to " + << FEPS << "."; + percent_allowed_difference = FEPS; + } + ARGCHECK( + pos_x < cam.film_width && pos_y < cam.film_height, + 15, + "pos_x must be < width and pos_y < height."); + if (percent_allowed_difference > 1.f - FEPS) { + LOG(WARNING) << "percent_allowed_difference > " << (1.f - FEPS) + << "! Clamping to " << (1.f - FEPS) << "."; + percent_allowed_difference = 1.f - FEPS; + } + LOG_IF(INFO, PULSAR_LOG_RENDER) + << "Rendering debug backward pass for x: " << pos_x << ", y: " << pos_y; + // Update camera. + self->cam.eye = cam.eye; + self->cam.pixel_0_0_center = cam.pixel_0_0_center - cam.eye; + self->cam.pixel_dir_x = cam.pixel_dir_x; + self->cam.pixel_dir_y = cam.pixel_dir_y; + self->cam.sensor_dir_z = cam.sensor_dir_z; + self->cam.half_pixel_size = cam.half_pixel_size; + self->cam.focal_length = cam.focal_length; + self->cam.aperture_width = cam.aperture_width; + self->cam.aperture_height = cam.aperture_height; + self->cam.min_dist = cam.min_dist; + self->cam.max_dist = cam.max_dist; + self->cam.norm_fac = cam.norm_fac; + self->cam.principal_point_offset_x = cam.principal_point_offset_x; + self->cam.principal_point_offset_y = cam.principal_point_offset_y; + self->cam.film_border_left = cam.film_border_left; + self->cam.film_border_top = cam.film_border_top; + LAUNCH_MAX_PARALLEL_1D( + calc_signature, + num_balls, + stream, + *self, + reinterpret_cast(vert_pos), + vert_col, + vert_rad, + num_balls); + CHECKLAUNCH(); + MEMSET(self->grad_pos_d, 0, float3, num_balls, stream); + MEMSET(self->grad_col_d, 0, float, num_balls * self->cam.n_channels, stream); + MEMSET(self->grad_rad_d, 0, float, num_balls, stream); + MEMSET(self->grad_cam_d, 0, float, 12, stream); + MEMSET(self->grad_cam_buf_d, 0, CamGradInfo, num_balls, stream); + MEMSET(self->grad_opy_d, 0, float, num_balls, stream); + MEMSET(self->ids_sorted_d, 0, int, num_balls, stream); + LAUNCH_MAX_PARALLEL_2D( + calc_gradients, + (int64_t)1, + (int64_t)1, + stream, + self->cam, + grad_im, + gamma, + reinterpret_cast(vert_pos), + vert_col, + vert_rad, + vert_opy_d, + num_balls, + image, + forw_info, + self->di_d, + self->ii_d, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + self->grad_rad_d, + self->grad_col_d, + self->grad_pos_d, + self->grad_cam_buf_d, + self->grad_opy_d, + self->ids_sorted_d, + self->n_track, + pos_x, + pos_y); + CHECKLAUNCH(); + // We're not doing sphere gradient normalization here. + SUM_WS( + self->grad_cam_buf_d, + reinterpret_cast(self->grad_cam_d), + static_cast(1), + self->workspace_d, + self->workspace_size, + stream); + CHECKLAUNCH(); + // We're not doing camera gradient normalization here. + LOG_IF(INFO, PULSAR_LOG_RENDER) << "Debug backward pass complete."; +} + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.backward_dbg.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.backward_dbg.instantiate.h new file mode 100644 index 00000000..c15108f9 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.backward_dbg.instantiate.h @@ -0,0 +1,32 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./renderer.backward_dbg.device.h" + +namespace pulsar { +namespace Renderer { + +template void backward_dbg( + Renderer* self, + const float* grad_im, + const float* image, + const float* forw_info, + const float* vert_pos, + const float* vert_col, + const float* vert_rad, + const CamInfo& cam, + const float& gamma, + float percent_allowed_difference, + const uint& max_n_hits, + const float* vert_opy, + const size_t& num_balls, + const uint& mode, + const bool& dif_pos, + const bool& dif_col, + const bool& dif_rad, + const bool& dif_cam, + const bool& dif_opy, + const uint& pos_x, + const uint& pos_y, + cudaStream_t stream); + +} // namespace Renderer +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/include/renderer.calc_gradients.device.h b/pytorch3d/csrc/pulsar/include/renderer.calc_gradients.device.h new file mode 100644 index 00000000..edf6ea8b --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.calc_gradients.device.h @@ -0,0 +1,191 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CALC_GRADIENTS_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_CALC_GRADIENTS_H_ + +#include "../global.h" +#include "./commands.h" +#include "./renderer.h" + +#include "./renderer.draw.device.h" + +namespace pulsar { +namespace Renderer { + +template +GLOBAL void calc_gradients( + const CamInfo cam, /** Camera in world coordinates. */ + float const* const RESTRICT grad_im, /** The gradient image. */ + const float + gamma, /** The transparency parameter used in the forward pass. */ + float3 const* const RESTRICT vert_poss, /** Vertex position vector. */ + float const* const RESTRICT vert_cols, /** Vertex color vector. */ + float const* const RESTRICT vert_rads, /** Vertex radius vector. */ + float const* const RESTRICT opacity, /** Vertex opacity. */ + const uint num_balls, /** Number of balls. */ + float const* const RESTRICT result_d, /** Result image. */ + float const* const RESTRICT forw_info_d, /** Forward pass info. */ + DrawInfo const* const RESTRICT di_d, /** Draw information. */ + IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */ + // Mode switches. + const bool calc_grad_pos, + const bool calc_grad_col, + const bool calc_grad_rad, + const bool calc_grad_cam, + const bool calc_grad_opy, + // Out variables. + float* const RESTRICT grad_rad_d, /** Radius gradients. */ + float* const RESTRICT grad_col_d, /** Color gradients. */ + float3* const RESTRICT grad_pos_d, /** Position gradients. */ + CamGradInfo* const RESTRICT grad_cam_buf_d, /** Camera gradient buffer. */ + float* const RESTRICT grad_opy_d, /** Opacity gradient buffer. */ + int* const RESTRICT + grad_contributed_d, /** Gradient contribution counter. */ + // Infrastructure. + const int n_track, + const uint offs_x, + const uint offs_y /** Debug offsets. */ +) { + uint limit_x = cam.film_width, limit_y = cam.film_height; + if (offs_x != 0) { + // We're in debug mode. + limit_x = 1; + limit_y = 1; + } + GET_PARALLEL_IDS_2D(coord_x_base, coord_y_base, limit_x, limit_y); + // coord_x_base and coord_y_base are in the film coordinate system. + // We now need to translate to the aperture coordinate system. If + // the principal point was shifted left/up nothing has to be + // subtracted - only shift needs to be added in case it has been + // shifted down/right. + const uint film_coord_x = coord_x_base + offs_x; + const uint ap_coord_x = film_coord_x + + 2 * static_cast(std::max(0, cam.principal_point_offset_x)); + const uint film_coord_y = coord_y_base + offs_y; + const uint ap_coord_y = film_coord_y + + 2 * static_cast(std::max(0, cam.principal_point_offset_y)); + const float3 ray_dir = /** Ray cast through the pixel, normalized. */ + cam.pixel_0_0_center + ap_coord_x * cam.pixel_dir_x + + ap_coord_y * cam.pixel_dir_y; + const float norm_ray_dir = length(ray_dir); + // ray_dir_norm *must* be calculated here in the same way as in the draw + // function to have the same values withno other numerical instabilities + // (for example, ray_dir * FRCP(norm_ray_dir) does not work)! + float3 ray_dir_norm; /** Ray cast through the pixel, normalized. */ + float2 projected_ray; /** Ray intersection with the sensor. */ + if (cam.orthogonal_projection) { + ray_dir_norm = cam.sensor_dir_z; + projected_ray.x = static_cast(ap_coord_x); + projected_ray.y = static_cast(ap_coord_y); + } else { + ray_dir_norm = normalize( + cam.pixel_0_0_center + ap_coord_x * cam.pixel_dir_x + + ap_coord_y * cam.pixel_dir_y); + // This is a reasonable assumption for normal focal lengths and image sizes. + PASSERT(FABS(ray_dir_norm.z) > FEPS); + projected_ray.x = ray_dir_norm.x / ray_dir_norm.z * cam.focal_length; + projected_ray.y = ray_dir_norm.y / ray_dir_norm.z * cam.focal_length; + } + float* result = const_cast( + result_d + film_coord_y * cam.film_width * cam.n_channels + + film_coord_x * cam.n_channels); + const float* grad_im_l = grad_im + + film_coord_y * cam.film_width * cam.n_channels + + film_coord_x * cam.n_channels; + // For writing... + float3 grad_pos; + float grad_rad, grad_opy; + CamGradInfo grad_cam_local = CamGradInfo(); + // Set up shared infrastructure. + const int fwi_loc = film_coord_y * cam.film_width * (3 + 2 * n_track) + + film_coord_x * (3 + 2 * n_track); + float sm_m = forw_info_d[fwi_loc]; + float sm_d = forw_info_d[fwi_loc + 1]; + PULSAR_LOG_DEV_APIX( + PULSAR_LOG_GRAD, + "grad|sm_m: %f, sm_d: %f, result: " + "%f, %f, %f; grad_im: %f, %f, %f.\n", + sm_m, + sm_d, + result[0], + result[1], + result[2], + grad_im_l[0], + grad_im_l[1], + grad_im_l[2]); + // Start processing. + for (int grad_idx = 0; grad_idx < n_track; ++grad_idx) { + int sphere_idx; + FASI(forw_info_d[fwi_loc + 3 + 2 * grad_idx], sphere_idx); + PASSERT( + sphere_idx == -1 || + sphere_idx >= 0 && static_cast(sphere_idx) < num_balls); + if (sphere_idx >= 0) { + // TODO: make more efficient. + grad_pos = make_float3(0.f, 0.f, 0.f); + grad_rad = 0.f; + grad_cam_local = CamGradInfo(); + const DrawInfo di = di_d[sphere_idx]; + grad_opy = 0.f; + draw( + di, + opacity == NULL ? 1.f : opacity[sphere_idx], + cam, + gamma, + ray_dir_norm, + projected_ray, + // Mode switches. + false, // draw only + calc_grad_pos, + calc_grad_col, + calc_grad_rad, + calc_grad_cam, + calc_grad_opy, + // Position info. + ap_coord_x, + ap_coord_y, + sphere_idx, + // Optional in. + &ii_d[sphere_idx], + &ray_dir, + &norm_ray_dir, + grad_im_l, + NULL, + // In/out + &sm_d, + &sm_m, + result, + // Optional out. + NULL, + NULL, + &grad_pos, + grad_col_d + sphere_idx * cam.n_channels, + &grad_rad, + &grad_cam_local, + &grad_opy); + ATOMICADD(&(grad_rad_d[sphere_idx]), grad_rad); + // Color has been added directly. + ATOMICADD_F3(&(grad_pos_d[sphere_idx]), grad_pos); + ATOMICADD_F3( + &(grad_cam_buf_d[sphere_idx].cam_pos), grad_cam_local.cam_pos); + if (!cam.orthogonal_projection) { + ATOMICADD_F3( + &(grad_cam_buf_d[sphere_idx].pixel_0_0_center), + grad_cam_local.pixel_0_0_center); + } + ATOMICADD_F3( + &(grad_cam_buf_d[sphere_idx].pixel_dir_x), + grad_cam_local.pixel_dir_x); + ATOMICADD_F3( + &(grad_cam_buf_d[sphere_idx].pixel_dir_y), + grad_cam_local.pixel_dir_y); + ATOMICADD(&(grad_opy_d[sphere_idx]), grad_opy); + ATOMICADD(&(grad_contributed_d[sphere_idx]), 1); + } + } + END_PARALLEL_2D_NORET(); +}; + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.calc_gradients.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.calc_gradients.instantiate.h new file mode 100644 index 00000000..14c70386 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.calc_gradients.instantiate.h @@ -0,0 +1,41 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./renderer.calc_gradients.device.h" + +namespace pulsar { +namespace Renderer { + +template GLOBAL void calc_gradients( + const CamInfo cam, /** Camera in world coordinates. */ + float const* const RESTRICT grad_im, /** The gradient image. */ + const float + gamma, /** The transparency parameter used in the forward pass. */ + float3 const* const RESTRICT vert_poss, /** Vertex position vector. */ + float const* const RESTRICT vert_cols, /** Vertex color vector. */ + float const* const RESTRICT vert_rads, /** Vertex radius vector. */ + float const* const RESTRICT opacity, /** Vertex opacity. */ + const uint num_balls, /** Number of balls. */ + float const* const RESTRICT result_d, /** Result image. */ + float const* const RESTRICT forw_info_d, /** Forward pass info. */ + DrawInfo const* const RESTRICT di_d, /** Draw information. */ + IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */ + // Mode switches. + const bool calc_grad_pos, + const bool calc_grad_col, + const bool calc_grad_rad, + const bool calc_grad_cam, + const bool calc_grad_opy, + // Out variables. + float* const RESTRICT grad_rad_d, /** Radius gradients. */ + float* const RESTRICT grad_col_d, /** Color gradients. */ + float3* const RESTRICT grad_pos_d, /** Position gradients. */ + CamGradInfo* const RESTRICT grad_cam_buf_d, /** Camera gradient buffer. */ + float* const RESTRICT grad_opy_d, /** Opacity gradient buffer. */ + int* const RESTRICT + grad_contributed_d, /** Gradient contribution counter. */ + // Infrastructure. + const int n_track, + const uint offs_x, + const uint offs_y); + +} // namespace Renderer +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/include/renderer.calc_signature.device.h b/pytorch3d/csrc/pulsar/include/renderer.calc_signature.device.h new file mode 100644 index 00000000..84b3e0aa --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.calc_signature.device.h @@ -0,0 +1,194 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_DEVICE_H_ + +#include "../global.h" +#include "./camera.device.h" +#include "./commands.h" +#include "./math.h" +#include "./renderer.get_screen_area.device.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +template +GLOBAL void calc_signature( + Renderer renderer, + float3 const* const RESTRICT vert_poss, + float const* const RESTRICT vert_cols, + float const* const RESTRICT vert_rads, + const uint num_balls) { + /* We're not using RESTRICT here for the pointers within `renderer`. Just one + value is being read from each of the pointers, so the effect would be + negligible or non-existent. */ + GET_PARALLEL_IDX_1D(idx, num_balls); + // Create aliases. + // For reading... + const float3& vert_pos = vert_poss[idx]; /** Vertex position. */ + const float* vert_col = + vert_cols + idx * renderer.cam.n_channels; /** Vertex color. */ + const float& vert_rad = vert_rads[idx]; /** Vertex radius. */ + const CamInfo& cam = renderer.cam; /** Camera in world coordinates. */ + // For writing... + /** Ball ID (either original index of the ball or -1 if not visible). */ + int& id_out = renderer.ids_d[idx]; + /** Intersection helper structure for the ball. */ + IntersectInfo& intersect_helper_out = renderer.ii_d[idx]; + /** Draw helper structure for this ball. */ + DrawInfo& draw_helper_out = renderer.di_d[idx]; + /** Minimum possible intersection depth for this ball. */ + float& closest_possible_intersect_out = renderer.min_depth_d[idx]; + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|vert_pos: %.9f, %.9f, %.9f, vert_col (first three): " + "%.9f, %.9f, %.9f.\n", + idx, + vert_pos.x, + vert_pos.y, + vert_pos.z, + vert_col[0], + vert_col[1], + vert_col[2]); + // Set flags to invalid for a potential early return. + id_out = -1; // Invalid ID. + closest_possible_intersect_out = + MAX_FLOAT; // These spheres are sorted to the very end. + intersect_helper_out.max.x = MAX_USHORT; // No intersection possible. + intersect_helper_out.min.x = MAX_USHORT; + intersect_helper_out.max.y = MAX_USHORT; + intersect_helper_out.min.y = MAX_USHORT; + // Start processing. + /** Ball center in the camera coordinate system. */ + const float3 ball_center_cam = vert_pos - cam.eye; + /** Distance to the ball center in the camera coordinate system. */ + const float t_center = length(ball_center_cam); + /** Closest possible intersection with this ball from the camera. */ + float closest_possible_intersect; + if (cam.orthogonal_projection) { + const float3 ball_center_cam_rot = rotate( + ball_center_cam, + cam.pixel_dir_x / length(cam.pixel_dir_x), + cam.pixel_dir_y / length(cam.pixel_dir_y), + cam.sensor_dir_z); + closest_possible_intersect = ball_center_cam_rot.z - vert_rad; + } else { + closest_possible_intersect = t_center - vert_rad; + } + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|t_center: %f. vert_rad: %f. " + "closest_possible_intersect: %f.\n", + idx, + t_center, + vert_rad, + closest_possible_intersect); + /** + * Corner points of the enclosing projected rectangle of the ball. + * They are first calculated in the camera coordinate system, then + * converted to the pixel coordinate system. + */ + float x_1, x_2, y_1, y_2; + bool hits_screen_plane; + float3 ray_center_norm = ball_center_cam / t_center; + PASSERT(vert_rad >= 0.f); + if (closest_possible_intersect < cam.min_dist || + closest_possible_intersect > cam.max_dist) { + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|ignoring sphere out of min/max bounds: %.9f, " + "min: %.9f, max: %.9f.\n", + idx, + closest_possible_intersect, + cam.min_dist, + cam.max_dist); + RETURN_PARALLEL(); + } + // Find the relevant region on the screen plane. + hits_screen_plane = get_screen_area( + ball_center_cam, + ray_center_norm, + vert_rad, + cam, + idx, + &x_1, + &x_2, + &y_1, + &y_2); + if (!hits_screen_plane) + RETURN_PARALLEL(); + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|in pixels: x_1: %f, x_2: %f, y_1: %f, y_2: %f.\n", + idx, + x_1, + x_2, + y_1, + y_2); + // Check whether the pixel coordinates are on screen. + if (FMAX(x_1, x_2) <= static_cast(cam.film_border_left) || + FMIN(x_1, x_2) >= + static_cast(cam.film_border_left + cam.film_width) - 0.5f || + FMAX(y_1, y_2) <= static_cast(cam.film_border_top) || + FMIN(y_1, y_2) > + static_cast(cam.film_border_top + cam.film_height) - 0.5f) + RETURN_PARALLEL(); + // Write results. + id_out = idx; + intersect_helper_out.min.x = static_cast( + FMAX(FMIN(x_1, x_2), static_cast(cam.film_border_left))); + intersect_helper_out.min.y = static_cast( + FMAX(FMIN(y_1, y_2), static_cast(cam.film_border_top))); + // In the following calculations, the max that needs to be stored is + // exclusive. + // That means that the calculated value needs to be `ceil`ed and incremented + // to find the correct value. + intersect_helper_out.max.x = static_cast(FMIN( + FCEIL(FMAX(x_1, x_2)) + 1, + static_cast(cam.film_border_left + cam.film_width))); + intersect_helper_out.max.y = static_cast(FMIN( + FCEIL(FMAX(y_1, y_2)) + 1, + static_cast(cam.film_border_top + cam.film_height))); + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|limits after refining: x_1: %u, x_2: %u, " + "y_1: %u, y_2: %u.\n", + idx, + intersect_helper_out.min.x, + intersect_helper_out.max.x, + intersect_helper_out.min.y, + intersect_helper_out.max.y); + if (intersect_helper_out.min.x == MAX_USHORT) { + id_out = -1; + RETURN_PARALLEL(); + } + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|writing info. closest_possible_intersect: %.9f. " + "ray_center_norm: %.9f, %.9f, %.9f. t_center: %.9f. radius: %.9f.\n", + idx, + closest_possible_intersect, + ray_center_norm.x, + ray_center_norm.y, + ray_center_norm.z, + t_center, + vert_rad); + closest_possible_intersect_out = closest_possible_intersect; + draw_helper_out.ray_center_norm = ray_center_norm; + draw_helper_out.t_center = t_center; + draw_helper_out.radius = vert_rad; + if (cam.n_channels <= 3) { + draw_helper_out.first_color = vert_col[0]; + for (uint c_id = 1; c_id < cam.n_channels; ++c_id) { + draw_helper_out.color_union.color[c_id - 1] = vert_col[c_id]; + } + } else { + draw_helper_out.color_union.ptr = const_cast(vert_col); + } + END_PARALLEL(); +}; + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.calc_signature.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.calc_signature.instantiate.h new file mode 100644 index 00000000..b87bcff7 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.calc_signature.instantiate.h @@ -0,0 +1,18 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_INSTANTIATE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_INSTANTIATE_H_ + +#include "./renderer.calc_signature.device.h" + +namespace pulsar { +namespace Renderer { +template GLOBAL void calc_signature( + Renderer renderer, + float3 const* const RESTRICT vert_poss, + float const* const RESTRICT vert_cols, + float const* const RESTRICT vert_rads, + const uint num_balls); +} +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.construct.device.h b/pytorch3d/csrc/pulsar/include/renderer.construct.device.h new file mode 100644 index 00000000..55bde54d --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.construct.device.h @@ -0,0 +1,104 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CONSTRUCT_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_CONSTRUCT_DEVICE_H_ + +#include "../global.h" +#include "./camera.device.h" +#include "./commands.h" +#include "./math.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +template +HOST void construct( + Renderer* self, + const size_t& max_num_balls, + const int& width, + const int& height, + const bool& orthogonal_projection, + const bool& right_handed_system, + const float& background_normalization_depth, + const uint& n_channels, + const uint& n_track) { + ARGCHECK( + (max_num_balls > 0 && max_num_balls < MAX_INT), + 2, + ("the maximum number of balls must be >0 and <" + + std::to_string(MAX_INT) + ". Is " + std::to_string(max_num_balls) + ".") + .c_str()); + ARGCHECK(width > 1, 3, "the image width must be > 1"); + ARGCHECK(height > 1, 4, "the image height must be > 1"); + ARGCHECK( + background_normalization_depth > 0.f && + background_normalization_depth < 1.f, + 6, + "background_normalization_depth must be in ]0., 1.[."); + ARGCHECK(n_channels > 0, 7, "n_channels must be >0!"); + ARGCHECK( + n_track > 0 && n_track <= MAX_GRAD_SPHERES, + 8, + ("n_track must be >0 and <" + std::to_string(MAX_GRAD_SPHERES) + ". Is " + + std::to_string(n_track) + ".") + .c_str()); + self->cam.film_width = width; + self->cam.film_height = height; + self->max_num_balls = max_num_balls; + MALLOC(self->result_d, float, width* height* n_channels); + self->cam.orthogonal_projection = orthogonal_projection; + self->cam.right_handed = right_handed_system; + self->cam.background_normalization_depth = background_normalization_depth; + self->cam.n_channels = n_channels; + MALLOC(self->min_depth_d, float, max_num_balls); + MALLOC(self->min_depth_sorted_d, float, max_num_balls); + MALLOC(self->ii_d, IntersectInfo, max_num_balls); + MALLOC(self->ii_sorted_d, IntersectInfo, max_num_balls); + MALLOC(self->ids_d, int, max_num_balls); + MALLOC(self->ids_sorted_d, int, max_num_balls); + size_t sort_id_size = 0; + GET_SORT_WS_SIZE(&sort_id_size, float, int, max_num_balls); + CHECKLAUNCH(); + size_t sort_ii_size = 0; + GET_SORT_WS_SIZE(&sort_ii_size, float, IntersectInfo, max_num_balls); + CHECKLAUNCH(); + size_t sort_di_size = 0; + GET_SORT_WS_SIZE(&sort_di_size, float, DrawInfo, max_num_balls); + CHECKLAUNCH(); + size_t select_ii_size = 0; + GET_SELECT_WS_SIZE(&select_ii_size, char, IntersectInfo, max_num_balls); + size_t select_di_size = 0; + GET_SELECT_WS_SIZE(&select_di_size, char, DrawInfo, max_num_balls); + size_t sum_size = 0; + GET_SUM_WS_SIZE(&sum_size, CamGradInfo, max_num_balls); + size_t sum_cont_size = 0; + GET_SUM_WS_SIZE(&sum_cont_size, int, max_num_balls); + size_t reduce_size = 0; + GET_REDUCE_WS_SIZE( + &reduce_size, IntersectInfo, IntersectInfoMinMax(), max_num_balls); + self->workspace_size = IMAX( + IMAX(IMAX(sort_id_size, sort_ii_size), sort_di_size), + IMAX( + IMAX(select_di_size, select_ii_size), + IMAX(IMAX(sum_size, sum_cont_size), reduce_size))); + MALLOC(self->workspace_d, char, self->workspace_size); + MALLOC(self->di_d, DrawInfo, max_num_balls); + MALLOC(self->di_sorted_d, DrawInfo, max_num_balls); + MALLOC(self->region_flags_d, char, max_num_balls); + MALLOC(self->num_selected_d, size_t, 1); + MALLOC(self->forw_info_d, float, width* height*(3 + 2 * n_track)); + MALLOC(self->min_max_pixels_d, IntersectInfo, 1); + MALLOC(self->grad_pos_d, float3, max_num_balls); + MALLOC(self->grad_col_d, float, max_num_balls* n_channels); + MALLOC(self->grad_rad_d, float, max_num_balls); + MALLOC(self->grad_cam_d, float, 12); + MALLOC(self->grad_cam_buf_d, CamGradInfo, max_num_balls); + MALLOC(self->grad_opy_d, float, max_num_balls); + MALLOC(self->n_grad_contributions_d, int, 1); + self->n_track = static_cast(n_track); +} + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.construct.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.construct.instantiate.h new file mode 100644 index 00000000..09964f2b --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.construct.instantiate.h @@ -0,0 +1,22 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CONSTRUCT_INSTANTIATE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_CONSTRUCT_INSTANTIATE_H_ + +#include "./renderer.construct.device.h" + +namespace pulsar { +namespace Renderer { +template void construct( + Renderer* self, + const size_t& max_num_balls, + const int& width, + const int& height, + const bool& orthogonal_projection, + const bool& right_handed_system, + const float& background_normalization_depth, + const uint& n_channels, + const uint& n_track); +} +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.create_selector.device.h b/pytorch3d/csrc/pulsar/include/renderer.create_selector.device.h new file mode 100644 index 00000000..42ef5c25 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.create_selector.device.h @@ -0,0 +1,34 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CREATE_SELECTOR_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_CREATE_SELECTOR_DEVICE_H_ + +#include "../global.h" +#include "./commands.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +template +GLOBAL void create_selector( + IntersectInfo const* const RESTRICT ii_sorted_d, + const uint num_balls, + const int min_x, + const int max_x, + const int min_y, + const int max_y, + /* Out variables. */ + char* RESTRICT region_flags_d) { + GET_PARALLEL_IDX_1D(idx, num_balls); + bool hit = (static_cast(ii_sorted_d[idx].min.x) <= max_x) && + (static_cast(ii_sorted_d[idx].max.x) > min_x) && + (static_cast(ii_sorted_d[idx].min.y) <= max_y) && + (static_cast(ii_sorted_d[idx].max.y) > min_y); + region_flags_d[idx] = hit; + END_PARALLEL_NORET(); +} + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.create_selector.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.create_selector.instantiate.h new file mode 100644 index 00000000..bafd9fac --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.create_selector.instantiate.h @@ -0,0 +1,23 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CREATE_SELECTOR_INSTANTIATE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_CREATE_SELECTOR_INSTANTIATE_H_ + +#include "./renderer.create_selector.device.h" + +namespace pulsar { +namespace Renderer { + +template GLOBAL void create_selector( + IntersectInfo const* const RESTRICT ii_sorted_d, + const uint num_balls, + const int min_x, + const int max_x, + const int min_y, + const int max_y, + /* Out variables. */ + char* RESTRICT region_flags_d); + +} +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.destruct.device.h b/pytorch3d/csrc/pulsar/include/renderer.destruct.device.h new file mode 100644 index 00000000..a3a1044e --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.destruct.device.h @@ -0,0 +1,82 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_DESTRUCT_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_DESTRUCT_H_ + +#include "../global.h" +#include "./commands.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +template +HOST void destruct(Renderer* self) { + if (self->result_d != NULL) + FREE(self->result_d); + self->result_d = NULL; + if (self->min_depth_d != NULL) + FREE(self->min_depth_d); + self->min_depth_d = NULL; + if (self->min_depth_sorted_d != NULL) + FREE(self->min_depth_sorted_d); + self->min_depth_sorted_d = NULL; + if (self->ii_d != NULL) + FREE(self->ii_d); + self->ii_d = NULL; + if (self->ii_sorted_d != NULL) + FREE(self->ii_sorted_d); + self->ii_sorted_d = NULL; + if (self->ids_d != NULL) + FREE(self->ids_d); + self->ids_d = NULL; + if (self->ids_sorted_d != NULL) + FREE(self->ids_sorted_d); + self->ids_sorted_d = NULL; + if (self->workspace_d != NULL) + FREE(self->workspace_d); + self->workspace_d = NULL; + if (self->di_d != NULL) + FREE(self->di_d); + self->di_d = NULL; + if (self->di_sorted_d != NULL) + FREE(self->di_sorted_d); + self->di_sorted_d = NULL; + if (self->region_flags_d != NULL) + FREE(self->region_flags_d); + self->region_flags_d = NULL; + if (self->num_selected_d != NULL) + FREE(self->num_selected_d); + self->num_selected_d = NULL; + if (self->forw_info_d != NULL) + FREE(self->forw_info_d); + self->forw_info_d = NULL; + if (self->min_max_pixels_d != NULL) + FREE(self->min_max_pixels_d); + self->min_max_pixels_d = NULL; + if (self->grad_pos_d != NULL) + FREE(self->grad_pos_d); + self->grad_pos_d = NULL; + if (self->grad_col_d != NULL) + FREE(self->grad_col_d); + self->grad_col_d = NULL; + if (self->grad_rad_d != NULL) + FREE(self->grad_rad_d); + self->grad_rad_d = NULL; + if (self->grad_cam_d != NULL) + FREE(self->grad_cam_d); + self->grad_cam_d = NULL; + if (self->grad_cam_buf_d != NULL) + FREE(self->grad_cam_buf_d); + self->grad_cam_buf_d = NULL; + if (self->grad_opy_d != NULL) + FREE(self->grad_opy_d); + self->grad_opy_d = NULL; + if (self->n_grad_contributions_d != NULL) + FREE(self->n_grad_contributions_d); + self->n_grad_contributions_d = NULL; +} + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.destruct.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.destruct.instantiate.h new file mode 100644 index 00000000..ce3d10f2 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.destruct.instantiate.h @@ -0,0 +1,13 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_DESTRUCT_INSTANTIATE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_DESTRUCT_INSTANTIATE_H_ + +#include "./renderer.destruct.device.h" + +namespace pulsar { +namespace Renderer { +template void destruct(Renderer* self); +} +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.draw.device.h b/pytorch3d/csrc/pulsar/include/renderer.draw.device.h new file mode 100644 index 00000000..379319a7 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.draw.device.h @@ -0,0 +1,839 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_DEVICE_H_ + +#include "../global.h" +#include "./camera.device.h" +#include "./commands.h" +#include "./math.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +/** + * Draw a ball into the `result`. + * + * Returns whether a hit was noticed. See README for an explanation of sphere + * points and variable notation. + */ +INLINE DEVICE bool draw( + /* In variables. */ + const DrawInfo& draw_info, /** The draw information for this ball. */ + const float& opacity, /** The sphere opacity. */ + const CamInfo& + cam, /** Camera information. Doesn't have to be normalized. */ + const float& gamma, /** 'Transparency' indicator (see paper for details). */ + const float3& ray_dir_norm, /** The direction of the ray, normalized. */ + const float2& projected_ray, /** The intersection of the ray with the image + in pixel space. */ + /** Mode switches. */ + const bool& draw_only, /** Whether we are in draw vs. grad mode. */ + const bool& calc_grad_pos, /** Calculate position gradients. */ + const bool& calc_grad_col, /** Calculate color gradients. */ + const bool& calc_grad_rad, /** Calculate radius gradients. */ + const bool& calc_grad_cam, /** Calculate camera gradients. */ + const bool& calc_grad_opy, /** Calculate opacity gradients. */ + /** Position info. */ + const uint& coord_x, /** The pixel position x to draw at. */ + const uint& coord_y, /** The pixel position y to draw at. */ + const uint& idx, /** The id of the sphere to process. */ + /* Optional in variables. */ + IntersectInfo const* const RESTRICT + intersect_info, /** The intersect information for this ball. */ + float3 const* const RESTRICT ray_dir, /** The ray direction (not normalized) + to draw at. Only used for grad computation. */ + float const* const RESTRICT norm_ray_dir, /** The length of the direction + vector. Only used for grad computation. */ + float const* const RESTRICT grad_pix, /** The gradient for this pixel. Only + used for grad computation. */ + float const* const RESTRICT + ln_pad_over_1minuspad, /** Allowed percentage indicator. */ + /* In or out variables, depending on mode. */ + float* const RESTRICT sm_d, /** Normalization denominator. */ + float* const RESTRICT + sm_m, /** Maximum of normalization weight factors observed. */ + float* const RESTRICT + result, /** Result pixel color. Must be zeros initially. */ + /* Optional out variables. */ + float* const RESTRICT depth_threshold, /** The depth threshold to use. Only + used for rendering. */ + float* const RESTRICT intersection_depth_norm_out, /** The intersection + depth. Only set when rendering. */ + float3* const RESTRICT grad_pos, /** Gradient w.r.t. position. */ + float* const RESTRICT grad_col, /** Gradient w.r.t. color. */ + float* const RESTRICT grad_rad, /** Gradient w.r.t. radius. */ + CamGradInfo* const RESTRICT grad_cam, /** Gradient w.r.t. camera. */ + float* const RESTRICT grad_opy /** Gradient w.r.t. opacity. */ +) { + // TODO: variable reuse? + PASSERT( + isfinite(draw_info.ray_center_norm.x) && + isfinite(draw_info.ray_center_norm.y) && + isfinite(draw_info.ray_center_norm.z)); + PASSERT(isfinite(draw_info.t_center) && draw_info.t_center >= 0.f); + PASSERT( + isfinite(draw_info.radius) && draw_info.radius >= 0.f && + draw_info.radius <= draw_info.t_center); + PASSERT(isfinite(ray_dir_norm.x)); + PASSERT(isfinite(ray_dir_norm.y)); + PASSERT(isfinite(ray_dir_norm.z)); + PASSERT(isfinite(*sm_d)); + PASSERT( + cam.orthogonal_projection && cam.focal_length == 0.f || + cam.focal_length > 0.f); + PASSERT(gamma <= 1.f && gamma >= 1e-5f); + /** The ball center in the camera coordinate system. */ + float3 center = draw_info.ray_center_norm * draw_info.t_center; + /** The vector from the reference point to the ball center. */ + float3 raydiff; + if (cam.orthogonal_projection) { + center = rotate( + center, + cam.pixel_dir_x / length(cam.pixel_dir_x), + cam.pixel_dir_y / length(cam.pixel_dir_y), + cam.sensor_dir_z); + raydiff = + make_float3( // TODO: make offset consistent with `get_screen_area`. + center.x - + (projected_ray.x - + static_cast(cam.aperture_width) * .5f) * + (2.f * cam.half_pixel_size), + center.y - + (projected_ray.y - + static_cast(cam.aperture_height) * .5f) * + (2.f * cam.half_pixel_size), + 0.f); + } else { + /** The reference point on the ray; the point in the same distance + * from the camera as the ball center, but along the ray. + */ + const float3 rayref = ray_dir_norm * draw_info.t_center; + raydiff = center - rayref; + } + /** The closeness of the reference point to ball center in world coords. + * + * In [0., radius]. + */ + const float closeness_world = length(raydiff); + /** The reciprocal radius. */ + const float radius_rcp = FRCP(draw_info.radius); + /** The closeness factor normalized with the ball radius. + * + * In [0., 1.]. + */ + float closeness = FSATURATE(FMA(-closeness_world, radius_rcp, 1.f)); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_DRAW_PIX, + "drawprep %u|center: %.9f, %.9f, %.9f. raydiff: %.9f, " + "%.9f, %.9f. closeness_world: %.9f. closeness: %.9f\n", + idx, + center.x, + center.y, + center.z, + raydiff.x, + raydiff.y, + raydiff.z, + closeness_world, + closeness); + /** Whether this is the 'center pixel' for this ball, the pixel that + * is closest to its projected center. This information is used to + * make sure to draw 'tiny' spheres with less than one pixel in + * projected size. + */ + bool ray_through_center_pixel; + float projected_radius, projected_x, projected_y; + if (cam.orthogonal_projection) { + projected_x = center.x / (2.f * cam.half_pixel_size) + + (static_cast(cam.aperture_width) - 1.f) / 2.f; + projected_y = center.y / (2.f * cam.half_pixel_size) + + (static_cast(cam.aperture_height) - 1.f) / 2.f; + projected_radius = draw_info.radius / (2.f * cam.half_pixel_size); + ray_through_center_pixel = + (FABS(FSUB(projected_x, projected_ray.x)) < 0.5f + FEPS && + FABS(FSUB(projected_y, projected_ray.y)) < 0.5f + FEPS); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_DRAW_PIX, + "drawprep %u|closeness_world: %.9f. closeness: %.9f. " + "projected (x, y): %.9f, %.9f. projected_ray (x, y): " + "%.9f, %.9f. ray_through_center_pixel: %d.\n", + idx, + closeness_world, + closeness, + projected_x, + projected_y, + projected_ray.x, + projected_ray.y, + ray_through_center_pixel); + } else { + // Misusing this variable for half pixel size projected to the depth + // at which the sphere resides. Leave some slack for numerical + // inaccuracy (factor 1.5). + projected_x = FMUL(cam.half_pixel_size * 1.5, draw_info.t_center) * + FRCP(cam.focal_length); + projected_radius = FMUL(draw_info.radius, cam.focal_length) * + FRCP(draw_info.t_center) / (2.f * cam.half_pixel_size); + ray_through_center_pixel = projected_x > closeness_world; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_DRAW_PIX, + "drawprep %u|closeness_world: %.9f. closeness: %.9f. " + "projected half pixel size: %.9f. " + "ray_through_center_pixel: %d.\n", + idx, + closeness_world, + closeness, + projected_x, + ray_through_center_pixel); + } + if (draw_only && draw_info.radius < closeness_world && + !ray_through_center_pixel) { + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_DRAW_PIX, + "drawprep %u|Abandoning since no hit has been detected.\n", + idx); + return false; + } else { + // This is always a hit since we are following the forward execution pass. + // p2 is the closest intersection point with the sphere. + } + if (ray_through_center_pixel && projected_radius < 1.f) { + // Make a tiny sphere visible. + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_DRAW_PIX, + "drawprep %u|Setting closeness to 1 (projected radius: %.9f).\n", + idx, + projected_radius); + closeness = 1.; + } + PASSERT(closeness >= 0.f && closeness <= 1.f); + /** Distance between the camera (`o`) and `p1`, the closest point to the + * ball center along the casted ray. + * + * In [t_center - radius, t_center]. + */ + float o__p1_; + /** The distance from ball center to p1. + * + * In [0., sqrt(t_center ^ 2 - (t_center - radius) ^ 2)]. + */ + float c__p1_; + if (cam.orthogonal_projection) { + o__p1_ = FABS(center.z); + c__p1_ = length(raydiff); + } else { + o__p1_ = dot(center, ray_dir_norm); + /** + * This is being calculated as sqrt(t_center^2 - o__p1_^2) = + * sqrt((t_center + o__p1_) * (t_center - o__p1_)) to avoid + * catastrophic cancellation in floating point representations. + */ + c__p1_ = FSQRT( + (draw_info.t_center + o__p1_) * FMAX(draw_info.t_center - o__p1_, 0.f)); + // PASSERT(o__p1_ >= draw_info.t_center - draw_info.radius); + // Numerical errors lead to too large values. + o__p1_ = FMIN(o__p1_, draw_info.t_center); + // PASSERT(o__p1_ <= draw_info.t_center); + } + /** The distance from the closest point to the sphere center (p1) + * to the closest intersection point (p2). + * + * In [0., radius]. + */ + const float p1__p2_ = + FSQRT((draw_info.radius + c__p1_) * FMAX(draw_info.radius - c__p1_, 0.f)); + PASSERT(p1__p2_ >= 0.f && p1__p2_ <= draw_info.radius); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_DRAW_PIX, + "drawprep %u|o__p1_: %.9f, c__p1_: %.9f, p1__p2_: %.9f.\n", + idx, + o__p1_, + c__p1_, + p1__p2_); + /** The intersection depth of the ray with this ball. + * + * In [t_center - radius, t_center]. + */ + const float intersection_depth = (o__p1_ - p1__p2_); + PASSERT( + cam.orthogonal_projection && + (intersection_depth >= center.z - draw_info.radius && + intersection_depth <= center.z) || + intersection_depth >= draw_info.t_center - draw_info.radius && + intersection_depth <= draw_info.t_center); + /** Normalized distance of the closest intersection point; in [0., 1.]. */ + const float norm_dist = + FMUL(FSUB(intersection_depth, cam.min_dist), cam.norm_fac); + PASSERT(norm_dist >= 0.f && norm_dist <= 1.f); + /** Scaled, normalized distance in [1., 0.] (closest, farthest). */ + const float norm_dist_scaled = FSUB(1.f, norm_dist) / gamma * opacity; + PASSERT(norm_dist_scaled >= 0.f && norm_dist_scaled <= 1.f / gamma); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_DRAW_PIX, + "drawprep %u|intersection_depth: %.9f, norm_dist: %.9f, " + "norm_dist_scaled: %.9f.\n", + idx, + intersection_depth, + norm_dist, + norm_dist_scaled); + float const* const col_ptr = + cam.n_channels > 3 ? draw_info.color_union.ptr : &draw_info.first_color; + // The implementation for the numerically stable weighted softmax is based + // on https://arxiv.org/pdf/1805.02867.pdf . + if (draw_only) { + /** The old maximum observed value. */ + const float sm_m_old = *sm_m; + *sm_m = FMAX(*sm_m, norm_dist_scaled); + const float coeff_exp = FEXP(norm_dist_scaled - *sm_m); + PASSERT(isfinite(coeff_exp)); + /** The color coefficient for the ball color; in [0., 1.]. */ + const float coeff = closeness * coeff_exp * opacity; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_DRAW_PIX, + "draw %u|coeff: %.9f. closeness: %.9f. coeff_exp: %.9f. " + "opacity: %.9f.\n", + idx, + coeff, + closeness, + coeff_exp, + opacity); + // Rendering. + if (sm_m_old == *sm_m) { + // Use the fact that exp(0) = 1 to avoid the exp calculation for + // the case that the maximum remains the same (which it should + // most of the time). + *sm_d = FADD(*sm_d, coeff); + for (uint c_id = 0; c_id < cam.n_channels; ++c_id) { + PASSERT(isfinite(result[c_id])); + result[c_id] = FMA(coeff, col_ptr[c_id], result[c_id]); + } + } else { + const float exp_correction = FEXP(sm_m_old - *sm_m); + *sm_d = FMA(*sm_d, exp_correction, coeff); + for (uint c_id = 0; c_id < cam.n_channels; ++c_id) { + PASSERT(isfinite(result[c_id])); + result[c_id] = + FMA(coeff, col_ptr[c_id], FMUL(result[c_id], exp_correction)); + } + } + PASSERT(isfinite(*sm_d)); + *intersection_depth_norm_out = intersection_depth; + // Update the depth threshold. + *depth_threshold = + 1.f - (FLN(*sm_d + FEPS) + *ln_pad_over_1minuspad + *sm_m) * gamma; + *depth_threshold = + FMA(*depth_threshold, FSUB(cam.max_dist, cam.min_dist), cam.min_dist); + } else { + // Gradient computation. + const float coeff_exp = FEXP(norm_dist_scaled - *sm_m); + const float gamma_rcp = FRCP(gamma); + const float radius_sq = FMUL(draw_info.radius, draw_info.radius); + const float coeff = FMAX( + FMIN(closeness * coeff_exp * opacity, *sm_d - FEPS), + 0.f); // in [0., sm_d - FEPS]. + PASSERT(coeff >= 0.f && coeff <= *sm_d); + const float otherw = *sm_d - coeff; // in [FEPS, sm_d]. + const float p1__p2_safe = FMAX(p1__p2_, FEPS); // in [eps, t_center]. + const float cam_range = FSUB(cam.max_dist, cam.min_dist); // in ]0, inf[ + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|pos: %.9f, %.9f, %.9f. pixeldirx: %.9f, %.9f, %.9f. " + "pixeldiry: %.9f, %.9f, %.9f. pixel00center: %.9f, %.9f, %.9f.\n", + idx, + draw_info.ray_center_norm.x * draw_info.t_center, + draw_info.ray_center_norm.y * draw_info.t_center, + draw_info.ray_center_norm.z * draw_info.t_center, + cam.pixel_dir_x.x, + cam.pixel_dir_x.y, + cam.pixel_dir_x.z, + cam.pixel_dir_y.x, + cam.pixel_dir_y.y, + cam.pixel_dir_y.z, + cam.pixel_0_0_center.x, + cam.pixel_0_0_center.y, + cam.pixel_0_0_center.z); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|ray_dir: %.9f, %.9f, %.9f. " + "ray_dir_norm: %.9f, %.9f, %.9f. " + "draw_info.ray_center_norm: %.9f, %.9f, %.9f.\n", + idx, + ray_dir->x, + ray_dir->y, + ray_dir->z, + ray_dir_norm.x, + ray_dir_norm.y, + ray_dir_norm.z, + draw_info.ray_center_norm.x, + draw_info.ray_center_norm.y, + draw_info.ray_center_norm.z); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|coeff_exp: %.9f. " + "norm_dist_scaled: %.9f. cam.norm_fac: %f.\n", + idx, + coeff_exp, + norm_dist_scaled, + cam.norm_fac); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|p1__p2_: %.9f. p1__p2_safe: %.9f.\n", + idx, + p1__p2_, + p1__p2_safe); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|o__p1_: %.9f. c__p1_: %.9f.\n", + idx, + o__p1_, + c__p1_); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|intersection_depth: %f. norm_dist: %f. " + "coeff: %.9f. closeness: %f. coeff_exp: %f. opacity: " + "%f. color: %f, %f, %f.\n", + idx, + intersection_depth, + norm_dist, + coeff, + closeness, + coeff_exp, + opacity, + draw_info.first_color, + draw_info.color_union.color[0], + draw_info.color_union.color[1]); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|t_center: %.9f. " + "radius: %.9f. max_dist: %f. min_dist: %f. gamma: %f.\n", + idx, + draw_info.t_center, + draw_info.radius, + cam.max_dist, + cam.min_dist, + gamma); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|sm_d: %f. sm_m: %f. grad_pix (first three): %f, %f, %f.\n", + idx, + *sm_d, + *sm_m, + grad_pix[0], + grad_pix[1], + grad_pix[2]); + PULSAR_LOG_DEV_PIX(PULSAR_LOG_GRAD, "grad %u|otherw: %f.\n", idx, otherw); + if (calc_grad_col) { + const float sm_d_norm = FRCP(FMAX(*sm_d, FEPS)); + // First do the multiplication of coeff (in [0., sm_d]) and 1/sm_d. The + // result is a factor in [0., 1.] to be multiplied with the incoming + // gradient. + for (uint c_id = 0; c_id < cam.n_channels; ++c_id) { + ATOMICADD(grad_col + c_id, grad_pix[c_id] * FMUL(coeff, sm_d_norm)); + } + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdcol.x: %f. dresDdcol.x: %f.\n", + idx, + FMUL(coeff, sm_d_norm) * grad_pix[0], + coeff * sm_d_norm); + } + // We disable the computation for too small spheres. + // The comparison is made this way to avoid subtraction of unsigned types. + if (calc_grad_cam || calc_grad_pos || calc_grad_rad || calc_grad_opy) { + //! First find dimDdcoeff. + const float n0 = + otherw * FRCP(FMAX(*sm_d * *sm_d, FEPS)); // in [0., 1. / sm_d]. + PASSERT(isfinite(n0) && n0 >= 0. && n0 <= 1. / *sm_d + 1e2f * FEPS); + // We'll aggergate dimDdcoeff over all the 'color' channels. + float dimDdcoeff = 0.f; + const float otherw_safe_rcp = FRCP(FMAX(otherw, FEPS)); + float othercol; + for (uint c_id = 0; c_id < cam.n_channels; ++c_id) { + othercol = + (result[c_id] * *sm_d - col_ptr[c_id] * coeff) * otherw_safe_rcp; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|othercol[%u]: %.9f.\n", + idx, + c_id, + othercol); + dimDdcoeff += + FMUL(FMUL(grad_pix[c_id], FSUB(col_ptr[c_id], othercol)), n0); + } + PASSERT(isfinite(dimDdcoeff)); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdcoeff: %.9f, n0: %f.\n", + idx, + dimDdcoeff, + n0); + if (calc_grad_opy) { + //! dimDdopacity. + *grad_opy += dimDdcoeff * coeff_exp * closeness * + (1.f + opacity * (1.f - norm_dist) * gamma_rcp); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcoeffDdopacity: %.9f, dimDdopacity: %.9f.\n", + idx, + coeff_exp * closeness, + dimDdcoeff * coeff_exp * closeness); + } + if (intersect_info->max.x >= intersect_info->min.x + 3 && + intersect_info->max.y >= intersect_info->min.y + 3) { + //! Now find dcoeffDdintersection_depth and dcoeffDdcloseness. + const float dcoeffDdintersection_depth = + -closeness * coeff_exp * opacity * opacity / (gamma * cam_range); + const float dcoeffDdcloseness = coeff_exp * opacity; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcoeffDdintersection_depth: %.9f. " + "dimDdintersection_depth: %.9f. " + "dcoeffDdcloseness: %.9f. dimDdcloseness: %.9f.\n", + idx, + dcoeffDdintersection_depth, + dimDdcoeff * dcoeffDdintersection_depth, + dcoeffDdcloseness, + dimDdcoeff * dcoeffDdcloseness); + //! Here, the execution paths for orthogonal and pinyhole camera split. + if (cam.orthogonal_projection) { + if (calc_grad_rad) { + //! Find dcoeffDdrad. + float dcoeffDdrad = + dcoeffDdcloseness * (closeness_world / radius_sq) - + dcoeffDdintersection_depth * draw_info.radius / p1__p2_safe; + PASSERT(isfinite(dcoeffDdrad)); + *grad_rad += FMUL(dimDdcoeff, dcoeffDdrad); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdrad: %.9f. dcoeffDdrad: %.9f.\n", + idx, + FMUL(dimDdcoeff, dcoeffDdrad), + dcoeffDdrad); + } + if (calc_grad_pos || calc_grad_cam) { + float3 dimDdcenter = raydiff / + p1__p2_safe; /* making it dintersection_depthDdcenter. */ + dimDdcenter.z = sign_dir(center.z); + PASSERT(FABS(center.z) >= cam.min_dist && cam.min_dist >= FEPS); + dimDdcenter *= dcoeffDdintersection_depth; // dcoeffDdcenter + dimDdcenter -= dcoeffDdcloseness * /* dclosenessDdcenter. */ + raydiff * FRCP(FMAX(length(raydiff) * draw_info.radius, FEPS)); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcoeffDdcenter: %.9f, %.9f, %.9f.\n", + idx, + dimDdcenter.x, + dimDdcenter.y, + dimDdcenter.z); + // Now dcoeffDdcenter is stored in dimDdcenter. + dimDdcenter *= dimDdcoeff; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdcenter: %.9f, %.9f, %.9f.\n", + idx, + dimDdcenter.x, + dimDdcenter.y, + dimDdcenter.z); + // Prepare for posglob and cam pos. + const float pixel_size = length(cam.pixel_dir_x); + // pixel_size is the same as length(pixeldiry)! + const float pixel_size_rcp = FRCP(pixel_size); + float3 dcenterDdposglob = + (cam.pixel_dir_x + cam.pixel_dir_y) * pixel_size_rcp + + cam.sensor_dir_z; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcenterDdposglob: %.9f, %.9f, %.9f.\n", + idx, + dcenterDdposglob.x, + dcenterDdposglob.y, + dcenterDdposglob.z); + if (calc_grad_pos) { + //! dcenterDdposglob. + *grad_pos += dimDdcenter * dcenterDdposglob; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdpos: %.9f, %.9f, %.9f.\n", + idx, + dimDdcenter.x * dcenterDdposglob.x, + dimDdcenter.y * dcenterDdposglob.y, + dimDdcenter.z * dcenterDdposglob.z); + } + if (calc_grad_cam) { + //! Camera. + grad_cam->cam_pos -= dimDdcenter * dcenterDdposglob; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdeye: %.9f, %.9f, %.9f.\n", + idx, + -dimDdcenter.x * dcenterDdposglob.x, + -dimDdcenter.y * dcenterDdposglob.y, + -dimDdcenter.z * dcenterDdposglob.z); + // coord_world + /* + float3 dclosenessDdcoord_world = + raydiff * FRCP(FMAX(draw_info.radius * length(raydiff), FEPS)); + float3 dintersection_depthDdcoord_world = -2.f * raydiff; + */ + float3 dimDdcoord_world = /* dcoeffDdcoord_world */ + dcoeffDdcloseness * raydiff * + FRCP(FMAX(draw_info.radius * length(raydiff), FEPS)) - + dcoeffDdintersection_depth * raydiff / p1__p2_safe; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcoeffDdcoord_world: %.9f, %.9f, %.9f.\n", + idx, + dimDdcoord_world.x, + dimDdcoord_world.y, + dimDdcoord_world.z); + dimDdcoord_world *= dimDdcoeff; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdcoord_world: %.9f, %.9f, %.9f.\n", + idx, + dimDdcoord_world.x, + dimDdcoord_world.y, + dimDdcoord_world.z); + // The third component of dimDdcoord_world is 0! + PASSERT(dimDdcoord_world.z == 0.f); + float3 coord_world = center - raydiff; + coord_world.z = 0.f; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|coord_world: %.9f, %.9f, %.9f.\n", + idx, + coord_world.x, + coord_world.y, + coord_world.z); + // Do this component-wise to save unnecessary matmul steps. + grad_cam->pixel_dir_x += dimDdcoord_world.x * cam.pixel_dir_x * + coord_world.x * pixel_size_rcp * pixel_size_rcp; + grad_cam->pixel_dir_x += dimDdcoord_world.y * cam.pixel_dir_x * + coord_world.y * pixel_size_rcp * pixel_size_rcp; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdpixel_dir_x|coord_world: %.9f, %.9f, %.9f.\n", + idx, + grad_cam->pixel_dir_x.x, + grad_cam->pixel_dir_x.y, + grad_cam->pixel_dir_x.z); + // dcenterkDdpixel_dir_k. + float3 center_in_pixels = draw_info.ray_center_norm * + draw_info.t_center * pixel_size_rcp; + grad_cam->pixel_dir_x += dimDdcenter.x * + (center_in_pixels - + outer_product_sum(cam.pixel_dir_x) * center_in_pixels * + pixel_size_rcp * pixel_size_rcp); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcenter0dpixel_dir_x: %.9f, %.9f, %.9f.\n", + idx, + (center_in_pixels - + outer_product_sum(cam.pixel_dir_x) * center_in_pixels * + pixel_size_rcp * pixel_size_rcp) + .x, + (center_in_pixels - + outer_product_sum(cam.pixel_dir_x) * center_in_pixels * + pixel_size_rcp * pixel_size_rcp) + .y, + (center_in_pixels - + outer_product_sum(cam.pixel_dir_x) * center_in_pixels * + pixel_size_rcp * pixel_size_rcp) + .z); + grad_cam->pixel_dir_y += dimDdcenter.y * + (center_in_pixels - + outer_product_sum(cam.pixel_dir_y) * center_in_pixels * + pixel_size_rcp * pixel_size_rcp); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcenter1dpixel_dir_y: %.9f, %.9f, %.9f.\n", + idx, + (center_in_pixels - + outer_product_sum(cam.pixel_dir_y) * center_in_pixels * + pixel_size_rcp * pixel_size_rcp) + .x, + (center_in_pixels - + outer_product_sum(cam.pixel_dir_y) * center_in_pixels * + pixel_size_rcp * pixel_size_rcp) + .y, + (center_in_pixels - + outer_product_sum(cam.pixel_dir_y) * center_in_pixels * + pixel_size_rcp * pixel_size_rcp) + .z); + // dcenterzDdpixel_dir_k. + float sensordirz_norm_rcp = FRCP( + FMAX(length(cross(cam.pixel_dir_y, cam.pixel_dir_x)), FEPS)); + grad_cam->pixel_dir_x += dimDdcenter.z * + (dot(center, cam.sensor_dir_z) * + cross(cam.pixel_dir_y, cam.sensor_dir_z) - + cross(cam.pixel_dir_y, center)) * + sensordirz_norm_rcp; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcenterzDdpixel_dir_x: %.9f, %.9f, %.9f.\n", + idx, + ((dot(center, cam.sensor_dir_z) * + cross(cam.pixel_dir_y, cam.sensor_dir_z) - + cross(cam.pixel_dir_y, center)) * + sensordirz_norm_rcp) + .x, + ((dot(center, cam.sensor_dir_z) * + cross(cam.pixel_dir_y, cam.sensor_dir_z) - + cross(cam.pixel_dir_y, center)) * + sensordirz_norm_rcp) + .y, + ((dot(center, cam.sensor_dir_z) * + cross(cam.pixel_dir_y, cam.sensor_dir_z) - + cross(cam.pixel_dir_y, center)) * + sensordirz_norm_rcp) + .z); + grad_cam->pixel_dir_y += dimDdcenter.z * + (dot(center, cam.sensor_dir_z) * + cross(cam.pixel_dir_x, cam.sensor_dir_z) - + cross(cam.pixel_dir_x, center)) * + sensordirz_norm_rcp; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcenterzDdpixel_dir_y: %.9f, %.9f, %.9f.\n", + idx, + ((dot(center, cam.sensor_dir_z) * + cross(cam.pixel_dir_x, cam.sensor_dir_z) - + cross(cam.pixel_dir_x, center)) * + sensordirz_norm_rcp) + .x, + ((dot(center, cam.sensor_dir_z) * + cross(cam.pixel_dir_x, cam.sensor_dir_z) - + cross(cam.pixel_dir_x, center)) * + sensordirz_norm_rcp) + .y, + ((dot(center, cam.sensor_dir_z) * + cross(cam.pixel_dir_x, cam.sensor_dir_z) - + cross(cam.pixel_dir_x, center)) * + sensordirz_norm_rcp) + .z); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdpixel_dir_x: %.9f, %.9f, %.9f.\n", + idx, + grad_cam->pixel_dir_x.x, + grad_cam->pixel_dir_x.y, + grad_cam->pixel_dir_x.z); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdpixel_dir_y: %.9f, %.9f, %.9f.\n", + idx, + grad_cam->pixel_dir_y.x, + grad_cam->pixel_dir_y.y, + grad_cam->pixel_dir_y.z); + } + } + } else { + if (calc_grad_rad) { + //! Find dcoeffDdrad. + float dcoeffDdrad = + dcoeffDdcloseness * (closeness_world / radius_sq) - + dcoeffDdintersection_depth * draw_info.radius / p1__p2_safe; + PASSERT(isfinite(dcoeffDdrad)); + *grad_rad += FMUL(dimDdcoeff, dcoeffDdrad); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdrad: %.9f. dcoeffDdrad: %.9f.\n", + idx, + FMUL(dimDdcoeff, dcoeffDdrad), + dcoeffDdrad); + } + if (calc_grad_pos || calc_grad_cam) { + const float3 tmp1 = center - ray_dir_norm * o__p1_; + const float3 tmp1n = tmp1 / p1__p2_safe; + const float ray_dir_normDotRaydiff = dot(ray_dir_norm, raydiff); + const float3 dcoeffDdray = dcoeffDdintersection_depth * + (tmp1 - o__p1_ * tmp1n) / *norm_ray_dir + + dcoeffDdcloseness * + (ray_dir_norm * -ray_dir_normDotRaydiff + raydiff) / + (closeness_world * draw_info.radius) * + (draw_info.t_center / *norm_ray_dir); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcoeffDdray: %.9f, %.9f, %.9f. dimDdray: " + "%.9f, %.9f, %.9f.\n", + idx, + dcoeffDdray.x, + dcoeffDdray.y, + dcoeffDdray.z, + dimDdcoeff * dcoeffDdray.x, + dimDdcoeff * dcoeffDdray.y, + dimDdcoeff * dcoeffDdray.z); + const float3 dcoeffDdcenter = + dcoeffDdintersection_depth * (ray_dir_norm + tmp1n) + + dcoeffDdcloseness * + (draw_info.ray_center_norm * ray_dir_normDotRaydiff - + raydiff) / + (closeness_world * draw_info.radius); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dcoeffDdcenter: %.9f, %.9f, %.9f. " + "dimDdcenter: %.9f, %.9f, %.9f.\n", + idx, + dcoeffDdcenter.x, + dcoeffDdcenter.y, + dcoeffDdcenter.z, + dimDdcoeff * dcoeffDdcenter.x, + dimDdcoeff * dcoeffDdcenter.y, + dimDdcoeff * dcoeffDdcenter.z); + if (calc_grad_pos) { + *grad_pos += dimDdcoeff * dcoeffDdcenter; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdposglob: %.9f, %.9f, %.9f.\n", + idx, + dimDdcoeff * dcoeffDdcenter.x, + dimDdcoeff * dcoeffDdcenter.y, + dimDdcoeff * dcoeffDdcenter.z); + } + if (calc_grad_cam) { + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdeye: %.9f, %.9f, %.9f.\n", + idx, + -dimDdcoeff * (dcoeffDdcenter.x + dcoeffDdray.x), + -dimDdcoeff * (dcoeffDdcenter.y + dcoeffDdray.y), + -dimDdcoeff * (dcoeffDdcenter.z + dcoeffDdray.z)); + grad_cam->cam_pos += -dimDdcoeff * (dcoeffDdcenter + dcoeffDdray); + grad_cam->pixel_0_0_center += dimDdcoeff * dcoeffDdray; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdpixel00centerglob: %.9f, %.9f, %.9f.\n", + idx, + dimDdcoeff * dcoeffDdray.x, + dimDdcoeff * dcoeffDdray.y, + dimDdcoeff * dcoeffDdray.z); + grad_cam->pixel_dir_x += + (dimDdcoeff * static_cast(coord_x)) * dcoeffDdray; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdpixel_dir_x: %.9f, %.9f, %.9f.\n", + idx, + (dimDdcoeff * static_cast(coord_x)) * dcoeffDdray.x, + (dimDdcoeff * static_cast(coord_x)) * dcoeffDdray.y, + (dimDdcoeff * static_cast(coord_x)) * dcoeffDdray.z); + grad_cam->pixel_dir_y += + (dimDdcoeff * static_cast(coord_y)) * dcoeffDdray; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_GRAD, + "grad %u|dimDdpixel_dir_y: %.9f, %.9f, %.9f.\n", + idx, + (dimDdcoeff * static_cast(coord_y)) * dcoeffDdray.x, + (dimDdcoeff * static_cast(coord_y)) * dcoeffDdray.y, + (dimDdcoeff * static_cast(coord_y)) * dcoeffDdray.z); + } + } + } + } + } + } + return true; +}; + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.fill_bg.device.h b/pytorch3d/csrc/pulsar/include/renderer.fill_bg.device.h new file mode 100644 index 00000000..4137d2d1 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.fill_bg.device.h @@ -0,0 +1,55 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_FILL_BG_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_FILL_BG_DEVICE_H_ + +#include "../global.h" +#include "./camera.h" +#include "./commands.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +template +GLOBAL void fill_bg( + Renderer renderer, + const CamInfo cam, + float const* const bg_col_d, + const float gamma, + const uint mode) { + GET_PARALLEL_IDS_2D(coord_x, coord_y, cam.film_width, cam.film_height); + int write_loc = coord_y * cam.film_width * (3 + 2 * renderer.n_track) + + coord_x * (3 + 2 * renderer.n_track); + if (renderer.forw_info_d[write_loc + 1] // sm_d + == 0.f) { + // This location has not been processed yet. + // Write first the forw_info: + // sm_m + renderer.forw_info_d[write_loc] = + cam.background_normalization_depth / gamma; + // sm_d + renderer.forw_info_d[write_loc + 1] = 1.f; + // max_closest_possible_intersection_hit + renderer.forw_info_d[write_loc + 2] = -1.f; + // sphere IDs and intersection depths. + for (int i = 0; i < renderer.n_track; ++i) { + int sphere_id = -1; + IASF(sphere_id, renderer.forw_info_d[write_loc + 3 + i * 2]); + renderer.forw_info_d[write_loc + 3 + i * 2 + 1] = -1.f; + } + if (mode == 0) { + // Image background. + for (int i = 0; i < cam.n_channels; ++i) { + renderer.result_d + [coord_y * cam.film_width * cam.n_channels + + coord_x * cam.n_channels + i] = bg_col_d[i]; + } + } + } + END_PARALLEL_2D_NORET(); +}; + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.fill_bg.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.fill_bg.instantiate.h new file mode 100644 index 00000000..2b4e279d --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.fill_bg.instantiate.h @@ -0,0 +1,15 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./renderer.fill_bg.device.h" + +namespace pulsar { +namespace Renderer { + +template GLOBAL void fill_bg( + Renderer renderer, + const CamInfo norm, + float const* const bg_col_d, + const float gamma, + const uint mode); + +} // namespace Renderer +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/include/renderer.forward.device.h b/pytorch3d/csrc/pulsar/include/renderer.forward.device.h new file mode 100644 index 00000000..60f52850 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.forward.device.h @@ -0,0 +1,293 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_FORWARD_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_FORWARD_DEVICE_H_ + +#include "../global.h" +#include "./camera.device.h" +#include "./commands.h" +#include "./math.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +template +void forward( + Renderer* self, + const float* vert_pos, + const float* vert_col, + const float* vert_rad, + const CamInfo& cam, + const float& gamma, + float percent_allowed_difference, + const uint& max_n_hits, + const float* bg_col_d, + const float* opacity_d, + const size_t& num_balls, + const uint& mode, + cudaStream_t stream) { + ARGCHECK(gamma > 0.f && gamma <= 1.f, 6, "gamma must be in [0., 1.]"); + ARGCHECK( + percent_allowed_difference >= 0.f && percent_allowed_difference <= 1.f, + 7, + "percent_allowed_difference must be in [0., 1.]"); + ARGCHECK(max_n_hits >= 1u, 8, "max_n_hits must be >= 1"); + ARGCHECK( + num_balls > 0 && num_balls <= self->max_num_balls, + 9, + ("num_balls must be >0 and <= max num balls! (" + + std::to_string(num_balls) + " vs. " + + std::to_string(self->max_num_balls) + ")") + .c_str()); + ARGCHECK( + cam.film_width == self->cam.film_width && + cam.film_height == self->cam.film_height, + 5, + "cam result width and height must agree"); + ARGCHECK(mode <= 1, 10, "mode must be <= 1!"); + if (percent_allowed_difference > 1.f - FEPS) { + LOG(WARNING) << "percent_allowed_difference > " << (1.f - FEPS) + << "! Clamping to " << (1.f - FEPS) << "."; + percent_allowed_difference = 1.f - FEPS; + } + LOG_IF(INFO, PULSAR_LOG_RENDER) << "Rendering forward pass..."; + // Update camera and transform into a new virtual camera system with + // centered principal point and subsection rendering. + self->cam.eye = cam.eye; + self->cam.pixel_0_0_center = cam.pixel_0_0_center - cam.eye; + self->cam.pixel_dir_x = cam.pixel_dir_x; + self->cam.pixel_dir_y = cam.pixel_dir_y; + self->cam.sensor_dir_z = cam.sensor_dir_z; + self->cam.half_pixel_size = cam.half_pixel_size; + self->cam.focal_length = cam.focal_length; + self->cam.aperture_width = cam.aperture_width; + self->cam.aperture_height = cam.aperture_height; + self->cam.min_dist = cam.min_dist; + self->cam.max_dist = cam.max_dist; + self->cam.norm_fac = cam.norm_fac; + self->cam.principal_point_offset_x = cam.principal_point_offset_x; + self->cam.principal_point_offset_y = cam.principal_point_offset_y; + self->cam.film_border_left = cam.film_border_left; + self->cam.film_border_top = cam.film_border_top; +#ifdef PULSAR_TIMINGS_ENABLED + START_TIME(calc_signature); +#endif + LAUNCH_MAX_PARALLEL_1D( + calc_signature, + num_balls, + stream, + *self, + reinterpret_cast(vert_pos), + vert_col, + vert_rad, + num_balls); + CHECKLAUNCH(); +#ifdef PULSAR_TIMINGS_ENABLED + STOP_TIME(calc_signature); + START_TIME(sort); +#endif + SORT_ASCENDING_WS( + self->min_depth_d, + self->min_depth_sorted_d, + self->ids_d, + self->ids_sorted_d, + num_balls, + self->workspace_d, + self->workspace_size, + stream); + SORT_ASCENDING_WS( + self->min_depth_d, + self->min_depth_sorted_d, + self->ii_d, + self->ii_sorted_d, + num_balls, + self->workspace_d, + self->workspace_size, + stream); + SORT_ASCENDING_WS( + self->min_depth_d, + self->min_depth_sorted_d, + self->di_d, + self->di_sorted_d, + num_balls, + self->workspace_d, + self->workspace_size, + stream); + CHECKLAUNCH(); +#ifdef PULSAR_TIMINGS_ENABLED + STOP_TIME(sort); + START_TIME(minmax); +#endif + IntersectInfo pixel_minmax; + pixel_minmax.min.x = MAX_USHORT; + pixel_minmax.min.y = MAX_USHORT; + pixel_minmax.max.x = 0; + pixel_minmax.max.y = 0; + REDUCE_WS( + self->ii_sorted_d, + self->min_max_pixels_d, + num_balls, + IntersectInfoMinMax(), + pixel_minmax, + self->workspace_d, + self->workspace_size, + stream); + COPY_DEV_HOST(&pixel_minmax, self->min_max_pixels_d, IntersectInfo, 1); + LOG_IF(INFO, PULSAR_LOG_RENDER) + << "Region with pixels to render: " << pixel_minmax.min.x << ":" + << pixel_minmax.max.x << " (x), " << pixel_minmax.min.y << ":" + << pixel_minmax.max.y << " (y)."; +#ifdef PULSAR_TIMINGS_ENABLED + STOP_TIME(minmax); + START_TIME(render); +#endif + MEMSET( + self->result_d, + 0, + float, + self->cam.film_width * self->cam.film_height * self->cam.n_channels, + stream); + MEMSET( + self->forw_info_d, + 0, + float, + self->cam.film_width * self->cam.film_height * (3 + 2 * self->n_track), + stream); + if (pixel_minmax.max.y > pixel_minmax.min.y && + pixel_minmax.max.x > pixel_minmax.min.x) { + PASSERT( + pixel_minmax.min.x >= static_cast(self->cam.film_border_left) && + pixel_minmax.min.x < + static_cast( + self->cam.film_border_left + self->cam.film_width) && + pixel_minmax.max.x <= + static_cast( + self->cam.film_border_left + self->cam.film_width) && + pixel_minmax.min.y >= static_cast(self->cam.film_border_top) && + pixel_minmax.min.y < + static_cast( + self->cam.film_border_top + self->cam.film_height) && + pixel_minmax.max.y <= + static_cast( + self->cam.film_border_top + self->cam.film_height)); + // Cut the image in 3x3 regions. + int y_step = RENDER_BLOCK_SIZE * + iDivCeil(pixel_minmax.max.y - pixel_minmax.min.y, + 3u * RENDER_BLOCK_SIZE); + int x_step = RENDER_BLOCK_SIZE * + iDivCeil(pixel_minmax.max.x - pixel_minmax.min.x, + 3u * RENDER_BLOCK_SIZE); + LOG_IF(INFO, PULSAR_LOG_RENDER) << "Using image slices of size " << x_step + << ", " << y_step << " (W, H)."; + for (int y_min = pixel_minmax.min.y; y_min < pixel_minmax.max.y; + y_min += y_step) { + for (int x_min = pixel_minmax.min.x; x_min < pixel_minmax.max.x; + x_min += x_step) { + // Create region selection. + LAUNCH_MAX_PARALLEL_1D( + create_selector, + num_balls, + stream, + self->ii_sorted_d, + num_balls, + x_min, + x_min + x_step, + y_min, + y_min + y_step, + self->region_flags_d); + CHECKLAUNCH(); + SELECT_FLAGS_WS( + self->region_flags_d, + self->ii_sorted_d, + self->ii_d, + self->num_selected_d, + num_balls, + self->workspace_d, + self->workspace_size, + stream); + CHECKLAUNCH(); + SELECT_FLAGS_WS( + self->region_flags_d, + self->di_sorted_d, + self->di_d, + self->num_selected_d, + num_balls, + self->workspace_d, + self->workspace_size, + stream); + CHECKLAUNCH(); + SELECT_FLAGS_WS( + self->region_flags_d, + self->ids_sorted_d, + self->ids_d, + self->num_selected_d, + num_balls, + self->workspace_d, + self->workspace_size, + stream); + CHECKLAUNCH(); + LAUNCH_PARALLEL_2D( + render, + x_step, + y_step, + RENDER_BLOCK_SIZE, + RENDER_BLOCK_SIZE, + stream, + self->num_selected_d, + self->ii_d, + self->di_d, + self->min_depth_d, + self->ids_d, + opacity_d, + self->cam, + gamma, + percent_allowed_difference, + max_n_hits, + bg_col_d, + mode, + x_min, + y_min, + x_step, + y_step, + self->result_d, + self->forw_info_d, + self->n_track); + CHECKLAUNCH(); + } + } + } + if (mode == 0) { + LAUNCH_MAX_PARALLEL_2D( + fill_bg, + static_cast(self->cam.film_width), + static_cast(self->cam.film_height), + stream, + *self, + self->cam, + bg_col_d, + gamma, + mode); + CHECKLAUNCH(); + } +#ifdef PULSAR_TIMINGS_ENABLED + STOP_TIME(render); + float time_ms; + // This blocks the result and prevents batch-processing from parallelizing. + GET_TIME(calc_signature, &time_ms); + std::cout << "Time for signature calculation: " << time_ms << " ms" + << std::endl; + GET_TIME(sort, &time_ms); + std::cout << "Time for sorting: " << time_ms << " ms" << std::endl; + GET_TIME(minmax, &time_ms); + std::cout << "Time for minmax pixel calculation: " << time_ms << " ms" + << std::endl; + GET_TIME(render, &time_ms); + std::cout << "Time for rendering: " << time_ms << " ms" << std::endl; +#endif + LOG_IF(INFO, PULSAR_LOG_RENDER) << "Forward pass complete."; +} + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.forward.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.forward.instantiate.h new file mode 100644 index 00000000..31fae46b --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.forward.instantiate.h @@ -0,0 +1,23 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./renderer.forward.device.h" + +namespace pulsar { +namespace Renderer { + +template void forward( + Renderer* self, + const float* vert_pos, + const float* vert_col, + const float* vert_rad, + const CamInfo& cam, + const float& gamma, + float percent_allowed_difference, + const uint& max_n_hits, + const float* bg_col_d, + const float* opacity_d, + const size_t& num_balls, + const uint& mode, + cudaStream_t stream); + +} // namespace Renderer +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/include/renderer.get_screen_area.device.h b/pytorch3d/csrc/pulsar/include/renderer.get_screen_area.device.h new file mode 100644 index 00000000..e2b7f504 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.get_screen_area.device.h @@ -0,0 +1,137 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_GET_SCREEN_AREA_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_GET_SCREEN_AREA_DEVICE_H_ + +#include "../global.h" +#include "./camera.device.h" +#include "./commands.h" +#include "./math.h" + +namespace pulsar { +namespace Renderer { + +/** + * Find the closest enclosing screen area rectangle in pixels that encloses a + * ball. + * + * The method returns the two x and the two y values of the boundaries. They + * are not ordered yet and you need to find min and max for the left/right and + * lower/upper boundary. + * + * The return values are floats and need to be rounded appropriately. + */ +INLINE DEVICE bool get_screen_area( + const float3& ball_center_cam, + const float3& ray_center_norm, + const float& vert_rad, + const CamInfo& cam, + const uint& idx, + /* Out variables. */ + float* x_1, + float* x_2, + float* y_1, + float* y_2) { + float cos_alpha = dot(cam.sensor_dir_z, ray_center_norm); + float2 o__c_, alpha, theta; + if (cos_alpha < EPS) { + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|ball not visible. cos_alpha: %.9f.\n", + idx, + cos_alpha); + // No intersection, ball won't be visible. + return false; + } + // Multiply the direction vector with the camera rotation matrix + // to have the optical axis being the canonical z vector (0, 0, 1). + // TODO: optimize. + const float3 ball_center_cam_rot = rotate( + ball_center_cam, + cam.pixel_dir_x / length(cam.pixel_dir_x), + cam.pixel_dir_y / length(cam.pixel_dir_y), + cam.sensor_dir_z); + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|ball_center_cam_rot: %f, %f, %f.\n", + idx, + ball_center_cam.x, + ball_center_cam.y, + ball_center_cam.z); + const float pixel_size_norm_fac = FRCP(2.f * cam.half_pixel_size); + const float optical_offset_x = + (static_cast(cam.aperture_width) - 1.f) * .5f; + const float optical_offset_y = + (static_cast(cam.aperture_height) - 1.f) * .5f; + if (cam.orthogonal_projection) { + *x_1 = + FMA(ball_center_cam_rot.x - vert_rad, + pixel_size_norm_fac, + optical_offset_x); + *x_2 = + FMA(ball_center_cam_rot.x + vert_rad, + pixel_size_norm_fac, + optical_offset_x); + *y_1 = + FMA(ball_center_cam_rot.y - vert_rad, + pixel_size_norm_fac, + optical_offset_y); + *y_2 = + FMA(ball_center_cam_rot.y + vert_rad, + pixel_size_norm_fac, + optical_offset_y); + return true; + } else { + o__c_.x = FMAX( + FSQRT( + ball_center_cam_rot.x * ball_center_cam_rot.x + + ball_center_cam_rot.z * ball_center_cam_rot.z), + FEPS); + o__c_.y = FMAX( + FSQRT( + ball_center_cam_rot.y * ball_center_cam_rot.y + + ball_center_cam_rot.z * ball_center_cam_rot.z), + FEPS); + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|o__c_: %f, %f.\n", + idx, + o__c_.x, + o__c_.y); + alpha.x = sign_dir(ball_center_cam_rot.x) * + acos(FMIN(FMAX(ball_center_cam_rot.z / o__c_.x, -1.f), 1.f)); + alpha.y = -sign_dir(ball_center_cam_rot.y) * + acos(FMIN(FMAX(ball_center_cam_rot.z / o__c_.y, -1.f), 1.f)); + theta.x = asin(FMIN(FMAX(vert_rad / o__c_.x, -1.f), 1.f)); + theta.y = asin(FMIN(FMAX(vert_rad / o__c_.y, -1.f), 1.f)); + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|alpha.x: %f, alpha.y: %f, theta.x: %f, theta.y: %f.\n", + idx, + alpha.x, + alpha.y, + theta.x, + theta.y); + *x_1 = tan(alpha.x - theta.x) * cam.focal_length; + *x_2 = tan(alpha.x + theta.x) * cam.focal_length; + *y_1 = tan(alpha.y - theta.y) * cam.focal_length; + *y_2 = tan(alpha.y + theta.y) * cam.focal_length; + PULSAR_LOG_DEV( + PULSAR_LOG_CALC_SIGNATURE, + "signature %d|in sensor plane: x_1: %f, x_2: %f, y_1: %f, y_2: %f.\n", + idx, + *x_1, + *x_2, + *y_1, + *y_2); + *x_1 = FMA(*x_1, pixel_size_norm_fac, optical_offset_x); + *x_2 = FMA(*x_2, pixel_size_norm_fac, optical_offset_x); + *y_1 = FMA(*y_1, -pixel_size_norm_fac, optical_offset_y); + *y_2 = FMA(*y_2, -pixel_size_norm_fac, optical_offset_y); + return true; + } +}; + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.h b/pytorch3d/csrc/pulsar/include/renderer.h new file mode 100644 index 00000000..dfcfae53 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.h @@ -0,0 +1,461 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_H_ + +#include + +#include "../global.h" +#include "./camera.h" + +namespace pulsar { +namespace Renderer { + +//! Remember to order struct members from larger size to smaller size +//! to avoid padding (for more info, see for example here: +//! http://www.catb.org/esr/structure-packing/). + +/** + * This is the information that's needed to do a fast screen point + * intersection with one of the balls. + * + * Aim to keep this below 8 bytes (256 bytes per cache-line / 32 threads in a + * warp = 8 bytes per thread). + */ +struct IntersectInfo { + ushort2 min; /** minimum x, y in pixel coordinates. */ + ushort2 max; /** maximum x, y in pixel coordinates. */ +}; +static_assert( + sizeof(IntersectInfo) == 8, + "The compiled size of `IntersectInfo` is wrong."); + +/** + * Reduction operation to find the limits of multiple IntersectInfo objects. + */ +struct IntersectInfoMinMax { + IHD IntersectInfo + operator()(const IntersectInfo& a, const IntersectInfo& b) const { + // Treat the special case of an invalid intersect info object or one for + // a ball out of bounds. + if (b.max.x == MAX_USHORT && b.min.x == MAX_USHORT && + b.max.y == MAX_USHORT && b.min.y == MAX_USHORT) { + return a; + } + if (a.max.x == MAX_USHORT && a.min.x == MAX_USHORT && + a.max.y == MAX_USHORT && a.min.y == MAX_USHORT) { + return b; + } + IntersectInfo result; + result.min.x = std::min(a.min.x, b.min.x); + result.min.y = std::min(a.min.y, b.min.y); + result.max.x = std::max(a.max.x, b.max.x); + result.max.y = std::max(a.max.y, b.max.y); + return result; + } +}; + +/** + * All information that's needed to draw a ball. + * + * It's necessary to keep this information in float (not half) format, + * because the loss in accuracy would be too high and lead to artifacts. + */ +struct DrawInfo { + float3 ray_center_norm; /** Ray to the ball center, normalized. */ + /** Ball color. + * + * This might be the full color in the case of n_channels <= 3. Otherwise, + * a pointer to the original 'color' data is stored in the following union. + */ + float first_color; + union { + float color[2]; + float* ptr; + } color_union; + float t_center; /** Distance from the camera to the ball center. */ + float radius; /** Ball radius. */ +}; +static_assert( + sizeof(DrawInfo) == 8 * 4, + "The compiled size of `DrawInfo` is wrong."); + +/** + * An object to collect all associated data with the renderer. + * + * The `_d` suffixed pointers point to memory 'on-device', potentially on the + * GPU. All other variables are expected to point to CPU memory. + */ +struct Renderer { + /** Dummy initializer to make sure all pointers are set to NULL to + * be safe for the device-specific 'construct' and 'destruct' methods. + */ + inline Renderer() { + max_num_balls = 0; + result_d = NULL; + min_depth_d = NULL; + min_depth_sorted_d = NULL; + ii_d = NULL; + ii_sorted_d = NULL; + ids_d = NULL; + ids_sorted_d = NULL; + workspace_d = NULL; + di_d = NULL; + di_sorted_d = NULL; + region_flags_d = NULL; + num_selected_d = NULL; + forw_info_d = NULL; + grad_pos_d = NULL; + grad_col_d = NULL; + grad_rad_d = NULL; + grad_cam_d = NULL; + grad_opy_d = NULL; + grad_cam_buf_d = NULL; + n_grad_contributions_d = NULL; + }; + /** The camera for this renderer. In world-coordinates. */ + CamInfo cam; + /** + * The maximum amount of balls the renderer can handle. Resources are + * pre-allocated to account for this size. Less than this amount of balls + * can be rendered, but not more. + */ + int max_num_balls; + /** The result buffer. */ + float* result_d; + /** Closest possible intersection depth per sphere w.r.t. the camera. */ + float* min_depth_d; + /** Closest possible intersection depth per sphere, ordered ascending. */ + float* min_depth_sorted_d; + /** The intersect infos per sphere. */ + IntersectInfo* ii_d; + /** The intersect infos per sphere, ordered by their closest possible + * intersection depth (asc.). */ + IntersectInfo* ii_sorted_d; + /** Original sphere IDs. */ + int* ids_d; + /** Original sphere IDs, ordered by their closest possible intersection depth + * (asc.). */ + int* ids_sorted_d; + /** Workspace for CUB routines. */ + char* workspace_d; + /** Workspace size for CUB routines. */ + size_t workspace_size; + /** The draw information structures for each sphere. */ + DrawInfo* di_d; + /** The draw information structures sorted by closest possible intersection + * depth (asc.). */ + DrawInfo* di_sorted_d; + /** Region association buffer. */ + char* region_flags_d; + /** Num spheres in the current region. */ + size_t* num_selected_d; + /** Pointer to information from the forward pass. */ + float* forw_info_d; + /** Struct containing information about the min max pixels that contain + * rendered information in the image. */ + IntersectInfo* min_max_pixels_d; + /** Gradients w.r.t. position. */ + float3* grad_pos_d; + /** Gradients w.r.t. color. */ + float* grad_col_d; + /** Gradients w.r.t. radius. */ + float* grad_rad_d; + /** Gradients w.r.t. camera parameters. */ + float* grad_cam_d; + /** Gradients w.r.t. opacity. */ + float* grad_opy_d; + /** Camera gradient information by sphere. + * + * Here, every sphere's contribution to the camera gradients is stored. It is + * aggregated and written to grad_cam_d in a separate step. This avoids write + * conflicts when processing the spheres. + */ + CamGradInfo* grad_cam_buf_d; + /** Total of all gradient contributions for this image. */ + int* n_grad_contributions_d; + /** The number of spheres to track for backpropagation. */ + int n_track; +}; + +inline bool operator==(const Renderer& a, const Renderer& b) { + return a.cam == b.cam && a.max_num_balls == b.max_num_balls; +} + +/** + * Construct a renderer. + */ +template +void construct( + Renderer* self, + const size_t& max_num_balls, + const int& width, + const int& height, + const bool& orthogonal_projection, + const bool& right_handed_system, + const float& background_normalization_depth, + const uint& n_channels, + const uint& n_track); + +/** + * Destruct the renderer and free the associated memory. + */ +template +void destruct(Renderer* self); + +/** + * Create a selection of points inside a rectangle. + * + * This write boolen values into `region_flags_d', which can + * for example be used by a CUB function to extract the selection. + */ +template +GLOBAL void create_selector( + IntersectInfo const* const RESTRICT ii_sorted_d, + const uint num_balls, + const int min_x, + const int max_x, + const int min_y, + const int max_y, + /* Out variables. */ + char* RESTRICT region_flags_d); + +/** + * Calculate a signature for a ball. + * + * Populate the `ids_d`, `ii_d`, `di_d` and `min_depth_d` fields of the + * renderer. For spheres not visible in the image, sets the id field to -1, + * min_depth_d to MAX_FLOAT and the ii_d.min.x fields to MAX_USHORT. + */ +template +GLOBAL void calc_signature( + Renderer renderer, + float3 const* const RESTRICT vert_poss, + float const* const RESTRICT vert_cols, + float const* const RESTRICT vert_rads, + const uint num_balls); + +/** + * The block size for rendering. + * + * This should be as large as possible, but is limited due to the amount + * of variables we use and the memory required per thread. + */ +#define RENDER_BLOCK_SIZE 16 +/** + * The buffer size of spheres to be loaded and analyzed for relevance. + * + * This must be at least RENDER_BLOCK_SIZE * RENDER_BLOCK_SIZE so that + * for every iteration through the loading loop every thread could add a + * 'hit' to the buffer. + */ +#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE* RENDER_BLOCK_SIZE * 2 +/** + * The threshold after which the spheres that are in the render buffer + * are rendered and the buffer is flushed. + * + * Must be less than RENDER_BUFFER_SIZE. + */ +#define RENDER_BUFFER_LOAD_THRESH 16 * 4 + +/** + * The render function. + * + * Assumptions: + * * the focal length is appropriately chosen, + * * ray_dir_norm.z is > EPS. + * * to be completed... + */ +template +GLOBAL void render( + size_t const* const RESTRICT + num_balls, /** Number of balls relevant for this pass. */ + IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */ + DrawInfo const* const RESTRICT di_d, /** Draw information. */ + float const* const RESTRICT min_depth_d, /** Minimum depth per sphere. */ + int const* const RESTRICT id_d, /** IDs. */ + float const* const RESTRICT op_d, /** Opacity. */ + const CamInfo cam_norm, /** Camera normalized with all vectors to be in the + * camera coordinate system. + */ + const float gamma, /** Transparency parameter. **/ + const float percent_allowed_difference, /** Maximum allowed + error in color. */ + const uint max_n_hits, + const float* bg_col_d, + const uint mode, + const int x_min, + const int y_min, + const int x_step, + const int y_step, + // Out variables. + float* const RESTRICT result_d, /** The result image. */ + float* const RESTRICT forw_info_d, /** Additional information needed for the + grad computation. */ + // Infrastructure. + const int n_track /** The number of spheres to track. */ +); + +/** + * Makes sure to paint background information. + * + * This is required as a separate post-processing step because certain + * pixels may not be processed during the forward pass if there is no + * possibility for a sphere to be present at their location. + */ +template +GLOBAL void fill_bg( + Renderer renderer, + const CamInfo norm, + float const* const bg_col_d, + const float gamma, + const uint mode); + +/** + * Rendering forward pass. + * + * Takes a renderer and sphere data as inputs and creates a rendering. + */ +template +void forward( + Renderer* self, + const float* vert_pos, + const float* vert_col, + const float* vert_rad, + const CamInfo& cam, + const float& gamma, + float percent_allowed_difference, + const uint& max_n_hits, + const float* bg_col_d, + const float* opacity_d, + const size_t& num_balls, + const uint& mode, + cudaStream_t stream); + +/** + * Normalize the camera gradients by the number of spheres that contributed. + */ +template +GLOBAL void norm_cam_gradients(Renderer renderer); + +/** + * Normalize the sphere gradients. + * + * We're assuming that the samples originate from a Monte Carlo + * sampling process and normalize by number and sphere area. + */ +template +GLOBAL void norm_sphere_gradients(Renderer renderer, const int num_balls); + +#define GRAD_BLOCK_SIZE 16 +/** Calculate the gradients. + */ +template +GLOBAL void calc_gradients( + const CamInfo cam, /** Camera in world coordinates. */ + float const* const RESTRICT grad_im, /** The gradient image. */ + const float + gamma, /** The transparency parameter used in the forward pass. */ + float3 const* const RESTRICT vert_poss, /** Vertex position vector. */ + float const* const RESTRICT vert_cols, /** Vertex color vector. */ + float const* const RESTRICT vert_rads, /** Vertex radius vector. */ + float const* const RESTRICT opacity, /** Vertex opacity. */ + const uint num_balls, /** Number of balls. */ + float const* const RESTRICT result_d, /** Result image. */ + float const* const RESTRICT forw_info_d, /** Forward pass info. */ + DrawInfo const* const RESTRICT di_d, /** Draw information. */ + IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */ + // Mode switches. + const bool calc_grad_pos, + const bool calc_grad_col, + const bool calc_grad_rad, + const bool calc_grad_cam, + const bool calc_grad_opy, + // Out variables. + float* const RESTRICT grad_rad_d, /** Radius gradients. */ + float* const RESTRICT grad_col_d, /** Color gradients. */ + float3* const RESTRICT grad_pos_d, /** Position gradients. */ + CamGradInfo* const RESTRICT grad_cam_buf_d, /** Camera gradient buffer. */ + float* const RESTRICT grad_opy_d, /** Opacity gradient buffer. */ + int* const RESTRICT + grad_contributed_d, /** Gradient contribution counter. */ + // Infrastructure. + const int n_track, + const uint offs_x = 0, + const uint offs_y = 0); + +/** + * A full backward pass. + * + * Creates the gradients for the given gradient_image and the spheres. + */ +template +void backward( + Renderer* self, + const float* grad_im, + const float* image, + const float* forw_info, + const float* vert_pos, + const float* vert_col, + const float* vert_rad, + const CamInfo& cam, + const float& gamma, + float percent_allowed_difference, + const uint& max_n_hits, + const float* vert_opy, + const size_t& num_balls, + const uint& mode, + const bool& dif_pos, + const bool& dif_col, + const bool& dif_rad, + const bool& dif_cam, + const bool& dif_opy, + cudaStream_t stream); + +/** + * A debug backward pass. + * + * This is a function to debug the gradient calculation. It calculates the + * gradients for exactly one pixel (set with pos_x and pos_y) without averaging. + * + * *Uses only the first sphere for camera gradient calculation!* + */ +template +void backward_dbg( + Renderer* self, + const float* grad_im, + const float* image, + const float* forw_info, + const float* vert_pos, + const float* vert_col, + const float* vert_rad, + const CamInfo& cam, + const float& gamma, + float percent_allowed_difference, + const uint& max_n_hits, + const float* vert_opy, + const size_t& num_balls, + const uint& mode, + const bool& dif_pos, + const bool& dif_col, + const bool& dif_rad, + const bool& dif_cam, + const bool& dif_opy, + const uint& pos_x, + const uint& pos_y, + cudaStream_t stream); + +template +void nn( + const float* ref_ptr, + const float* tar_ptr, + const uint& k, + const uint& d, + const uint& n, + float* dist_ptr, + int32_t* inds_ptr, + cudaStream_t stream); + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.norm_cam_gradients.device.h b/pytorch3d/csrc/pulsar/include/renderer.norm_cam_gradients.device.h new file mode 100644 index 00000000..14f6a21f --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.norm_cam_gradients.device.h @@ -0,0 +1,28 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_NORM_CAM_GRADIENTS_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_NORM_CAM_GRADIENTS_DEVICE_H_ + +#include "../global.h" +#include "./camera.device.h" +#include "./commands.h" +#include "./math.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +/** + * Normalize the camera gradients by the number of spheres that contributed. + */ +template +GLOBAL void norm_cam_gradients(Renderer renderer) { + GET_PARALLEL_IDX_1D(idx, 1); + CamGradInfo* cgi = reinterpret_cast(renderer.grad_cam_d); + *cgi = *cgi * FRCP(static_cast(*renderer.n_grad_contributions_d)); + END_PARALLEL_NORET(); +}; + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.norm_cam_gradients.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.norm_cam_gradients.instantiate.h new file mode 100644 index 00000000..77433245 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.norm_cam_gradients.instantiate.h @@ -0,0 +1,10 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./renderer.norm_cam_gradients.device.h" + +namespace pulsar { +namespace Renderer { + +template GLOBAL void norm_cam_gradients(Renderer renderer); + +} // namespace Renderer +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/include/renderer.norm_sphere_gradients.device.h b/pytorch3d/csrc/pulsar/include/renderer.norm_sphere_gradients.device.h new file mode 100644 index 00000000..1d68a303 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.norm_sphere_gradients.device.h @@ -0,0 +1,68 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_NORM_SPHERE_GRADIENTS_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_NORM_SPHERE_GRADIENTS_H_ + +#include "../global.h" +#include "./commands.h" +#include "./math.h" +#include "./renderer.h" + +namespace pulsar { +namespace Renderer { + +/** + * Normalize the sphere gradients. + * + * We're assuming that the samples originate from a Monte Carlo + * sampling process and normalize by number and sphere area. + */ +template +GLOBAL void norm_sphere_gradients(Renderer renderer, const int num_balls) { + GET_PARALLEL_IDX_1D(idx, num_balls); + float norm_fac = 0.f; + IntersectInfo ii; + if (renderer.ids_sorted_d[idx] > 0) { + ii = renderer.ii_d[idx]; + // Normalize the sphere gradients as averages. + // This avoids the case that there are small spheres in a scene with still + // un-converged colors whereas the big spheres already converged, just + // because their integrated learning rate is 'higher'. + norm_fac = FRCP(static_cast(renderer.ids_sorted_d[idx])); + } + PULSAR_LOG_DEV_NODE( + PULSAR_LOG_NORMALIZE, + "ids_sorted_d[idx]: %d, norm_fac: %.9f.\n", + renderer.ids_sorted_d[idx], + norm_fac); + renderer.grad_rad_d[idx] *= norm_fac; + for (uint c_idx = 0; c_idx < renderer.cam.n_channels; ++c_idx) { + renderer.grad_col_d[idx * renderer.cam.n_channels + c_idx] *= norm_fac; + } + renderer.grad_pos_d[idx] *= norm_fac; + renderer.grad_opy_d[idx] *= norm_fac; + + if (renderer.ids_sorted_d[idx] > 0) { + // For the camera, we need to be more correct and have the gradients + // be proportional to the area they cover in the image. + // This leads to a formulation very much like in monte carlo integration: + norm_fac = FRCP(static_cast(renderer.ids_sorted_d[idx])) * + (static_cast(ii.max.x) - static_cast(ii.min.x)) * + (static_cast(ii.max.y) - static_cast(ii.min.y)) * + 1e-3f; // for better numerics. + } + renderer.grad_cam_buf_d[idx].cam_pos *= norm_fac; + renderer.grad_cam_buf_d[idx].pixel_0_0_center *= norm_fac; + renderer.grad_cam_buf_d[idx].pixel_dir_x *= norm_fac; + renderer.grad_cam_buf_d[idx].pixel_dir_y *= norm_fac; + // The sphere only contributes to the camera gradients if it is + // large enough in screen space. + if (renderer.ids_sorted_d[idx] > 0 && ii.max.x >= ii.min.x + 3 && + ii.max.y >= ii.min.y + 3) + renderer.ids_sorted_d[idx] = 1; + END_PARALLEL_NORET(); +}; + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.norm_sphere_gradients.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.norm_sphere_gradients.instantiate.h new file mode 100644 index 00000000..2a48aa07 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.norm_sphere_gradients.instantiate.h @@ -0,0 +1,12 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./renderer.norm_sphere_gradients.device.h" + +namespace pulsar { +namespace Renderer { + +template GLOBAL void norm_sphere_gradients( + Renderer renderer, + const int num_balls); + +} // namespace Renderer +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/include/renderer.render.device.h b/pytorch3d/csrc/pulsar/include/renderer.render.device.h new file mode 100644 index 00000000..95adf9af --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.render.device.h @@ -0,0 +1,409 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_DEVICE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_DEVICE_H_ + +#include "../global.h" +#include "./camera.device.h" +#include "./commands.h" +#include "./math.h" +#include "./renderer.h" + +#include "./closest_sphere_tracker.device.h" +#include "./renderer.draw.device.h" + +namespace pulsar { +namespace Renderer { + +template +GLOBAL void render( + size_t const* const RESTRICT + num_balls, /** Number of balls relevant for this pass. */ + IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */ + DrawInfo const* const RESTRICT di_d, /** Draw information. */ + float const* const RESTRICT min_depth_d, /** Minimum depth per sphere. */ + int const* const RESTRICT ids_d, /** IDs. */ + float const* const RESTRICT op_d, /** Opacity. */ + const CamInfo cam_norm, /** Camera normalized with all vectors to be in the + * camera coordinate system. + */ + const float gamma, /** Transparency parameter. **/ + const float percent_allowed_difference, /** Maximum allowed + error in color. */ + const uint max_n_hits, + const float* bg_col, + const uint mode, + const int x_min, + const int y_min, + const int x_step, + const int y_step, + // Out variables. + float* const RESTRICT result_d, /** The result image. */ + float* const RESTRICT forw_info_d, /** Additional information needed for the + grad computation. */ + const int n_track /** The number of spheres to track for backprop. */ +) { + // Do not early stop threads in this block here. They can all contribute to + // the scanning process, we just have to prevent from writing their result. + GET_PARALLEL_IDS_2D(offs_x, offs_y, x_step, y_step); + // Variable declarations and const initializations. + const float ln_pad_over_1minuspad = + FLN(percent_allowed_difference / (1.f - percent_allowed_difference)); + /** A facility to track the closest spheres to the camera + (in preparation for gradient calculation). */ + ClosestSphereTracker tracker(n_track); + const uint coord_x = x_min + offs_x; /** Ray coordinate x. */ + const uint coord_y = y_min + offs_y; /** Ray coordinate y. */ + float3 ray_dir_norm; /** Ray cast through the pixel, normalized. */ + float2 projected_ray; /** Ray intersection with the sensor. */ + if (cam_norm.orthogonal_projection) { + ray_dir_norm = cam_norm.sensor_dir_z; + projected_ray.x = static_cast(coord_x); + projected_ray.y = static_cast(coord_y); + } else { + ray_dir_norm = normalize( + cam_norm.pixel_0_0_center + coord_x * cam_norm.pixel_dir_x + + coord_y * cam_norm.pixel_dir_y); + // This is a reasonable assumption for normal focal lengths and image sizes. + PASSERT(FABS(ray_dir_norm.z) > FEPS); + projected_ray.x = ray_dir_norm.x / ray_dir_norm.z * cam_norm.focal_length; + projected_ray.y = ray_dir_norm.y / ray_dir_norm.z * cam_norm.focal_length; + } + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_RENDER_PIX, + "render|ray_dir_norm: %.9f, %.9f, %.9f. projected_ray: %.9f, %.9f.\n", + ray_dir_norm.x, + ray_dir_norm.y, + ray_dir_norm.z, + projected_ray.x, + projected_ray.y); + // Set up shared infrastructure. + /** This entire thread block. */ + cg::thread_block thread_block = cg::this_thread_block(); + /** The collaborators within a warp. */ + cg::coalesced_group thread_warp = cg::coalesced_threads(); + /** The number of loaded balls in the load buffer di_l. */ + SHARED uint n_loaded; + /** Draw information buffer. */ + SHARED DrawInfo di_l[RENDER_BUFFER_SIZE]; + /** The original sphere id of each loaded sphere. */ + SHARED uint sphere_id_l[RENDER_BUFFER_SIZE]; + /** The number of pixels in this block that are done. */ + SHARED int n_pixels_done; + /** Whether loading of balls is completed. */ + SHARED bool loading_done; + /** The number of balls loaded overall (just for statistics). */ + SHARED int n_balls_loaded; + /** The area this thread block covers. */ + SHARED IntersectInfo block_area; + if (thread_block.thread_rank() == 0) { + // Initialize the shared variables. + n_loaded = 0; + block_area.min.x = static_cast(coord_x); + block_area.max.x = static_cast(IMIN( + coord_x + blockDim.x, cam_norm.film_border_left + cam_norm.film_width)); + block_area.min.y = static_cast(coord_y); + block_area.max.y = static_cast(IMIN( + coord_y + blockDim.y, cam_norm.film_border_top + cam_norm.film_height)); + n_pixels_done = 0; + loading_done = false; + n_balls_loaded = 0; + } + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_RENDER_PIX, + "render|block_area.min: %d, %d. block_area.max: %d, %d.\n", + block_area.min.x, + block_area.min.y, + block_area.max.x, + block_area.max.y); + // Initialization of the pixel with the background color. + /** + * The result of this very pixel. + * the offset calculation might overflow if this thread is out of + * bounds of the film. However, in this case result is not + * accessed, so this is fine. + */ + float* result = result_d + + (coord_y - cam_norm.film_border_top) * cam_norm.film_width * + cam_norm.n_channels + + (coord_x - cam_norm.film_border_left) * cam_norm.n_channels; + if (coord_x >= cam_norm.film_border_left && + coord_x < cam_norm.film_border_left + cam_norm.film_width && + coord_y >= cam_norm.film_border_top && + coord_y < cam_norm.film_border_top + cam_norm.film_height) { + // Initialize the result. + if (mode == 0u) { + for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id) + result[c_id] = bg_col[c_id]; + } else { + result[0] = 0.f; + } + } + /** Normalization denominator. */ + float sm_d = 1.f; + /** Normalization tracker for stable softmax. The maximum observed value. */ + float sm_m = cam_norm.background_normalization_depth / gamma; + /** Whether this pixel has had all information needed for drawing. */ + bool done = + (coord_x < cam_norm.film_border_left || + coord_x >= cam_norm.film_border_left + cam_norm.film_width || + coord_y < cam_norm.film_border_top || + coord_y >= cam_norm.film_border_top + cam_norm.film_height); + /** The depth threshold for a new point to have at least + * `percent_allowed_difference` influence on the result color. All points that + * are further away than this are ignored. + */ + float depth_threshold = done ? -1.f : MAX_FLOAT; + /** The closest intersection possible of a ball that was hit by this pixel + * ray. */ + float max_closest_possible_intersection_hit = -1.f; + bool hit; /** Whether a sphere was hit. */ + float intersection_depth; /** The intersection_depth for a sphere at this + pixel. */ + float closest_possible_intersection; /** The closest possible intersection + for this sphere. */ + float max_closest_possible_intersection; + // Sync up threads so that everyone is similarly initialized. + thread_block.sync(); + //! Coalesced loading and intersection analysis of balls. + for (uint ball_idx = thread_block.thread_rank(); + ball_idx < iDivCeil(static_cast(*num_balls), thread_block.size()) * + thread_block.size() && + !loading_done && n_pixels_done < thread_block.size(); + ball_idx += thread_block.size()) { + if (ball_idx < static_cast(*num_balls)) { // Account for overflow. + const IntersectInfo& ii = ii_d[ball_idx]; + hit = (ii.min.x <= block_area.max.x) && (ii.max.x > block_area.min.x) && + (ii.min.y <= block_area.max.y) && (ii.max.y > block_area.min.y); + if (hit) { + uint write_idx = ATOMICADD_B(&n_loaded, 1u); + di_l[write_idx] = di_d[ball_idx]; + sphere_id_l[write_idx] = static_cast(ids_d[ball_idx]); + PULSAR_LOG_DEV_PIXB( + PULSAR_LOG_RENDER_PIX, + "render|found intersection with sphere %u.\n", + sphere_id_l[write_idx]); + } + if (ii.min.x == MAX_USHORT) + // This is an invalid sphere (out of image). These spheres have + // maximum depth. Since we ordered the spheres by earliest possible + // intersection depth we re certain that there will no other sphere + // that is relevant after this one. + loading_done = true; + } + // Reset n_pixels_done. + n_pixels_done = 0; + thread_block.sync(); // Make sure n_loaded is updated. + if (n_loaded > RENDER_BUFFER_LOAD_THRESH) { + // The load buffer is full enough. Draw. + if (thread_block.thread_rank() == 0) + n_balls_loaded += n_loaded; + max_closest_possible_intersection = 0.f; + // This excludes threads outside of the image boundary. Also, it reduces + // block artifacts. + if (!done) { + for (uint draw_idx = 0; draw_idx < n_loaded; ++draw_idx) { + intersection_depth = 0.f; + if (cam_norm.orthogonal_projection) { + // The closest possible intersection is the distance to the camera + // plane. + closest_possible_intersection = min_depth_d[sphere_id_l[draw_idx]]; + } else { + closest_possible_intersection = + di_l[draw_idx].t_center - di_l[draw_idx].radius; + } + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_RENDER_PIX, + "render|drawing sphere %u (depth: %f, " + "closest possible intersection: %f).\n", + sphere_id_l[draw_idx], + di_l[draw_idx].t_center, + closest_possible_intersection); + hit = draw( + di_l[draw_idx], // Sphere to draw. + op_d == NULL ? 1.f : op_d[sphere_id_l[draw_idx]], // Opacity. + cam_norm, // Cam. + gamma, // Gamma. + ray_dir_norm, // Ray direction. + projected_ray, // Ray intersection with the image. + // Mode switches. + true, // Draw. + false, + false, + false, + false, + false, // No gradients. + // Position info. + coord_x, + coord_y, + sphere_id_l[draw_idx], + // Optional in variables. + NULL, // intersect information. + NULL, // ray_dir. + NULL, // norm_ray_dir. + NULL, // grad_pix. + &ln_pad_over_1minuspad, + // in/out variables + &sm_d, + &sm_m, + result, + // Optional out. + &depth_threshold, + &intersection_depth, + NULL, + NULL, + NULL, + NULL, + NULL // gradients. + ); + if (hit) { + max_closest_possible_intersection_hit = FMAX( + max_closest_possible_intersection_hit, + closest_possible_intersection); + tracker.track( + sphere_id_l[draw_idx], intersection_depth, coord_x, coord_y); + } + max_closest_possible_intersection = FMAX( + max_closest_possible_intersection, closest_possible_intersection); + } + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_RENDER_PIX, + "render|max_closest_possible_intersection: %f, " + "depth_threshold: %f.\n", + max_closest_possible_intersection, + depth_threshold); + } + done = done || + (percent_allowed_difference > 0.f && + max_closest_possible_intersection > depth_threshold) || + tracker.get_n_hits() >= max_n_hits; + uint warp_done = thread_warp.ballot(done); + if (thread_warp.thread_rank() == 0) + ATOMICADD_B(&n_pixels_done, POPC(warp_done)); + // This sync is necessary to keep n_loaded until all threads are done with + // painting. + thread_block.sync(); + n_loaded = 0; + } + thread_block.sync(); + } + if (thread_block.thread_rank() == 0) + n_balls_loaded += n_loaded; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_RENDER_PIX, + "render|loaded %d balls in total.\n", + n_balls_loaded); + if (!done) { + for (uint draw_idx = 0; draw_idx < n_loaded; ++draw_idx) { + intersection_depth = 0.f; + if (cam_norm.orthogonal_projection) { + // The closest possible intersection is the distance to the camera + // plane. + closest_possible_intersection = min_depth_d[sphere_id_l[draw_idx]]; + } else { + closest_possible_intersection = + di_l[draw_idx].t_center - di_l[draw_idx].radius; + } + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_RENDER_PIX, + "render|drawing sphere %u (depth: %f, " + "closest possible intersection: %f).\n", + sphere_id_l[draw_idx], + di_l[draw_idx].t_center, + closest_possible_intersection); + hit = draw( + di_l[draw_idx], // Sphere to draw. + op_d == NULL ? 1.f : op_d[sphere_id_l[draw_idx]], // Opacity. + cam_norm, // Cam. + gamma, // Gamma. + ray_dir_norm, // Ray direction. + projected_ray, // Ray intersection with the image. + // Mode switches. + true, // Draw. + false, + false, + false, + false, + false, // No gradients. + // Logging info. + coord_x, + coord_y, + sphere_id_l[draw_idx], + // Optional in variables. + NULL, // intersect information. + NULL, // ray_dir. + NULL, // norm_ray_dir. + NULL, // grad_pix. + &ln_pad_over_1minuspad, + // in/out variables + &sm_d, + &sm_m, + result, + // Optional out. + &depth_threshold, + &intersection_depth, + NULL, + NULL, + NULL, + NULL, + NULL // gradients. + ); + if (hit) { + max_closest_possible_intersection_hit = FMAX( + max_closest_possible_intersection_hit, + closest_possible_intersection); + tracker.track( + sphere_id_l[draw_idx], intersection_depth, coord_x, coord_y); + } + } + } + if (coord_x < cam_norm.film_border_left || + coord_y < cam_norm.film_border_top || + coord_x >= cam_norm.film_border_left + cam_norm.film_width || + coord_y >= cam_norm.film_border_top + cam_norm.film_height) { + RETURN_PARALLEL(); + } + if (mode == 1u) { + // The subtractions, for example coord_y - cam_norm.film_border_left, are + // safe even though both components are uints. We checked their relation + // just above. + result_d + [(coord_y - cam_norm.film_border_top) * cam_norm.film_width * + cam_norm.n_channels + + (coord_x - cam_norm.film_border_left) * cam_norm.n_channels] = + static_cast(tracker.get_n_hits()); + } else { + float sm_d_normfac = FRCP(FMAX(sm_d, FEPS)); + for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id) + result[c_id] *= sm_d_normfac; + int write_loc = (coord_y - cam_norm.film_border_top) * cam_norm.film_width * + (3 + 2 * n_track) + + (coord_x - cam_norm.film_border_left) * (3 + 2 * n_track); + forw_info_d[write_loc] = sm_m; + forw_info_d[write_loc + 1] = sm_d; + forw_info_d[write_loc + 2] = max_closest_possible_intersection_hit; + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_RENDER_PIX, + "render|writing the %d most important ball infos.\n", + IMIN(n_track, tracker.get_n_hits())); + for (int i = 0; i < n_track; ++i) { + int sphere_id = tracker.get_closest_sphere_id(i); + IASF(sphere_id, forw_info_d[write_loc + 3 + i * 2]); + forw_info_d[write_loc + 3 + i * 2 + 1] = + tracker.get_closest_sphere_depth(i) == MAX_FLOAT + ? -1.f + : tracker.get_closest_sphere_depth(i); + PULSAR_LOG_DEV_PIX( + PULSAR_LOG_RENDER_PIX, + "render|writing %d most important: id: %d, normalized depth: %f.\n", + i, + tracker.get_closest_sphere_id(i), + tracker.get_closest_sphere_depth(i)); + } + } + END_PARALLEL_2D(); +} + +} // namespace Renderer +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/include/renderer.render.instantiate.h b/pytorch3d/csrc/pulsar/include/renderer.render.instantiate.h new file mode 100644 index 00000000..3379bae4 --- /dev/null +++ b/pytorch3d/csrc/pulsar/include/renderer.render.instantiate.h @@ -0,0 +1,39 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_INSTANTIATE_H_ +#define PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_INSTANTIATE_H_ + +#include "./renderer.render.device.h" + +namespace pulsar { +namespace Renderer { +template GLOBAL void render( + size_t const* const RESTRICT + num_balls, /** Number of balls relevant for this pass. */ + IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */ + DrawInfo const* const RESTRICT di_d, /** Draw information. */ + float const* const RESTRICT min_depth_d, /** Minimum depth per sphere. */ + int const* const RESTRICT id_d, /** IDs. */ + float const* const RESTRICT op_d, /** Opacity. */ + const CamInfo cam_norm, /** Camera normalized with all vectors to be in the + * camera coordinate system. + */ + const float gamma, /** Transparency parameter. **/ + const float percent_allowed_difference, /** Maximum allowed + error in color. */ + const uint max_n_hits, + const float* bg_col_d, + const uint mode, + const int x_min, + const int y_min, + const int x_step, + const int y_step, + // Out variables. + float* const RESTRICT result_d, /** The result image. */ + float* const RESTRICT forw_info_d, /** Additional information needed for the + grad computation. */ + const int n_track /** The number of spheres to track for backprop. */ +); +} +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/logging.h b/pytorch3d/csrc/pulsar/logging.h new file mode 100644 index 00000000..977e0c9d --- /dev/null +++ b/pytorch3d/csrc/pulsar/logging.h @@ -0,0 +1,108 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_LOGGING_H_ +#define PULSAR_LOGGING_H_ + +// #define PULSAR_LOGGING_ENABLED +/** + * Enable detailed per-operation timings. + * + * This timing scheme is not appropriate to measure batched calculations. + * Use `PULSAR_TIMINGS_BATCHED_ENABLED` for that. + */ +// #define PULSAR_TIMINGS_ENABLED +/** + * Time batched operations. + */ +// #define PULSAR_TIMINGS_BATCHED_ENABLED +#if defined(PULSAR_TIMINGS_BATCHED_ENABLED) && defined(PULSAR_TIMINGS_ENABLED) +#pragma message("Pulsar|batched and unbatched timings enabled. This will not") +#pragma message("Pulsar|create meaningful results.") +#endif + +#ifdef PULSAR_LOGGING_ENABLED + +// Control logging. +// 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL (Abort after logging). +#define CAFFE2_LOG_THRESHOLD 0 +#define PULSAR_LOG_INIT false +#define PULSAR_LOG_FORWARD false +#define PULSAR_LOG_CALC_SIGNATURE false +#define PULSAR_LOG_RENDER false +#define PULSAR_LOG_RENDER_PIX false +#define PULSAR_LOG_RENDER_PIX_X 428 +#define PULSAR_LOG_RENDER_PIX_Y 669 +#define PULSAR_LOG_RENDER_PIX_ALL false +#define PULSAR_LOG_TRACKER_PIX false +#define PULSAR_LOG_TRACKER_PIX_X 428 +#define PULSAR_LOG_TRACKER_PIX_Y 669 +#define PULSAR_LOG_TRACKER_PIX_ALL false +#define PULSAR_LOG_DRAW_PIX false +#define PULSAR_LOG_DRAW_PIX_X 428 +#define PULSAR_LOG_DRAW_PIX_Y 669 +#define PULSAR_LOG_DRAW_PIX_ALL false +#define PULSAR_LOG_BACKWARD false +#define PULSAR_LOG_GRAD false +#define PULSAR_LOG_GRAD_X 509 +#define PULSAR_LOG_GRAD_Y 489 +#define PULSAR_LOG_GRAD_ALL false +#define PULSAR_LOG_NORMALIZE false +#define PULSAR_LOG_NORMALIZE_X 0 +#define PULSAR_LOG_NORMALIZE_ALL false + +#define PULSAR_LOG_DEV(ID, ...) \ + if ((ID)) { \ + printf(__VA_ARGS__); \ + } +#define PULSAR_LOG_DEV_APIX(ID, MSG, ...) \ + if ((ID) && (film_coord_x == (ID##_X) && film_coord_y == (ID##_Y)) || \ + ID##_ALL) { \ + printf( \ + "%u %u (ap %u %u)|" MSG, \ + film_coord_x, \ + film_coord_y, \ + ap_coord_x, \ + ap_coord_y, \ + __VA_ARGS__); \ + } +#define PULSAR_LOG_DEV_PIX(ID, MSG, ...) \ + if ((ID) && (coord_x == (ID##_X) && coord_y == (ID##_Y)) || ID##_ALL) { \ + printf("%u %u|" MSG, coord_x, coord_y, __VA_ARGS__); \ + } +#ifdef __CUDACC__ +#define PULSAR_LOG_DEV_PIXB(ID, MSG, ...) \ + if ((ID) && static_cast(block_area.min.x) <= (ID##_X) && \ + static_cast(block_area.max.x) > (ID##_X) && \ + static_cast(block_area.min.y) <= (ID##_Y) && \ + static_cast(block_area.max.y) > (ID##_Y)) { \ + printf("%u %u|" MSG, coord_x, coord_y, __VA_ARGS__); \ + } +#else +#define PULSAR_LOG_DEV_PIXB(ID, MSG, ...) \ + if ((ID) && coord_x == (ID##_X) && coord_y == (ID##_Y)) { \ + printf("%u %u|" MSG, coord_x, coord_y, __VA_ARGS__); \ + } +#endif +#define PULSAR_LOG_DEV_NODE(ID, MSG, ...) \ + if ((ID) && idx == (ID##_X) || (ID##_ALL)) { \ + printf("%u|" MSG, idx, __VA_ARGS__); \ + } + +#else + +#define CAFFE2_LOG_THRESHOLD 2 + +#define PULSAR_LOG_RENDER false +#define PULSAR_LOG_INIT false +#define PULSAR_LOG_FORWARD false +#define PULSAR_LOG_BACKWARD false +#define PULSAR_LOG_TRACKER_PIX false + +#define PULSAR_LOG_DEV(...) +#define PULSAR_LOG_DEV_APIX(...) +#define PULSAR_LOG_DEV_PIX(...) +#define PULSAR_LOG_DEV_PIXB(...) +#define PULSAR_LOG_DEV_NODE(...) + +#endif + +#endif diff --git a/pytorch3d/csrc/pulsar/pytorch/camera.cpp b/pytorch3d/csrc/pulsar/pytorch/camera.cpp new file mode 100644 index 00000000..3f1cfca4 --- /dev/null +++ b/pytorch3d/csrc/pulsar/pytorch/camera.cpp @@ -0,0 +1,63 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./camera.h" +#include "../include/math.h" + +namespace pulsar { +namespace pytorch { + +CamInfo cam_info_from_params( + const torch::Tensor& cam_pos, + const torch::Tensor& pixel_0_0_center, + const torch::Tensor& pixel_vec_x, + const torch::Tensor& pixel_vec_y, + const torch::Tensor& principal_point_offset, + const float& focal_length, + const uint& width, + const uint& height, + const float& min_dist, + const float& max_dist, + const bool& right_handed) { + CamInfo res; + fill_cam_vecs( + cam_pos.detach().cpu(), + pixel_0_0_center.detach().cpu(), + pixel_vec_x.detach().cpu(), + pixel_vec_y.detach().cpu(), + principal_point_offset.detach().cpu(), + right_handed, + &res); + res.half_pixel_size = 0.5f * length(res.pixel_dir_x); + if (length(res.pixel_dir_y) * 0.5f - res.half_pixel_size > EPS) { + throw std::runtime_error("Pixel sizes must agree in x and y direction!"); + } + res.focal_length = focal_length; + res.aperture_width = + width + 2u * static_cast(abs(res.principal_point_offset_x)); + res.aperture_height = + height + 2u * static_cast(abs(res.principal_point_offset_y)); + res.pixel_0_0_center -= + res.pixel_dir_x * static_cast(abs(res.principal_point_offset_x)); + res.pixel_0_0_center -= + res.pixel_dir_y * static_cast(abs(res.principal_point_offset_y)); + res.film_width = width; + res.film_height = height; + res.film_border_left = + static_cast(std::max(0, 2 * res.principal_point_offset_x)); + res.film_border_top = + static_cast(std::max(0, 2 * res.principal_point_offset_y)); + LOG_IF(INFO, PULSAR_LOG_INIT) + << "Aperture width, height: " << res.aperture_width << ", " + << res.aperture_height; + LOG_IF(INFO, PULSAR_LOG_INIT) + << "Film width, height: " << res.film_width << ", " << res.film_height; + LOG_IF(INFO, PULSAR_LOG_INIT) + << "Film border left, top: " << res.film_border_left << ", " + << res.film_border_top; + res.min_dist = min_dist; + res.max_dist = max_dist; + res.norm_fac = 1.f / (max_dist - min_dist); + return res; +}; + +} // namespace pytorch +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/pytorch/camera.h b/pytorch3d/csrc/pulsar/pytorch/camera.h new file mode 100644 index 00000000..f7312364 --- /dev/null +++ b/pytorch3d/csrc/pulsar/pytorch/camera.h @@ -0,0 +1,61 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_CAMERA_H_ +#define PULSAR_NATIVE_CAMERA_H_ + +#include +#include "../global.h" + +#include "../include/camera.h" + +namespace pulsar { +namespace pytorch { + +inline void fill_cam_vecs( + const torch::Tensor& pos_vec, + const torch::Tensor& pixel_0_0_center, + const torch::Tensor& pixel_dir_x, + const torch::Tensor& pixel_dir_y, + const torch::Tensor& principal_point_offset, + const bool& right_handed, + CamInfo* res) { + res->eye.x = pos_vec.data_ptr()[0]; + res->eye.y = pos_vec.data_ptr()[1]; + res->eye.z = pos_vec.data_ptr()[2]; + res->pixel_0_0_center.x = pixel_0_0_center.data_ptr()[0]; + res->pixel_0_0_center.y = pixel_0_0_center.data_ptr()[1]; + res->pixel_0_0_center.z = pixel_0_0_center.data_ptr()[2]; + res->pixel_dir_x.x = pixel_dir_x.data_ptr()[0]; + res->pixel_dir_x.y = pixel_dir_x.data_ptr()[1]; + res->pixel_dir_x.z = pixel_dir_x.data_ptr()[2]; + res->pixel_dir_y.x = pixel_dir_y.data_ptr()[0]; + res->pixel_dir_y.y = pixel_dir_y.data_ptr()[1]; + res->pixel_dir_y.z = pixel_dir_y.data_ptr()[2]; + auto sensor_dir_z = pixel_dir_y.cross(pixel_dir_x); + sensor_dir_z /= sensor_dir_z.norm(); + if (right_handed) { + sensor_dir_z *= -1.f; + } + res->sensor_dir_z.x = sensor_dir_z.data_ptr()[0]; + res->sensor_dir_z.y = sensor_dir_z.data_ptr()[1]; + res->sensor_dir_z.z = sensor_dir_z.data_ptr()[2]; + res->principal_point_offset_x = principal_point_offset.data_ptr()[0]; + res->principal_point_offset_y = principal_point_offset.data_ptr()[1]; +} + +CamInfo cam_info_from_params( + const torch::Tensor& cam_pos, + const torch::Tensor& pixel_0_0_center, + const torch::Tensor& pixel_vec_x, + const torch::Tensor& pixel_vec_y, + const torch::Tensor& principal_point_offset, + const float& focal_length, + const uint& width, + const uint& height, + const float& min_dist, + const float& max_dist, + const bool& right_handed); + +} // namespace pytorch +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/pytorch/renderer.cpp b/pytorch3d/csrc/pulsar/pytorch/renderer.cpp new file mode 100644 index 00000000..b44f6dc5 --- /dev/null +++ b/pytorch3d/csrc/pulsar/pytorch/renderer.cpp @@ -0,0 +1,1481 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./renderer.h" +#include "../include/commands.h" +#include "./camera.h" +#include "./util.h" + +#include +#include +#include + +namespace PRE = ::pulsar::Renderer; + +namespace pulsar { +namespace pytorch { + +Renderer::Renderer( + const unsigned int& width, + const unsigned int& height, + const unsigned int& max_n_balls, + const bool& orthogonal_projection, + const bool& right_handed_system, + const float& background_normalization_depth, + const uint& n_channels, + const uint& n_track) { + LOG_IF(INFO, PULSAR_LOG_INIT) << "Initializing renderer."; + THArgCheck(width > 0, 1, "image width must be > 0!"); + THArgCheck(height > 0, 2, "image height must be > 0!"); + THArgCheck(max_n_balls > 0, 3, "max_n_balls must be > 0!"); + THArgCheck( + background_normalization_depth > 0.f && + background_normalization_depth < 1.f, + 5, + "background_normalization_depth must be in ]0., 1.["); + THArgCheck(n_channels > 0, 6, "n_channels must be > 0"); + THArgCheck( + n_track > 0 && n_track <= MAX_GRAD_SPHERES, + 7, + ("n_track must be > 0 and <" + std::to_string(MAX_GRAD_SPHERES) + + ". Is " + std::to_string(n_track) + ".") + .c_str()); + LOG_IF(INFO, PULSAR_LOG_INIT) + << "Image width: " << width << ", height: " << height; + this->renderer_vec.emplace_back(); + this->device_type = c10::DeviceType::CPU; + this->device_index = -1; + PRE::construct( + this->renderer_vec.data(), + max_n_balls, + width, + height, + orthogonal_projection, + right_handed_system, + background_normalization_depth, + n_channels, + n_track); + this->device_tracker = torch::zeros(1); +}; + +Renderer::~Renderer() { + if (this->device_type == c10::DeviceType::CUDA) { + at::cuda::CUDAGuard device_guard(this->device_tracker.device()); + for (auto nrend : this->renderer_vec) { + PRE::destruct(&nrend); + } + } else { + for (auto nrend : this->renderer_vec) { + PRE::destruct(&nrend); + } + } +} + +bool Renderer::operator==(const Renderer& rhs) const { + LOG_IF(INFO, PULSAR_LOG_INIT) << "Equality check."; + bool renderer_agrees = (this->renderer_vec[0] == rhs.renderer_vec[0]); + LOG_IF(INFO, PULSAR_LOG_INIT) << " Renderer agrees: " << renderer_agrees; + bool device_agrees = + (this->device_tracker.device() == rhs.device_tracker.device()); + LOG_IF(INFO, PULSAR_LOG_INIT) << " Device agrees: " << device_agrees; + return (renderer_agrees && device_agrees); +}; + +void Renderer::ensure_on_device(torch::Device device, bool /*non_blocking*/) { + THArgCheck( + device.type() == c10::DeviceType::CUDA || + device.type() == c10::DeviceType::CPU, + 1, + "Only CPU and CUDA device types are supported."); + if (device.type() != this->device_type || + device.index() != this->device_index) { + LOG_IF(INFO, PULSAR_LOG_INIT) + << "Transferring render buffers between devices."; + int prev_active; + cudaGetDevice(&prev_active); + if (this->device_type == c10::DeviceType::CUDA) { + LOG_IF(INFO, PULSAR_LOG_INIT) << " Destructing on CUDA."; + cudaSetDevice(this->device_index); + for (auto& nrend : this->renderer_vec) { + PRE::destruct(&nrend); + } + } else { + LOG_IF(INFO, PULSAR_LOG_INIT) << " Destructing on CPU."; + for (auto& nrend : this->renderer_vec) { + PRE::destruct(&nrend); + } + } + if (device.type() == c10::DeviceType::CUDA) { + LOG_IF(INFO, PULSAR_LOG_INIT) << " Constructing on CUDA."; + cudaSetDevice(device.index()); + for (auto& nrend : this->renderer_vec) { + PRE::construct( + &nrend, + this->renderer_vec[0].max_num_balls, + this->renderer_vec[0].cam.film_width, + this->renderer_vec[0].cam.film_height, + this->renderer_vec[0].cam.orthogonal_projection, + this->renderer_vec[0].cam.right_handed, + this->renderer_vec[0].cam.background_normalization_depth, + this->renderer_vec[0].cam.n_channels, + this->n_track()); + } + } else { + LOG_IF(INFO, PULSAR_LOG_INIT) << " Constructing on CPU."; + for (auto& nrend : this->renderer_vec) { + PRE::construct( + &nrend, + this->renderer_vec[0].max_num_balls, + this->renderer_vec[0].cam.film_width, + this->renderer_vec[0].cam.film_height, + this->renderer_vec[0].cam.orthogonal_projection, + this->renderer_vec[0].cam.right_handed, + this->renderer_vec[0].cam.background_normalization_depth, + this->renderer_vec[0].cam.n_channels, + this->n_track()); + } + } + cudaSetDevice(prev_active); + this->device_type = device.type(); + this->device_index = device.index(); + } +}; + +void Renderer::ensure_n_renderers_gte(const size_t& batch_size) { + if (this->renderer_vec.size() < batch_size) { + ptrdiff_t diff = batch_size - this->renderer_vec.size(); + LOG_IF(INFO, PULSAR_LOG_INIT) + << "Increasing render buffers by " << diff + << " to account for batch size " << batch_size; + for (ptrdiff_t i = 0; i < diff; ++i) { + this->renderer_vec.emplace_back(); + if (this->device_type == c10::DeviceType::CUDA) { + PRE::construct( + &this->renderer_vec[this->renderer_vec.size() - 1], + this->max_num_balls(), + this->width(), + this->height(), + this->renderer_vec[0].cam.orthogonal_projection, + this->renderer_vec[0].cam.right_handed, + this->renderer_vec[0].cam.background_normalization_depth, + this->renderer_vec[0].cam.n_channels, + this->n_track()); + } else { + PRE::construct( + &this->renderer_vec[this->renderer_vec.size() - 1], + this->max_num_balls(), + this->width(), + this->height(), + this->renderer_vec[0].cam.orthogonal_projection, + this->renderer_vec[0].cam.right_handed, + this->renderer_vec[0].cam.background_normalization_depth, + this->renderer_vec[0].cam.n_channels, + this->n_track()); + } + } + } +} + +std::tuple Renderer::arg_check( + const torch::Tensor& vert_pos, + const torch::Tensor& vert_col, + const torch::Tensor& vert_radii, + const torch::Tensor& cam_pos, + const torch::Tensor& pixel_0_0_center, + const torch::Tensor& pixel_vec_x, + const torch::Tensor& pixel_vec_y, + const torch::Tensor& focal_length, + const torch::Tensor& principal_point_offsets, + const float& gamma, + const float& max_depth, + float& min_depth, + const c10::optional& bg_col, + const c10::optional& opacity, + const float& percent_allowed_difference, + const uint& max_n_hits, + const uint& mode) { + LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD) << "Arg check."; + size_t batch_size = 1; + size_t n_points; + bool batch_processing = false; + if (vert_pos.ndimension() == 3) { + // Check all parameters adhere batch size. + batch_processing = true; + batch_size = vert_pos.size(0); + THArgCheck( + vert_col.ndimension() == 3 && vert_col.size(0) == batch_size, + 2, + "vert_col needs to have batch size."); + THArgCheck( + vert_radii.ndimension() == 2 && vert_radii.size(0) == batch_size, + 3, + "vert_radii must be specified per batch."); + THArgCheck( + cam_pos.ndimension() == 2 && cam_pos.size(0) == batch_size, + 4, + "cam_pos must be specified per batch and have the correct batch size."); + THArgCheck( + pixel_0_0_center.ndimension() == 2 && + pixel_0_0_center.size(0) == batch_size, + 5, + "pixel_0_0_center must be specified per batch."); + THArgCheck( + pixel_vec_x.ndimension() == 2 && pixel_vec_x.size(0) == batch_size, + 6, + "pixel_vec_x must be specified per batch."); + THArgCheck( + pixel_vec_y.ndimension() == 2 && pixel_vec_y.size(0) == batch_size, + 7, + "pixel_vec_y must be specified per batch."); + THArgCheck( + focal_length.ndimension() == 1 && focal_length.size(0) == batch_size, + 8, + "focal_length must be specified per batch."); + THArgCheck( + principal_point_offsets.ndimension() == 2 && + principal_point_offsets.size(0) == batch_size, + 9, + "principal_point_offsets must be specified per batch."); + if (opacity.has_value()) { + THArgCheck( + opacity.value().ndimension() == 2 && + opacity.value().size(0) == batch_size, + 13, + "Opacity needs to be specified batch-wise."); + } + // Check all parameters are for a matching number of points. + n_points = vert_pos.size(1); + THArgCheck( + vert_col.size(1) == n_points, + 2, + ("The number of points for vertex positions (" + + std::to_string(n_points) + ") and vertex colors (" + + std::to_string(vert_col.size(1)) + ") doesn't agree.") + .c_str()); + THArgCheck( + vert_radii.size(1) == n_points, + 3, + ("The number of points for vertex positions (" + + std::to_string(n_points) + ") and vertex radii (" + + std::to_string(vert_col.size(1)) + ") doesn't agree.") + .c_str()); + if (opacity.has_value()) { + THArgCheck( + opacity.value().size(1) == n_points, + 13, + "Opacity needs to be specified per point."); + } + // Check all parameters have the correct last dimension size. + THArgCheck( + vert_pos.size(2) == 3, + 1, + ("Vertex positions must be 3D (have shape " + + std::to_string(vert_pos.size(2)) + ")!") + .c_str()); + THArgCheck( + vert_col.size(2) == this->renderer_vec[0].cam.n_channels, + 2, + ("Vertex colors must have the right number of channels (have shape " + + std::to_string(vert_col.size(2)) + ", need " + + std::to_string(this->renderer_vec[0].cam.n_channels) + ")!") + .c_str()); + THArgCheck( + cam_pos.size(1) == 3, + 4, + ("Camera position must be 3D (has shape " + + std::to_string(cam_pos.size(1)) + ")!") + .c_str()); + THArgCheck( + pixel_0_0_center.size(1) == 3, + 5, + ("pixel_0_0_center must be 3D (has shape " + + std::to_string(pixel_0_0_center.size(1)) + ")!") + .c_str()); + THArgCheck( + pixel_vec_x.size(1) == 3, + 6, + ("pixel_vec_x must be 3D (has shape " + + std::to_string(pixel_vec_x.size(1)) + ")!") + .c_str()); + THArgCheck( + pixel_vec_y.size(1) == 3, + 7, + ("pixel_vec_y must be 3D (has shape " + + std::to_string(pixel_vec_y.size(1)) + ")!") + .c_str()); + THArgCheck( + principal_point_offsets.size(1) == 2, + 9, + "principal_point_offsets must contain x and y offsets."); + // Ensure enough renderers are available for the batch. + ensure_n_renderers_gte(batch_size); + } else { + // Check all parameters are of correct dimension. + THArgCheck( + vert_col.ndimension() == 2, 2, "vert_col needs to have dimension 2."); + THArgCheck( + vert_radii.ndimension() == 1, 3, "vert_radii must have dimension 1."); + THArgCheck(cam_pos.ndimension() == 1, 4, "cam_pos must have dimension 1."); + THArgCheck( + pixel_0_0_center.ndimension() == 1, + 5, + "pixel_0_0_center must have dimension 1."); + THArgCheck( + pixel_vec_x.ndimension() == 1, 6, "pixel_vec_x must have dimension 1."); + THArgCheck( + pixel_vec_y.ndimension() == 1, 7, "pixel_vec_y must have dimension 1."); + THArgCheck( + focal_length.ndimension() == 0, + 8, + "focal_length must have dimension 0."); + THArgCheck( + principal_point_offsets.ndimension() == 1, + 9, + "principal_point_offsets must have dimension 1."); + if (opacity.has_value()) { + THArgCheck( + opacity.value().ndimension() == 1, + 13, + "Opacity needs to be specified per sample."); + } + // Check each. + n_points = vert_pos.size(0); + THArgCheck( + vert_col.size(0) == n_points, + 2, + ("The number of points for vertex positions (" + + std::to_string(n_points) + ") and vertex colors (" + + std::to_string(vert_col.size(0)) + ") doesn't agree.") + .c_str()); + THArgCheck( + vert_radii.size(0) == n_points, + 3, + ("The number of points for vertex positions (" + + std::to_string(n_points) + ") and vertex radii (" + + std::to_string(vert_col.size(0)) + ") doesn't agree.") + .c_str()); + if (opacity.has_value()) { + THArgCheck( + opacity.value().size(0) == n_points, + 12, + "Opacity needs to be specified per point."); + } + // Check all parameters have the correct last dimension size. + THArgCheck( + vert_pos.size(1) == 3, + 1, + ("Vertex positions must be 3D (have shape " + + std::to_string(vert_pos.size(1)) + ")!") + .c_str()); + THArgCheck( + vert_col.size(1) == this->renderer_vec[0].cam.n_channels, + 2, + ("Vertex colors must have the right number of channels (have shape " + + std::to_string(vert_col.size(1)) + ", need " + + std::to_string(this->renderer_vec[0].cam.n_channels) + ")!") + .c_str()); + THArgCheck( + cam_pos.size(0) == 3, + 4, + ("Camera position must be 3D (has shape " + + std::to_string(cam_pos.size(0)) + ")!") + .c_str()); + THArgCheck( + pixel_0_0_center.size(0) == 3, + 5, + ("pixel_0_0_center must be 3D (has shape " + + std::to_string(pixel_0_0_center.size(0)) + ")!") + .c_str()); + THArgCheck( + pixel_vec_x.size(0) == 3, + 6, + ("pixel_vec_x must be 3D (has shape " + + std::to_string(pixel_vec_x.size(0)) + ")!") + .c_str()); + THArgCheck( + pixel_vec_y.size(0) == 3, + 7, + ("pixel_vec_y must be 3D (has shape " + + std::to_string(pixel_vec_y.size(0)) + ")!") + .c_str()); + THArgCheck( + principal_point_offsets.size(0) == 2, + 9, + "principal_point_offsets must have x and y component."); + } + // Check device placement. + auto dev = torch::device_of(vert_pos).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 1, + ("Vertex positions must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + dev = torch::device_of(vert_col).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 2, + ("Vertex colors must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + dev = torch::device_of(vert_radii).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 3, + ("Vertex radii must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + dev = torch::device_of(cam_pos).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 4, + ("Camera position must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + dev = torch::device_of(pixel_0_0_center).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 5, + ("pixel_0_0_center must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + dev = torch::device_of(pixel_vec_x).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 6, + ("pixel_vec_x must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + dev = torch::device_of(pixel_vec_y).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 7, + ("pixel_vec_y must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + dev = torch::device_of(principal_point_offsets).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 9, + ("principal_point_offsets must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + if (opacity.has_value()) { + dev = torch::device_of(opacity.value()).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 13, + ("opacity must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Is stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + } + // Type checks. + THArgCheck( + vert_pos.scalar_type() == c10::kFloat, 1, "pulsar requires float types."); + THArgCheck( + vert_col.scalar_type() == c10::kFloat, 2, "pulsar requires float types."); + THArgCheck( + vert_radii.scalar_type() == c10::kFloat, + 3, + "pulsar requires float types."); + THArgCheck( + cam_pos.scalar_type() == c10::kFloat, 4, "pulsar requires float types."); + THArgCheck( + pixel_0_0_center.scalar_type() == c10::kFloat, + 5, + "pulsar requires float types."); + THArgCheck( + pixel_vec_x.scalar_type() == c10::kFloat, + 6, + "pulsar requires float types."); + THArgCheck( + pixel_vec_y.scalar_type() == c10::kFloat, + 7, + "pulsar requires float types."); + THArgCheck( + focal_length.scalar_type() == c10::kFloat, + 8, + "pulsar requires float types."); + THArgCheck( + // Unfortunately, the PyTorch interface is inconsistent for + // Int32: in Python, there exists an explicit int32 type, in + // C++ this is currently `c10::kInt`. + principal_point_offsets.scalar_type() == c10::kInt, + 9, + "principal_point_offsets must be provided as int32."); + if (opacity.has_value()) { + THArgCheck( + opacity.value().scalar_type() == c10::kFloat, + 13, + "opacity must be a float type."); + } + // Content checks. + THArgCheck( + (vert_radii > FEPS).all().item(), + 3, + ("Vertex radii must be > FEPS (min is " + + std::to_string(vert_radii.min().item()) + ").") + .c_str()); + if (this->orthogonal()) { + THArgCheck( + (focal_length == 0.f).all().item(), + 8, + ("for an orthogonal projection focal length must be zero (abs max: " + + std::to_string(focal_length.abs().max().item()) + ").") + .c_str()); + } else { + THArgCheck( + (focal_length > FEPS).all().item(), + 8, + ("for a perspective projection focal length must be > FEPS (min " + + std::to_string(focal_length.min().item()) + ").") + .c_str()); + } + THArgCheck( + gamma <= 1.f && gamma >= 1E-5f, + 10, + ("gamma must be in [1E-5, 1] (" + std::to_string(gamma) + ").").c_str()); + if (min_depth == 0.f) { + min_depth = focal_length.max().item() + 2.f * FEPS; + } + THArgCheck( + min_depth > focal_length.max().item(), + 12, + ("min_depth must be > focal_length (" + std::to_string(min_depth) + + " vs. " + std::to_string(focal_length.max().item()) + ").") + .c_str()); + THArgCheck( + max_depth > min_depth + FEPS, + 11, + ("max_depth must be > min_depth + FEPS (" + std::to_string(max_depth) + + " vs. " + std::to_string(min_depth + FEPS) + ").") + .c_str()); + THArgCheck( + percent_allowed_difference >= 0.f && percent_allowed_difference < 1.f, + 14, + ("percent_allowed_difference must be in [0., 1.[ (" + + std::to_string(percent_allowed_difference) + ").") + .c_str()); + THArgCheck(max_n_hits > 0, 14, "max_n_hits must be > 0!"); + THArgCheck(mode < 2, 15, "mode must be in {0, 1}."); + torch::Tensor real_bg_col; + if (bg_col.has_value()) { + THArgCheck( + bg_col.value().device().type() == this->device_type && + bg_col.value().device().index() == this->device_index, + 13, + "bg_col must be stored on the renderer device!"); + THArgCheck( + bg_col.value().ndimension() == 1 && + bg_col.value().size(0) == renderer_vec[0].cam.n_channels, + 13, + "bg_col must have the same number of channels as the image,)."); + real_bg_col = bg_col.value(); + } else { + real_bg_col = torch::ones( + {renderer_vec[0].cam.n_channels}, + c10::Device(this->device_type, this->device_index)) + .to(c10::kFloat); + } + if (opacity.has_value()) { + THArgCheck( + (opacity.value() >= 0.f).all().item(), + 13, + "opacity must be >= 0."); + THArgCheck( + (opacity.value() <= 1.f).all().item(), + 13, + "opacity must be <= 1."); + } + LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD) + << " batch_size: " << batch_size; + LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD) + << " n_points: " << n_points; + LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD) + << " batch_processing: " << batch_processing; + return std::tuple( + batch_size, n_points, batch_processing, real_bg_col); +} + +std::tuple Renderer::forward( + const torch::Tensor& vert_pos, + const torch::Tensor& vert_col, + const torch::Tensor& vert_radii, + const torch::Tensor& cam_pos, + const torch::Tensor& pixel_0_0_center, + const torch::Tensor& pixel_vec_x, + const torch::Tensor& pixel_vec_y, + const torch::Tensor& focal_length, + const torch::Tensor& principal_point_offsets, + const float& gamma, + const float& max_depth, + float min_depth, + const c10::optional& bg_col, + const c10::optional& opacity, + const float& percent_allowed_difference, + const uint& max_n_hits, + const uint& mode) { + // Parameter checks. + this->ensure_on_device(this->device_tracker.device()); + size_t batch_size; + size_t n_points; + bool batch_processing; + torch::Tensor real_bg_col; + std::tie(batch_size, n_points, batch_processing, real_bg_col) = + this->arg_check( + vert_pos, + vert_col, + vert_radii, + cam_pos, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + focal_length, + principal_point_offsets, + gamma, + max_depth, + min_depth, + bg_col, + opacity, + percent_allowed_difference, + max_n_hits, + mode); + LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Extracting camera objects..."; + // Create the camera information. + std::vector cam_infos(batch_size); + if (batch_processing) { + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + cam_infos[batch_i] = cam_info_from_params( + cam_pos[batch_i], + pixel_0_0_center[batch_i], + pixel_vec_x[batch_i], + pixel_vec_y[batch_i], + principal_point_offsets[batch_i], + focal_length[batch_i].item(), + this->renderer_vec[0].cam.film_width, + this->renderer_vec[0].cam.film_height, + min_depth, + max_depth, + this->renderer_vec[0].cam.right_handed); + } + } else { + cam_infos[0] = cam_info_from_params( + cam_pos, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + principal_point_offsets, + focal_length.item(), + this->renderer_vec[0].cam.film_width, + this->renderer_vec[0].cam.film_height, + min_depth, + max_depth, + this->renderer_vec[0].cam.right_handed); + } + LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Processing..."; + // Let's go! + // Contiguous version of opacity, if available. We need to create this object + // in scope to keep it alive. + torch::Tensor opacity_contiguous; + float const* opacity_ptr = nullptr; + if (opacity.has_value()) { + opacity_contiguous = opacity.value().contiguous(); + opacity_ptr = opacity_contiguous.data_ptr(); + } + if (this->device_type == c10::DeviceType::CUDA) { + int prev_active; + cudaGetDevice(&prev_active); + cudaSetDevice(this->device_index); +#ifdef PULSAR_TIMINGS_BATCHED_ENABLED + START_TIME_CU(batch_forward); +#endif + if (batch_processing) { + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + // These calls are non-blocking and just kick off the computations. + PRE::forward( + &this->renderer_vec[batch_i], + vert_pos[batch_i].contiguous().data_ptr(), + vert_col[batch_i].contiguous().data_ptr(), + vert_radii[batch_i].contiguous().data_ptr(), + cam_infos[batch_i], + gamma, + percent_allowed_difference, + max_n_hits, + real_bg_col.contiguous().data_ptr(), + opacity_ptr, + n_points, + mode, + at::cuda::getCurrentCUDAStream()); + } + } else { + PRE::forward( + this->renderer_vec.data(), + vert_pos.contiguous().data_ptr(), + vert_col.contiguous().data_ptr(), + vert_radii.contiguous().data_ptr(), + cam_infos[0], + gamma, + percent_allowed_difference, + max_n_hits, + real_bg_col.contiguous().data_ptr(), + opacity_ptr, + n_points, + mode, + at::cuda::getCurrentCUDAStream()); + } +#ifdef PULSAR_TIMINGS_BATCHED_ENABLED + STOP_TIME_CU(batch_forward); + float time_ms; + GET_TIME_CU(batch_forward, &time_ms); + std::cout << "Forward render batched time per example: " + << time_ms / static_cast(batch_size) << "ms" << std::endl; +#endif + cudaSetDevice(prev_active); + } else { +#ifdef PULSAR_TIMINGS_BATCHED_ENABLED + START_TIME(batch_forward); +#endif + if (batch_processing) { + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + // These calls are non-blocking and just kick off the computations. + PRE::forward( + &this->renderer_vec[batch_i], + vert_pos[batch_i].contiguous().data_ptr(), + vert_col[batch_i].contiguous().data_ptr(), + vert_radii[batch_i].contiguous().data_ptr(), + cam_infos[batch_i], + gamma, + percent_allowed_difference, + max_n_hits, + real_bg_col.contiguous().data_ptr(), + opacity_ptr, + n_points, + mode, + nullptr); + } + } else { + PRE::forward( + this->renderer_vec.data(), + vert_pos.contiguous().data_ptr(), + vert_col.contiguous().data_ptr(), + vert_radii.contiguous().data_ptr(), + cam_infos[0], + gamma, + percent_allowed_difference, + max_n_hits, + real_bg_col.contiguous().data_ptr(), + opacity_ptr, + n_points, + mode, + nullptr); + } +#ifdef PULSAR_TIMINGS_BATCHED_ENABLED + STOP_TIME(batch_forward); + float time_ms; + GET_TIME(batch_forward, &time_ms); + std::cout << "Forward render batched time per example: " + << time_ms / static_cast(batch_size) << "ms" << std::endl; +#endif + } + LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Extracting results..."; + // Create the results. + std::vector results(batch_size); + std::vector forw_infos(batch_size); + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + results[batch_i] = from_blob( + this->renderer_vec[batch_i].result_d, + {this->renderer_vec[0].cam.film_height, + this->renderer_vec[0].cam.film_width, + this->renderer_vec[0].cam.n_channels}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + if (mode == 1) + results[batch_i] = results[batch_i].slice(2, 0, 1, 1); + forw_infos[batch_i] = from_blob( + this->renderer_vec[batch_i].forw_info_d, + {this->renderer_vec[0].cam.film_height, + this->renderer_vec[0].cam.film_width, + 3 + 2 * this->n_track()}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Forward render complete."; + if (batch_processing) { + return std::tuple( + torch::stack(results), torch::stack(forw_infos)); + } else { + return std::tuple(results[0], forw_infos[0]); + } +}; + +std::tuple< + at::optional, + at::optional, + at::optional, + at::optional, + at::optional, + at::optional, + at::optional, + at::optional> +Renderer::backward( + const torch::Tensor& grad_im, + const torch::Tensor& image, + const torch::Tensor& forw_info, + const torch::Tensor& vert_pos, + const torch::Tensor& vert_col, + const torch::Tensor& vert_radii, + const torch::Tensor& cam_pos, + const torch::Tensor& pixel_0_0_center, + const torch::Tensor& pixel_vec_x, + const torch::Tensor& pixel_vec_y, + const torch::Tensor& focal_length, + const torch::Tensor& principal_point_offsets, + const float& gamma, + const float& max_depth, + float min_depth, + const c10::optional& bg_col, + const c10::optional& opacity, + const float& percent_allowed_difference, + const uint& max_n_hits, + const uint& mode, + const bool& dif_pos, + const bool& dif_col, + const bool& dif_rad, + const bool& dif_cam, + const bool& dif_opy, + const at::optional>& dbg_pos) { + this->ensure_on_device(this->device_tracker.device()); + size_t batch_size; + size_t n_points; + bool batch_processing; + torch::Tensor real_bg_col; + std::tie(batch_size, n_points, batch_processing, real_bg_col) = + this->arg_check( + vert_pos, + vert_col, + vert_radii, + cam_pos, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + focal_length, + principal_point_offsets, + gamma, + max_depth, + min_depth, + bg_col, + opacity, + percent_allowed_difference, + max_n_hits, + mode); + // Additional checks for the gradient computation. + THArgCheck( + (grad_im.ndimension() == 3 + batch_processing && + static_cast(grad_im.size(0 + batch_processing)) == + this->height() && + static_cast(grad_im.size(1 + batch_processing)) == this->width() && + static_cast(grad_im.size(2 + batch_processing)) == + this->renderer_vec[0].cam.n_channels), + 1, + "The gradient image size is not correct."); + THArgCheck( + (image.ndimension() == 3 + batch_processing && + static_cast(image.size(0 + batch_processing)) == this->height() && + static_cast(image.size(1 + batch_processing)) == this->width() && + static_cast(image.size(2 + batch_processing)) == + this->renderer_vec[0].cam.n_channels), + 2, + "The result image size is not correct."); + THArgCheck( + grad_im.scalar_type() == c10::kFloat, + 1, + "The gradient image must be of float type."); + THArgCheck( + image.scalar_type() == c10::kFloat, + 2, + "The image must be of float type."); + if (dif_opy) { + THArgCheck(opacity.has_value(), 13, "dif_opy set requires opacity values."); + } + if (batch_processing) { + THArgCheck( + grad_im.size(0) == batch_size, + 1, + "Gradient image batch size must agree."); + THArgCheck(image.size(0) == batch_size, 2, "Image batch size must agree."); + THArgCheck( + forw_info.size(0) == batch_size, + 3, + "forward info must have batch size."); + } + THArgCheck( + (forw_info.ndimension() == 3 + batch_processing && + static_cast(forw_info.size(0 + batch_processing)) == + this->height() && + static_cast(forw_info.size(1 + batch_processing)) == + this->width() && + static_cast(forw_info.size(2 + batch_processing)) == + 3 + 2 * this->n_track()), + 3, + "The forward info image size is not correct."); + THArgCheck( + forw_info.scalar_type() == c10::kFloat, + 3, + "The forward info must be of float type."); + // Check device. + auto dev = torch::device_of(grad_im).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 1, + ("grad_im must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + dev = torch::device_of(image).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 2, + ("image must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + dev = torch::device_of(forw_info).value(); + THArgCheck( + dev.type() == this->device_type && dev.index() == this->device_index, + 3, + ("forw_info must be stored on device " + + c10::DeviceTypeName(this->device_type) + ", index " + + std::to_string(this->device_index) + "! Are stored on " + + c10::DeviceTypeName(dev.type()) + ", index " + + std::to_string(dev.index()) + ".") + .c_str()); + if (dbg_pos.has_value()) { + THArgCheck( + dbg_pos.value().first < this->width() && + dbg_pos.value().second < this->height(), + 23, + "The debug position must be within image bounds."); + } + // Prepare the return value. + std::tuple< + at::optional, + at::optional, + at::optional, + at::optional, + at::optional, + at::optional, + at::optional, + at::optional> + ret; + if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam)) { + return ret; + } + // Create the camera information. + std::vector cam_infos(batch_size); + if (batch_processing) { + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + cam_infos[batch_i] = cam_info_from_params( + cam_pos[batch_i], + pixel_0_0_center[batch_i], + pixel_vec_x[batch_i], + pixel_vec_y[batch_i], + principal_point_offsets[batch_i], + focal_length[batch_i].item(), + this->renderer_vec[0].cam.film_width, + this->renderer_vec[0].cam.film_height, + min_depth, + max_depth, + this->renderer_vec[0].cam.right_handed); + } + } else { + cam_infos[0] = cam_info_from_params( + cam_pos, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + principal_point_offsets, + focal_length.item(), + this->renderer_vec[0].cam.film_width, + this->renderer_vec[0].cam.film_height, + min_depth, + max_depth, + this->renderer_vec[0].cam.right_handed); + } + // Let's go! + // Contiguous version of opacity, if available. We need to create this object + // in scope to keep it alive. + torch::Tensor opacity_contiguous; + float const* opacity_ptr = nullptr; + if (opacity.has_value()) { + opacity_contiguous = opacity.value().contiguous(); + opacity_ptr = opacity_contiguous.data_ptr(); + } + if (this->device_type == c10::DeviceType::CUDA) { + int prev_active; + cudaGetDevice(&prev_active); + cudaSetDevice(this->device_index); +#ifdef PULSAR_TIMINGS_BATCHED_ENABLED + START_TIME_CU(batch_backward); +#endif + if (batch_processing) { + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + // These calls are non-blocking and just kick off the computations. + if (dbg_pos.has_value()) { + PRE::backward_dbg( + &this->renderer_vec[batch_i], + grad_im[batch_i].contiguous().data_ptr(), + image[batch_i].contiguous().data_ptr(), + forw_info[batch_i].contiguous().data_ptr(), + vert_pos[batch_i].contiguous().data_ptr(), + vert_col[batch_i].contiguous().data_ptr(), + vert_radii[batch_i].contiguous().data_ptr(), + cam_infos[batch_i], + gamma, + percent_allowed_difference, + max_n_hits, + opacity_ptr, + n_points, + mode, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + dbg_pos.value().first, + dbg_pos.value().second, + at::cuda::getCurrentCUDAStream()); + } else { + PRE::backward( + &this->renderer_vec[batch_i], + grad_im[batch_i].contiguous().data_ptr(), + image[batch_i].contiguous().data_ptr(), + forw_info[batch_i].contiguous().data_ptr(), + vert_pos[batch_i].contiguous().data_ptr(), + vert_col[batch_i].contiguous().data_ptr(), + vert_radii[batch_i].contiguous().data_ptr(), + cam_infos[batch_i], + gamma, + percent_allowed_difference, + max_n_hits, + opacity_ptr, + n_points, + mode, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + at::cuda::getCurrentCUDAStream()); + } + } + } else { + if (dbg_pos.has_value()) { + PRE::backward_dbg( + this->renderer_vec.data(), + grad_im.contiguous().data_ptr(), + image.contiguous().data_ptr(), + forw_info.contiguous().data_ptr(), + vert_pos.contiguous().data_ptr(), + vert_col.contiguous().data_ptr(), + vert_radii.contiguous().data_ptr(), + cam_infos[0], + gamma, + percent_allowed_difference, + max_n_hits, + opacity_ptr, + n_points, + mode, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + dbg_pos.value().first, + dbg_pos.value().second, + at::cuda::getCurrentCUDAStream()); + } else { + PRE::backward( + this->renderer_vec.data(), + grad_im.contiguous().data_ptr(), + image.contiguous().data_ptr(), + forw_info.contiguous().data_ptr(), + vert_pos.contiguous().data_ptr(), + vert_col.contiguous().data_ptr(), + vert_radii.contiguous().data_ptr(), + cam_infos[0], + gamma, + percent_allowed_difference, + max_n_hits, + opacity_ptr, + n_points, + mode, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + at::cuda::getCurrentCUDAStream()); + } + } + cudaSetDevice(prev_active); +#ifdef PULSAR_TIMINGS_BATCHED_ENABLED + STOP_TIME_CU(batch_backward); + float time_ms; + GET_TIME_CU(batch_backward, &time_ms); + std::cout << "Backward render batched time per example: " + << time_ms / static_cast(batch_size) << "ms" << std::endl; +#endif + } else { +#ifdef PULSAR_TIMINGS_BATCHED_ENABLED + START_TIME(batch_backward); +#endif + if (batch_processing) { + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + // These calls are non-blocking and just kick off the computations. + if (dbg_pos.has_value()) { + PRE::backward_dbg( + &this->renderer_vec[batch_i], + grad_im[batch_i].contiguous().data_ptr(), + image[batch_i].contiguous().data_ptr(), + forw_info[batch_i].contiguous().data_ptr(), + vert_pos[batch_i].contiguous().data_ptr(), + vert_col[batch_i].contiguous().data_ptr(), + vert_radii[batch_i].contiguous().data_ptr(), + cam_infos[batch_i], + gamma, + percent_allowed_difference, + max_n_hits, + opacity_ptr, + n_points, + mode, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + dbg_pos.value().first, + dbg_pos.value().second, + nullptr); + } else { + PRE::backward( + &this->renderer_vec[batch_i], + grad_im[batch_i].contiguous().data_ptr(), + image[batch_i].contiguous().data_ptr(), + forw_info[batch_i].contiguous().data_ptr(), + vert_pos[batch_i].contiguous().data_ptr(), + vert_col[batch_i].contiguous().data_ptr(), + vert_radii[batch_i].contiguous().data_ptr(), + cam_infos[batch_i], + gamma, + percent_allowed_difference, + max_n_hits, + opacity_ptr, + n_points, + mode, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + nullptr); + } + } + } else { + if (dbg_pos.has_value()) { + PRE::backward_dbg( + this->renderer_vec.data(), + grad_im.contiguous().data_ptr(), + image.contiguous().data_ptr(), + forw_info.contiguous().data_ptr(), + vert_pos.contiguous().data_ptr(), + vert_col.contiguous().data_ptr(), + vert_radii.contiguous().data_ptr(), + cam_infos[0], + gamma, + percent_allowed_difference, + max_n_hits, + opacity_ptr, + n_points, + mode, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + dbg_pos.value().first, + dbg_pos.value().second, + nullptr); + } else { + PRE::backward( + this->renderer_vec.data(), + grad_im.contiguous().data_ptr(), + image.contiguous().data_ptr(), + forw_info.contiguous().data_ptr(), + vert_pos.contiguous().data_ptr(), + vert_col.contiguous().data_ptr(), + vert_radii.contiguous().data_ptr(), + cam_infos[0], + gamma, + percent_allowed_difference, + max_n_hits, + opacity_ptr, + n_points, + mode, + dif_pos, + dif_col, + dif_rad, + dif_cam, + dif_opy, + nullptr); + } + } +#ifdef PULSAR_TIMINGS_BATCHED_ENABLED + STOP_TIME(batch_backward); + float time_ms; + GET_TIME(batch_backward, &time_ms); + std::cout << "Backward render batched time per example: " + << time_ms / static_cast(batch_size) << "ms" << std::endl; +#endif + } + if (dif_pos) { + if (batch_processing) { + std::vector results(batch_size); + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + results[batch_i] = from_blob( + reinterpret_cast(this->renderer_vec[batch_i].grad_pos_d), + {static_cast(n_points), 3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + std::get<0>(ret) = torch::stack(results); + } else { + std::get<0>(ret) = from_blob( + reinterpret_cast(this->renderer_vec[0].grad_pos_d), + {static_cast(n_points), 3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + } + if (dif_col) { + if (batch_processing) { + std::vector results(batch_size); + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + results[batch_i] = from_blob( + reinterpret_cast(this->renderer_vec[batch_i].grad_col_d), + {static_cast(n_points), + this->renderer_vec[0].cam.n_channels}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + std::get<1>(ret) = torch::stack(results); + } else { + std::get<1>(ret) = from_blob( + reinterpret_cast(this->renderer_vec[0].grad_col_d), + {static_cast(n_points), + this->renderer_vec[0].cam.n_channels}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + } + if (dif_rad) { + if (batch_processing) { + std::vector results(batch_size); + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + results[batch_i] = from_blob( + reinterpret_cast(this->renderer_vec[batch_i].grad_rad_d), + {static_cast(n_points)}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + std::get<2>(ret) = torch::stack(results); + } else { + std::get<2>(ret) = from_blob( + reinterpret_cast(this->renderer_vec[0].grad_rad_d), + {static_cast(n_points)}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + } + if (dif_cam) { + if (batch_processing) { + std::vector res_p1(batch_size); + std::vector res_p2(batch_size); + std::vector res_p3(batch_size); + std::vector res_p4(batch_size); + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + res_p1[batch_i] = from_blob( + reinterpret_cast(this->renderer_vec[batch_i].grad_cam_d), + {3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + res_p2[batch_i] = from_blob( + reinterpret_cast( + this->renderer_vec[batch_i].grad_cam_d + 3), + {3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + res_p3[batch_i] = from_blob( + reinterpret_cast( + this->renderer_vec[batch_i].grad_cam_d + 6), + {3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + res_p4[batch_i] = from_blob( + reinterpret_cast( + this->renderer_vec[batch_i].grad_cam_d + 9), + {3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + std::get<3>(ret) = torch::stack(res_p1); + std::get<4>(ret) = torch::stack(res_p2); + std::get<5>(ret) = torch::stack(res_p3); + std::get<6>(ret) = torch::stack(res_p4); + } else { + std::get<3>(ret) = from_blob( + reinterpret_cast(this->renderer_vec[0].grad_cam_d), + {3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + std::get<4>(ret) = from_blob( + reinterpret_cast(this->renderer_vec[0].grad_cam_d + 3), + {3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + std::get<5>(ret) = from_blob( + reinterpret_cast(this->renderer_vec[0].grad_cam_d + 6), + {3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + std::get<6>(ret) = from_blob( + reinterpret_cast(this->renderer_vec[0].grad_cam_d + 9), + {3}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + } + if (dif_opy) { + if (batch_processing) { + std::vector results(batch_size); + for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) { + results[batch_i] = from_blob( + reinterpret_cast(this->renderer_vec[batch_i].grad_opy_d), + {static_cast(n_points)}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + std::get<7>(ret) = torch::stack(results); + } else { + std::get<7>(ret) = from_blob( + reinterpret_cast(this->renderer_vec[0].grad_opy_d), + {static_cast(n_points)}, + this->device_type, + this->device_index, + torch::kFloat, + this->device_type == c10::DeviceType::CUDA + ? at::cuda::getCurrentCUDAStream() + : (cudaStream_t) nullptr); + } + } + return ret; +}; + +} // namespace pytorch +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/pytorch/renderer.h b/pytorch3d/csrc/pulsar/pytorch/renderer.h new file mode 100644 index 00000000..a5144d29 --- /dev/null +++ b/pytorch3d/csrc/pulsar/pytorch/renderer.h @@ -0,0 +1,167 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_PYTORCH_RENDERER_H_ +#define PULSAR_NATIVE_PYTORCH_RENDERER_H_ + +#include "../global.h" +#include "../include/renderer.h" + +namespace pulsar { +namespace pytorch { + +struct Renderer { + public: + /** + * Pytorch Pulsar differentiable rendering module. + */ + explicit Renderer( + const unsigned int& width, + const unsigned int& height, + const uint& max_n_balls, + const bool& orthogonal_projection, + const bool& right_handed_system, + const float& background_normalization_depth, + const uint& n_channels, + const uint& n_track); + ~Renderer(); + + std::tuple forward( + const torch::Tensor& vert_pos, + const torch::Tensor& vert_col, + const torch::Tensor& vert_radii, + const torch::Tensor& cam_pos, + const torch::Tensor& pixel_0_0_center, + const torch::Tensor& pixel_vec_x, + const torch::Tensor& pixel_vec_y, + const torch::Tensor& focal_length, + const torch::Tensor& principal_point_offsets, + const float& gamma, + const float& max_depth, + float min_depth, + const c10::optional& bg_col, + const c10::optional& opacity, + const float& percent_allowed_difference, + const uint& max_n_hits, + const uint& mode); + + std::tuple< + at::optional, + at::optional, + at::optional, + at::optional, + at::optional, + at::optional, + at::optional, + at::optional> + backward( + const torch::Tensor& grad_im, + const torch::Tensor& image, + const torch::Tensor& forw_info, + const torch::Tensor& vert_pos, + const torch::Tensor& vert_col, + const torch::Tensor& vert_radii, + const torch::Tensor& cam_pos, + const torch::Tensor& pixel_0_0_center, + const torch::Tensor& pixel_vec_x, + const torch::Tensor& pixel_vec_y, + const torch::Tensor& focal_length, + const torch::Tensor& principal_point_offsets, + const float& gamma, + const float& max_depth, + float min_depth, + const c10::optional& bg_col, + const c10::optional& opacity, + const float& percent_allowed_difference, + const uint& max_n_hits, + const uint& mode, + const bool& dif_pos, + const bool& dif_col, + const bool& dif_rad, + const bool& dif_cam, + const bool& dif_opy, + const at::optional>& dbg_pos); + + // Infrastructure. + /** + * Ensure that the renderer is placed on this device. + * Is nearly a no-op if the device is correct. + */ + void ensure_on_device(torch::Device device, bool non_blocking = false); + + /** + * Ensure that at least n renderers are available. + */ + void ensure_n_renderers_gte(const size_t& batch_size); + + /** + * Check the parameters. + */ + std::tuple arg_check( + const torch::Tensor& vert_pos, + const torch::Tensor& vert_col, + const torch::Tensor& vert_radii, + const torch::Tensor& cam_pos, + const torch::Tensor& pixel_0_0_center, + const torch::Tensor& pixel_vec_x, + const torch::Tensor& pixel_vec_y, + const torch::Tensor& focal_length, + const torch::Tensor& principal_point_offsets, + const float& gamma, + const float& max_depth, + float& min_depth, + const c10::optional& bg_col, + const c10::optional& opacity, + const float& percent_allowed_difference, + const uint& max_n_hits, + const uint& mode); + + bool operator==(const Renderer& rhs) const; + inline friend std::ostream& operator<<( + std::ostream& stream, + const Renderer& self) { + stream << "pulsar::Renderer["; + // Device info. + stream << self.device_type; + if (self.device_index != -1) + stream << ", ID " << self.device_index; + stream << "]"; + return stream; + } + + inline uint width() const { + return this->renderer_vec[0].cam.film_width; + } + inline uint height() const { + return this->renderer_vec[0].cam.film_height; + } + inline int max_num_balls() const { + return this->renderer_vec[0].max_num_balls; + } + inline bool orthogonal() const { + return this->renderer_vec[0].cam.orthogonal_projection; + } + inline bool right_handed() const { + return this->renderer_vec[0].cam.right_handed; + } + inline uint n_track() const { + return static_cast(this->renderer_vec[0].n_track); + } + + /** A tensor that is registered as a buffer with this Module to track its + * device placement. Unfortunately, pytorch doesn't offer tracking Module + * device placement in a better way as of now. + */ + torch::Tensor device_tracker; + + protected: + /** The device type for this renderer. */ + c10::DeviceType device_type; + /** The device index for this renderer. */ + c10::DeviceIndex device_index; + /** Pointer to the underlying pulsar renderers. */ + std::vector renderer_vec; +}; + +} // namespace pytorch +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp b/pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp new file mode 100644 index 00000000..c1e9e108 --- /dev/null +++ b/pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include +#include +#include + +#include "./tensor_util.h" + +namespace pulsar { +namespace pytorch { + +torch::Tensor sphere_ids_from_result_info_nograd( + const torch::Tensor& forw_info) { + torch::Tensor result = torch::zeros( + {forw_info.size(0), + forw_info.size(1), + forw_info.size(2), + (forw_info.size(3) - 3) / 2}, + torch::TensorOptions().device(forw_info.device()).dtype(torch::kInt32)); + // Get the relevant slice, contiguous. + torch::Tensor tmp = + forw_info + .slice( + /*dim=*/3, /*start=*/3, /*end=*/forw_info.size(3), /*step=*/2) + .contiguous(); + if (forw_info.device().type() == c10::DeviceType::CUDA) { + cudaMemcpyAsync( + result.data_ptr(), + tmp.data_ptr(), + sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) * + tmp.size(3), + cudaMemcpyDeviceToDevice, + at::cuda::getCurrentCUDAStream()); + } else { + memcpy( + result.data_ptr(), + tmp.data_ptr(), + sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) * + tmp.size(3)); + } + // `tmp` is freed after this, the memory might get reallocated. However, + // only kernels in the same stream should ever be able to write to this + // memory, which are executed only after the memcpy is complete. That's + // why we can just continue. + return result; +} + +} // namespace pytorch +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/pytorch/tensor_util.h b/pytorch3d/csrc/pulsar/pytorch/tensor_util.h new file mode 100644 index 00000000..b98f7e50 --- /dev/null +++ b/pytorch3d/csrc/pulsar/pytorch/tensor_util.h @@ -0,0 +1,16 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_PYTORCH_TENSOR_UTIL_H_ +#define PULSAR_NATIVE_PYTORCH_TENSOR_UTIL_H_ + +#include + +namespace pulsar { +namespace pytorch { + +torch::Tensor sphere_ids_from_result_info_nograd( + const torch::Tensor& forw_info); + +} +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/pytorch/util.cpp b/pytorch3d/csrc/pulsar/pytorch/util.cpp new file mode 100644 index 00000000..847e697e --- /dev/null +++ b/pytorch3d/csrc/pulsar/pytorch/util.cpp @@ -0,0 +1,24 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include + +namespace pulsar { +namespace pytorch { + +void cudaDevToDev( + void* trg, + const void* src, + const int& size, + const cudaStream_t& stream) { + cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream); +} + +void cudaDevToHost( + void* trg, + const void* src, + const int& size, + const cudaStream_t& stream) { + cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream); +} + +} // namespace pytorch +} // namespace pulsar diff --git a/pytorch3d/csrc/pulsar/pytorch/util.h b/pytorch3d/csrc/pulsar/pytorch/util.h new file mode 100644 index 00000000..bab41678 --- /dev/null +++ b/pytorch3d/csrc/pulsar/pytorch/util.h @@ -0,0 +1,59 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#ifndef PULSAR_NATIVE_PYTORCH_UTIL_H_ +#define PULSAR_NATIVE_PYTORCH_UTIL_H_ + +#include +#include "../global.h" + +namespace pulsar { +namespace pytorch { + +void cudaDevToDev( + void* trg, + const void* src, + const int& size, + const cudaStream_t& stream); +void cudaDevToHost( + void* trg, + const void* src, + const int& size, + const cudaStream_t& stream); + +/** + * This method takes a memory pointer and wraps it into a pytorch tensor. + * + * This is preferred over `torch::from_blob`, since that requires a CUDA + * managed pointer. However, working with these for high performance + * operations is slower. Most of the rendering operations should stay + * local to the respective GPU anyways, so unmanaged pointers are + * preferred. + */ +template +torch::Tensor from_blob( + const T* ptr, + const torch::IntArrayRef& shape, + const c10::DeviceType& device_type, + const c10::DeviceIndex& device_index, + const torch::Dtype& dtype, + const cudaStream_t& stream) { + torch::Tensor ret = torch::zeros( + shape, torch::device({device_type, device_index}).dtype(dtype)); + const int num_elements = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); + if (device_type == c10::DeviceType::CUDA) { + cudaDevToDev( + ret.data_ptr(), + static_cast(ptr), + sizeof(T) * num_elements, + stream); + // TODO: check for synchronization. + } else { + memcpy(ret.data_ptr(), ptr, sizeof(T) * num_elements); + } + return ret; +}; + +} // namespace pytorch +} // namespace pulsar + +#endif diff --git a/pytorch3d/csrc/pulsar/warnings.cpp b/pytorch3d/csrc/pulsar/warnings.cpp new file mode 100644 index 00000000..0a875b2a --- /dev/null +++ b/pytorch3d/csrc/pulsar/warnings.cpp @@ -0,0 +1,14 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +#include "./global.h" +#include "./logging.h" + +/** + * A compilation unit to provide warnings about the code and avoid + * repeated messages. + */ +#ifdef PULSAR_ASSERTIONS +#pragma message("WARNING: assertions are enabled in Pulsar.") +#endif +#ifdef PULSAR_LOGGING_ENABLED +#pragma message("WARNING: logging is enabled in Pulsar.") +#endif diff --git a/pytorch3d/renderer/points/pulsar/__init__.py b/pytorch3d/renderer/points/pulsar/__init__.py new file mode 100644 index 00000000..383929f9 --- /dev/null +++ b/pytorch3d/renderer/points/pulsar/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from .renderer import Renderer # noqa: F401 diff --git a/pytorch3d/renderer/points/pulsar/renderer.py b/pytorch3d/renderer/points/pulsar/renderer.py new file mode 100644 index 00000000..0e86d6aa --- /dev/null +++ b/pytorch3d/renderer/points/pulsar/renderer.py @@ -0,0 +1,692 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""pulsar renderer PyTorch integration. + +Proper Python support for pytorch requires creating a torch.autograd.function +(independent of whether this is being done within the C++ module). This is done +here and a torch.nn.Module is exposed for the use in more complex models. +""" +import logging +import math +import warnings +from typing import Optional, Tuple, Union + +import torch + +# pyre-fixme[21]: Could not find a name `_C` defined in module `pytorch3d`. +from pytorch3d import _C +from pytorch3d.transforms import axis_angle_to_matrix, rotation_6d_to_matrix + + +LOGGER = logging.getLogger(__name__) +GAMMA_WARNING_EMITTED = False +AXANGLE_WARNING_EMITTED = False + + +class _Render(torch.autograd.Function): + """ + Differentiable rendering function for the Pulsar renderer. + + Usually this will be used through the `Renderer` module, which takes care of + setting up the buffers and putting them on the correct device. If you use + the function directly, you will have to do this manually. + + The steps for this are two-fold: first, you need to create a native Renderer + object to provide the required buffers. This is the `native_renderer` parameter + for this function. You can create it by creating a `pytorch3d._C.PulsarRenderer` + object (with parameters for width, height and maximum number of balls it should + be able to render). This object by default resides on the CPU. If you want to + shift the buffers to a different device, just assign an empty tensor on the target + device to its property `device_tracker`. + + To convert camera parameters from a more convenient representation to the + required vectors as in this function, you can use the static + function `pytorch3d.renderer.points.pulsar.Renderer._transform_cam_params`. + + Args: + * ctx: Pytorch context. + * vert_pos: vertex positions. [Bx]Nx3 tensor of positions in 3D space. + * vert_col: vertex colors. [Bx]NxK tensor of channels. + * vert_rad: vertex radii. [Bx]N tensor of radiuses, >0. + * cam_pos: camera position(s). [Bx]3 tensor in 3D coordinates. + * pixel_0_0_center: [Bx]3 tensor center(s) of the upper left pixel(s) in + world coordinates. + * pixel_vec_x: [Bx]3 tensor from one pixel center to the next in image x + direction in world coordinates. + * pixel_vec_y: [Bx]3 tensor from one pixel center to the next in image y + direction in world coordinates. + * focal_length: [Bx]1 tensor of focal lengths in world coordinates. + * principal_point_offsets: [Bx]2 tensor of principal point offsets in pixels. + * gamma: sphere transparency in [1.,1E-5], with 1 being mostly transparent. + [Bx]1. + * max_depth: maximum depth for spheres to render. Set this as tighly + as possible to have good numerical accuracy for gradients. + * native_renderer: a `pytorch3d._C.PulsarRenderer` object. + * min_depth: a float with the minimum depth a sphere must have to be renderer. + Must be 0. or > max(focal_length). + * bg_col: K tensor with a background color to use or None (uses all ones). + * opacity: [Bx]N tensor of opacity values in [0., 1.] or None (uses all ones). + * percent_allowed_difference: a float in [0., 1.[ with the maximum allowed + difference in color space. This is used to speed up the + computation. Default: 0.01. + * max_n_hits: a hard limit on the number of hits per ray. Default: max int. + * mode: render mode in {0, 1}. 0: render an image; 1: render the hit map. + * return_forward_info: whether to return a second map. This second map contains + 13 channels: first channel contains sm_m (the maximum exponent factor + observed), the second sm_d (the normalization denominator, the sum of all + coefficients), the third the maximum closest possible intersection for a + hit. The following channels alternate with the float encoded integer index + of a sphere and its weight. They are the five spheres with the highest + color contribution to this pixel color, ordered descending. + + Returns: + * image: [Bx]HxWxK float tensor with the resulting image. + * forw_info: [Bx]HxWx13 float forward information as described above, + if enabled. + """ + + @staticmethod + def forward( + ctx, + vert_pos, + vert_col, + vert_rad, + cam_pos, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + focal_length, + principal_point_offsets, + gamma, + max_depth, + native_renderer, + min_depth=0.0, + bg_col=None, + opacity=None, + percent_allowed_difference=0.01, + max_n_hits=_C.MAX_UINT, + mode=0, + return_forward_info=False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if mode != 0: + assert not return_forward_info, ( + "You are using a non-standard rendering mode. This does " + "not provide gradients, and also no `forward_info`. Please " + "set `return_forward_info` to `False`." + ) + ctx.gamma = gamma + ctx.max_depth = max_depth + ctx.min_depth = min_depth + ctx.percent_allowed_difference = percent_allowed_difference + ctx.max_n_hits = max_n_hits + ctx.mode = mode + ctx.native_renderer = native_renderer + image, info = ctx.native_renderer.forward( + vert_pos, + vert_col, + vert_rad, + cam_pos, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + focal_length, + principal_point_offsets, + gamma, + max_depth, + min_depth, + bg_col, + opacity, + percent_allowed_difference, + max_n_hits, + mode, + ) + if mode != 0: + # Backprop not possible! + info = None + # Prepare for backprop. + ctx.save_for_backward( + vert_pos, + vert_col, + vert_rad, + cam_pos, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + focal_length, + principal_point_offsets, + bg_col, + opacity, + image, + info, + ) + if return_forward_info: + return image, info + else: + return image + + @staticmethod + def backward(ctx, grad_im, *args): + global GAMMA_WARNING_EMITTED + ( + vert_pos, + vert_col, + vert_rad, + cam_pos, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + focal_length, + principal_point_offsets, + bg_col, + opacity, + image, + info, + ) = ctx.saved_tensors + if ( + ( + ctx.needs_input_grad[0] + or ctx.needs_input_grad[2] + or ctx.needs_input_grad[3] + or ctx.needs_input_grad[4] + or ctx.needs_input_grad[5] + or ctx.needs_input_grad[6] + or ctx.needs_input_grad[7] + ) + and ctx.gamma < 1e-3 + and not GAMMA_WARNING_EMITTED + ): + warnings.warn( + "Optimizing for non-color parameters and having a gamma value < 1E-3! " + "This is probably not going to produce usable gradients." + ) + GAMMA_WARNING_EMITTED = True + if ctx.mode == 0: + ( + grad_pos, + grad_col, + grad_rad, + grad_cam_pos, + grad_pixel_0_0_center, + grad_pixel_vec_x, + grad_pixel_vec_y, + grad_opacity, + ) = ctx.native_renderer.backward( + grad_im, + image, + info, + vert_pos, + vert_col, + vert_rad, + cam_pos, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + focal_length, + principal_point_offsets, + ctx.gamma, + ctx.max_depth, + ctx.min_depth, + bg_col, + opacity, + ctx.percent_allowed_difference, + ctx.max_n_hits, + ctx.mode, + ctx.needs_input_grad[0], + ctx.needs_input_grad[1], + ctx.needs_input_grad[2], + ctx.needs_input_grad[3] + or ctx.needs_input_grad[4] + or ctx.needs_input_grad[5] + or ctx.needs_input_grad[6], + ctx.needs_input_grad[13], + None, # No debug information provided. + ) + else: + raise ValueError( + "Performing a backward pass for a " + "rendering with `mode != 0`! This is not possible." + ) + return ( + grad_pos, + grad_col, + grad_rad, + grad_cam_pos, + grad_pixel_0_0_center, + grad_pixel_vec_x, + grad_pixel_vec_y, + None, # focal_length + None, # principal_point_offsets + None, # gamma + None, # max_depth + None, # native_renderer + None, # min_depth + None, # bg_col + grad_opacity, + None, # percent_allowed_difference + None, # max_n_hits + None, # mode + None, # return_forward_info + ) + + +class Renderer(torch.nn.Module): + """ + Differentiable rendering module for the Pulsar renderer. + + Set the maximum number of balls to a reasonable value. It is used to determine + several buffer sizes. It is no problem to render less balls than this number, + but never more. + + When optimizing for sphere positions, sphere radiuses or camera parameters you + have to use higher gamma values (closer to one) and larger sphere sizes: spheres + can only 'move' to areas that they cover, and only with higher gamma values exists + a gradient w.r.t. their color depending on their position. + + Args: + * width: result image width in pixels. + * height: result image height in pixels. + * max_num_balls: the maximum number of balls this renderer will handle. + * orthogonal_projection: use an orthogonal instead of perspective projection. + Default: False. + * right_handed_system: use a right-handed instead of a left-handed coordinate + system. This is relevant for compatibility with other drawing or scanning + systems. Pulsar by default assumes a left-handed world and camera coordinate + system as known from mathematics with x-axis to the right, y axis up and z + axis for increasing depth along the optical axis. In the image coordinate + system, only the y axis is pointing down, leading still to a left-handed + system. If you set this to True, it is assuming a right-handed world and + camera coordinate system with x axis to the right, y axis to the top and + z axis decreasing along the optical axis. Again, the image coordinate + system has a flipped y axis, remaining a right-handed system. + Default: False. + * background_normalized_depth: the normalized depth the background is placed + at. + This is on a scale from 0. to 1. between the specified min and max depth + (see the forward function). The value 0. is the most furthest depth whereas + 1. is the closest. Be careful when setting the background too far front - it + may hide elements in your scene. Default: EPS. + * n_channels: the number of image content channels to use. This is usually three + for regular color representations, but can be a higher or lower number. + Default: 3. + * n_track: the number of spheres to track for gradient calculation per pixel. + Only the closest n_track spheres will receive gradients. Default: 5. + """ + + def __init__( + self, + width: int, + height: int, + max_num_balls: int, + orthogonal_projection: bool = False, + right_handed_system: bool = False, + background_normalized_depth: float = _C.EPS, + n_channels: int = 3, + n_track: int = 5, + ): + super(Renderer, self).__init__() + # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`. + self._renderer = _C.PulsarRenderer( + width, + height, + max_num_balls, + orthogonal_projection, + right_handed_system, + background_normalized_depth, + n_channels, + n_track, + ) + self.register_buffer("device_tracker", torch.zeros(1)) + + @staticmethod + def sphere_ids_from_result_info_nograd(result_info: torch.Tensor) -> torch.Tensor: + """ + Get the sphere IDs from a result info tensor. + """ + if result_info.ndim == 3: + return Renderer.sphere_ids_from_result_info_nograd(result_info[None, ...]) + # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`. + return _C.pulsar_sphere_ids_from_result_info_nograd(result_info) + + @staticmethod + def depth_map_from_result_info_nograd(result_info: torch.Tensor) -> torch.Tensor: + """ + Get the depth map from a result info tensor. + + This returns a map of the same size as the image with just one channel + containing the closest intersection value at that position. Gradients + are not available for this tensor, but do note that you can use + `sphere_ids_from_result_info_nograd` to get the IDs of the spheres at + each position and directly create a loss on their depth if required. + + The depth map contains -1. at positions where no intersection has + been detected. + """ + return result_info[..., 4] + + @staticmethod + def _transform_cam_params( + cam_params: torch.Tensor, + width: int, + height: int, + orthogonal: bool, + right_handed: bool, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Transform 8 component camera parameter vector(s) to the internal camera + representation. + + The input vectors consists of: + * 3 components for camera position, + * 3 components for camera rotation (three rotation angles) or + 6 components as described in "On the Continuity of Rotation + Representations in Neural Networks" (Zhou et al.), + * focal length, + * the sensor width in world coordinates, + * [optional] the principal point offset in x and y. + + The sensor height is inferred by pixel size and sensor width to obtain + quadratic pixels. + + Args: + * cam_params: [Bx]{8, 10, 11, 13}, input tensors as described above. + * width: number of pixels in x direction. + * height: number of pixels in y direction. + * orthogonal: bool, whether an orthogonal projection is used + (does not use focal length). + * right_handed: bool, whether to use a right handed system + (negative z in camera direction). + + Returns: + * pos_vec: the position vector in 3D, + * pixel_0_0_center: the center of the upper left pixel in world coordinates, + * pixel_vec_x: the step to move one pixel on the image x axis + in world coordinates, + * pixel_vec_y: the step to move one pixel on the image y axis + in world coordinates, + * focal_length: the focal lengths, + * principal_point_offsets: the principal point offsets in x, y. + """ + global AXANGLE_WARNING_EMITTED + # Set up all direction vectors, i.e., the sensor direction of all axes. + assert width > 0 + assert height > 0 + batch_processing = True + if cam_params.ndimension() == 1: + batch_processing = False + cam_params = cam_params[None, :] + batch_size = cam_params.size(0) + continuous_rep = True + if cam_params.shape[1] in [8, 10]: + if cam_params.requires_grad and not AXANGLE_WARNING_EMITTED: + warnings.warn( + "Using an axis angle representation for camera rotations. " + "This has discontinuities and should not be used for optimization. " + "Alternatively, use a six-component representation as described in " + "'On the Continuity of Rotation Representations in Neural Networks'" + " (Zhou et al.). " + "The `pytorch3d.transforms` module provides " + "facilities for using this representation." + ) + AXANGLE_WARNING_EMITTED = True + continuous_rep = False + else: + assert cam_params.shape[1] in [11, 13] + pos_vec: torch.Tensor = cam_params[:, :3] + principal_point_offsets: torch.Tensor = torch.zeros( + (cam_params.shape[0], 2), dtype=torch.int32, device=cam_params.device + ) + if continuous_rep: + rot_vec = cam_params[:, 3:9] + focal_length: torch.Tensor = cam_params[:, 9:10] + sensor_size_x = cam_params[:, 10:11] + if cam_params.shape[1] == 13: + principal_point_offsets: torch.Tensor = cam_params[:, 11:13].to( + torch.int32 + ) + else: + rot_vec = cam_params[:, 3:6] + focal_length: torch.Tensor = cam_params[:, 6:7] + sensor_size_x = cam_params[:, 7:8] + if cam_params.shape[1] == 10: + principal_point_offsets: torch.Tensor = cam_params[:, 8:10].to( + torch.int32 + ) + # Always get quadratic pixels. + pixel_size_x = sensor_size_x / float(width) + sensor_size_y = height * pixel_size_x + LOGGER.debug( + "Camera position: %s, rotation: %s. Focal length: %s.", + str(pos_vec), + str(rot_vec), + str(focal_length), + ) + if continuous_rep: + rot_mat = rotation_6d_to_matrix(rot_vec) + else: + rot_mat = axis_angle_to_matrix(rot_vec) + sensor_dir_x = torch.matmul( + rot_mat, + torch.tensor( + [1.0, 0.0, 0.0], dtype=torch.float32, device=rot_mat.device + ).repeat(batch_size, 1)[:, :, None], + )[:, :, 0] + sensor_dir_y = torch.matmul( + rot_mat, + torch.tensor( + [0.0, -1.0, 0.0], dtype=torch.float32, device=rot_mat.device + ).repeat(batch_size, 1)[:, :, None], + )[:, :, 0] + sensor_dir_z = torch.matmul( + rot_mat, + torch.tensor( + [0.0, 0.0, 1.0], dtype=torch.float32, device=rot_mat.device + ).repeat(batch_size, 1)[:, :, None], + )[:, :, 0] + if right_handed: + sensor_dir_z *= -1 + LOGGER.debug( + "Sensor direction vectors: %s, %s, %s.", + str(sensor_dir_x), + str(sensor_dir_y), + str(sensor_dir_z), + ) + if orthogonal: + sensor_center = pos_vec + else: + sensor_center = pos_vec + focal_length * sensor_dir_z + LOGGER.debug("Sensor center: %s.", str(sensor_center)) + sensor_luc = ( # Sensor left upper corner. + sensor_center + - sensor_dir_x * (sensor_size_x / 2.0) + - sensor_dir_y * (sensor_size_y / 2.0) + ) + LOGGER.debug("Sensor luc: %s.", str(sensor_luc)) + pixel_size_x = sensor_size_x / float(width) + pixel_size_y = sensor_size_y / float(height) + LOGGER.debug( + "Pixel sizes (x): %s, (y) %s.", str(pixel_size_x), str(pixel_size_y) + ) + pixel_vec_x: torch.Tensor = sensor_dir_x * pixel_size_x + pixel_vec_y: torch.Tensor = sensor_dir_y * pixel_size_y + pixel_0_0_center = sensor_luc + 0.5 * pixel_vec_x + 0.5 * pixel_vec_y + LOGGER.debug( + "Pixel 0 centers: %s, vec x: %s, vec y: %s.", + str(pixel_0_0_center), + str(pixel_vec_x), + str(pixel_vec_y), + ) + if not orthogonal: + LOGGER.debug( + "Camera horizontal fovs: %s deg.", + str( + 2.0 + * torch.atan(0.5 * sensor_size_x / focal_length) + / math.pi + * 180.0 + ), + ) + LOGGER.debug( + "Camera vertical fovs: %s deg.", + str( + 2.0 + * torch.atan(0.5 * sensor_size_y / focal_length) + / math.pi + * 180.0 + ), + ) + # Reduce dimension. + focal_length: torch.Tensor = focal_length[:, 0] + if batch_processing: + return ( + pos_vec, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + focal_length, + principal_point_offsets, + ) + else: + return ( + pos_vec[0], + pixel_0_0_center[0], + pixel_vec_x[0], + pixel_vec_y[0], + focal_length[0], + principal_point_offsets[0], + ) + + def forward( + self, + vert_pos: torch.Tensor, + vert_col: torch.Tensor, + vert_rad: torch.Tensor, + cam_params: torch.Tensor, + gamma: float, + max_depth: float, + min_depth: float = 0.0, + bg_col: Optional[torch.Tensor] = None, + opacity: Optional[torch.Tensor] = None, + percent_allowed_difference: float = 0.01, + max_n_hits: int = _C.MAX_UINT, + mode: int = 0, + return_forward_info: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + """ + Rendering pass to create an image from the provided spheres and camera + parameters. + + Args: + * vert_pos: vertex positions. [Bx]Nx3 tensor of positions in 3D space. + * vert_col: vertex colors. [Bx]NxK tensor of channels. + * vert_rad: vertex radii. [Bx]N tensor of radiuses, >0. + * cam_params: camera parameter(s). [Bx]8 tensor, consisting of: + - 3 components for camera position, + - 3 components for camera rotation (axis angle representation) or + 6 components as described in "On the Continuity of Rotation + Representations in Neural Networks" (Zhou et al.), + - focal length, + - the sensor width in world coordinates, + - [optional] an offset for the principal point in x, y (no gradients). + * gamma: sphere transparency in [1.,1E-5], with 1 being mostly transparent. + [Bx]1. + * max_depth: maximum depth for spheres to render. Set this as tightly + as possible to have good numerical accuracy for gradients. + float > min_depth + eps. + * min_depth: a float with the minimum depth a sphere must have to be + rendered. Must be 0. or > max(focal_length) + eps. + * bg_col: K tensor with a background color to use or None (uses all ones). + * opacity: [Bx]N tensor of opacity values in [0., 1.] or None (uses all + ones). + * percent_allowed_difference: a float in [0., 1.[ with the maximum allowed + difference in color space. This is used to speed up the + computation. Default: 0.01. + * max_n_hits: a hard limit on the number of hits per ray. Default: max int. + * mode: render mode in {0, 1}. 0: render an image; 1: render the hit map. + * return_forward_info: whether to return a second map. This second map + contains 13 channels: first channel contains sm_m (the maximum + exponent factor observed), the second sm_d (the normalization + denominator, the sum of all coefficients), the third the maximum closest + possible intersection for a hit. The following channels alternate with + the float encoded integer index of a sphere and its weight. They are the + five spheres with the highest color contribution to this pixel color, + ordered descending. Default: False. + + Returns: + * image: [Bx]HxWx3 float tensor with the resulting image. + * forw_info: [Bx]HxWx13 float forward information as described above, if + enabled. + """ + # The device tracker is registered as buffer. + # pyre-fixme[16]: `Renderer` has no attribute `device_tracker`. + self._renderer.device_tracker = self.device_tracker + ( + pos_vec, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + focal_lengths, + principal_point_offsets, + ) = Renderer._transform_cam_params( + cam_params, + self._renderer.width, + self._renderer.height, + self._renderer.orthogonal, + self._renderer.right_handed, + ) + if ( + focal_lengths.min().item() > 0.0 + and max_depth > 10_000.0 * focal_lengths.min().item() + ): + warnings.warn( + ( + "Extreme ratio of `max_depth` vs. focal length detected " + "(%f vs. %f, ratio: %f). This will likely lead to " + "artifacts due to numerical instabilities." + ) + % ( + max_depth, + focal_lengths.min().item(), + max_depth / focal_lengths.min().item(), + ) + ) + # pyre-fixme[16]: `_Render` has no attribute `apply`. + ret_res = _Render.apply( + vert_pos, + vert_col, + vert_rad, + pos_vec, + pixel_0_0_center, + pixel_vec_x, + pixel_vec_y, + # Focal length and sensor size don't need gradients other than through + # `pixel_vec_x` and `pixel_vec_y`. The focal length is only used in the + # renderer to determine the projection areas of the balls. + focal_lengths, + # principal_point_offsets does not receive gradients. + principal_point_offsets, + gamma, + max_depth, + self._renderer, + min_depth, + bg_col, + opacity, + percent_allowed_difference, + max_n_hits, + mode, + (mode == 0) and return_forward_info, + ) + if return_forward_info and mode != 0: + return ret_res, None + return ret_res + + def extra_repr(self) -> str: + """Extra information to print in pytorch graphs.""" + return "width={}, height={}, max_num_balls={}".format( + self._renderer.width, self._renderer.height, self._renderer.max_num_balls + ) diff --git a/pytorch3d/transforms/__init__.py b/pytorch3d/transforms/__init__.py index b30b7fa1..2f0a0301 100644 --- a/pytorch3d/transforms/__init__.py +++ b/pytorch3d/transforms/__init__.py @@ -1,9 +1,13 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from .external.kornia_angle_axis_to_rotation_matrix import ( + angle_axis_to_rotation_matrix as axis_angle_to_matrix, +) from .rotation_conversions import ( euler_angles_to_matrix, matrix_to_euler_angles, matrix_to_quaternion, + matrix_to_rotation_6d, quaternion_apply, quaternion_invert, quaternion_multiply, @@ -12,6 +16,7 @@ from .rotation_conversions import ( random_quaternions, random_rotation, random_rotations, + rotation_6d_to_matrix, standardize_quaternion, ) from .so3 import ( diff --git a/pytorch3d/transforms/external/__init__.py b/pytorch3d/transforms/external/__init__.py new file mode 100644 index 00000000..40539064 --- /dev/null +++ b/pytorch3d/transforms/external/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. diff --git a/pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py b/pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py new file mode 100644 index 00000000..1269813a --- /dev/null +++ b/pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python3 +""" +This file contains the great angle axis to rotation matrix conversion +from kornia (https://github.com/arraiyopensource/kornia). The license +can be found in kornia_license.txt. + +The method is used unchanged; the documentation has been adjusted +to match our doc format. +""" +import torch + + +def angle_axis_to_rotation_matrix(angle_axis): + """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix + + Args: + angle_axis (Tensor): tensor of 3d vector of axis-angle rotations. + + Returns: + Tensor: tensor of 3x3 rotation matrix. + + Shape: + - Input: :math:`(N, 3)` + - Output: :math:`(N, 3, 3)` + + Example: + + ..code-block::python + + >>> input = torch.rand(1, 3) # Nx3 + >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3 + >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3 + """ + + def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6): + # We want to be careful to only evaluate the square root if the + # norm of the angle_axis vector is greater than zero. Otherwise + # we get a division by zero. + k_one = 1.0 + theta = torch.sqrt(theta2) + wxyz = angle_axis / (theta + eps) + wx, wy, wz = torch.chunk(wxyz, 3, dim=1) + cos_theta = torch.cos(theta) + sin_theta = torch.sin(theta) + + r00 = cos_theta + wx * wx * (k_one - cos_theta) + r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) + r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) + r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta + r11 = cos_theta + wy * wy * (k_one - cos_theta) + r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) + r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) + r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) + r22 = cos_theta + wz * wz * (k_one - cos_theta) + rotation_matrix = torch.cat( + [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1 + ) + return rotation_matrix.view(-1, 3, 3) + + def _compute_rotation_matrix_taylor(angle_axis): + rx, ry, rz = torch.chunk(angle_axis, 3, dim=1) + k_one = torch.ones_like(rx) + rotation_matrix = torch.cat( + [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1 + ) + return rotation_matrix.view(-1, 3, 3) + + # stolen from ceres/rotation.h + + _angle_axis = torch.unsqueeze(angle_axis + 1e-6, dim=1) + # _angle_axis.register_hook(lambda grad: pdb.set_trace()) + # _angle_axis = 1e-6 + theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2)) + theta2 = torch.squeeze(theta2, dim=1) + + # compute rotation matrices + rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2) + rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis) + + # create mask to handle both cases + eps = 1e-6 + mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device) + mask_pos = (mask).type_as(theta2) + mask_neg = (mask == False).type_as(theta2) # noqa + + # create output pose matrix + batch_size = angle_axis.shape[0] + rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis) + rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1) + # fill output matrix with masked values + rotation_matrix[..., :3, :3] = ( + mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor + ) + return rotation_matrix.to(angle_axis.device).type_as(angle_axis) # Nx4x4 diff --git a/pytorch3d/transforms/external/kornia_license.txt b/pytorch3d/transforms/external/kornia_license.txt new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/pytorch3d/transforms/external/kornia_license.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/tests/bm_barycentric_clipping.py b/tests/bm_barycentric_clipping.py index df72987e..d55c3c0c 100644 --- a/tests/bm_barycentric_clipping.py +++ b/tests/bm_barycentric_clipping.py @@ -110,3 +110,7 @@ def bm_barycentric_clip() -> None: benchmark(baryclip_cuda, "BARY_CLIP_CUDA", kwargs_list, warmup_iters=1) benchmark(baryclip_pytorch, "BARY_CLIP_PYTORCH", kwargs_list, warmup_iters=1) + + +if __name__ == "__main__": + bm_barycentric_clip() diff --git a/tests/bm_blending.py b/tests/bm_blending.py index 16aa11bc..7febc7ad 100644 --- a/tests/bm_blending.py +++ b/tests/bm_blending.py @@ -42,3 +42,7 @@ def bm_blending() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_blending() diff --git a/tests/bm_cameras_alignment.py b/tests/bm_cameras_alignment.py index 4d0c0397..128cc9c8 100644 --- a/tests/bm_cameras_alignment.py +++ b/tests/bm_cameras_alignment.py @@ -22,3 +22,7 @@ def bm_cameras_alignment() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_cameras_alignment() diff --git a/tests/bm_chamfer.py b/tests/bm_chamfer.py index 4de5829e..5cdee080 100644 --- a/tests/bm_chamfer.py +++ b/tests/bm_chamfer.py @@ -8,6 +8,8 @@ from test_chamfer import TestChamfer def bm_chamfer() -> None: + # Currently disabled. + return devices = ["cpu"] if torch.cuda.is_available(): devices.append("cuda:0") @@ -53,3 +55,7 @@ def bm_chamfer() -> None: } ) benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1) + + +if __name__ == "__main__": + bm_chamfer() diff --git a/tests/bm_cubify.py b/tests/bm_cubify.py index 239b1e69..632a1751 100644 --- a/tests/bm_cubify.py +++ b/tests/bm_cubify.py @@ -11,3 +11,7 @@ def bm_cubify() -> None: {"batch_size": 16, "V": 32}, ] benchmark(TestCubify.cubify_with_init, "CUBIFY", kwargs_list, warmup_iters=1) + + +if __name__ == "__main__": + bm_cubify() diff --git a/tests/bm_face_areas_normals.py b/tests/bm_face_areas_normals.py index 0a01441f..66f85c23 100644 --- a/tests/bm_face_areas_normals.py +++ b/tests/bm_face_areas_normals.py @@ -37,3 +37,7 @@ def bm_face_areas_normals() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_face_areas_normals() diff --git a/tests/bm_graph_conv.py b/tests/bm_graph_conv.py index 404c44a4..bf77b8b5 100644 --- a/tests/bm_graph_conv.py +++ b/tests/bm_graph_conv.py @@ -40,3 +40,7 @@ def bm_graph_conv() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_graph_conv() diff --git a/tests/bm_interpolate_face_attributes.py b/tests/bm_interpolate_face_attributes.py index 60fb8d1c..f721eead 100644 --- a/tests/bm_interpolate_face_attributes.py +++ b/tests/bm_interpolate_face_attributes.py @@ -74,3 +74,7 @@ def bm_interpolate_face_attribues() -> None: kwargs_list.append({"N": N, "S": S, "K": K, "F": F, "D": D, "impl": impl}) benchmark(_bm_forward, "FORWARD", kwargs_list, warmup_iters=3) benchmark(_bm_forward_backward, "FORWARD+BACKWARD", kwargs_list, warmup_iters=3) + + +if __name__ == "__main__": + bm_interpolate_face_attribues() diff --git a/tests/bm_knn.py b/tests/bm_knn.py index 4a96d64b..5a894a22 100644 --- a/tests/bm_knn.py +++ b/tests/bm_knn.py @@ -24,3 +24,7 @@ def bm_knn() -> None: benchmark(TestKNN.knn_square, "KNN_SQUARE", kwargs_list, warmup_iters=1) benchmark(TestKNN.knn_ragged, "KNN_RAGGED", kwargs_list, warmup_iters=1) + + +if __name__ == "__main__": + bm_knn() diff --git a/tests/bm_lighting.py b/tests/bm_lighting.py index a98d0ed1..cc9fcb88 100644 --- a/tests/bm_lighting.py +++ b/tests/bm_lighting.py @@ -45,3 +45,7 @@ def bm_lighting() -> None: kwargs_list.append({"N": N, "S": S, "K": K}) benchmark(_bm_diffuse_cuda_with_init, "DIFFUSE", kwargs_list, warmup_iters=3) benchmark(_bm_specular_cuda_with_init, "SPECULAR", kwargs_list, warmup_iters=3) + + +if __name__ == "__main__": + bm_lighting() diff --git a/tests/bm_main.py b/tests/bm_main.py index f178ef98..26665092 100755 --- a/tests/bm_main.py +++ b/tests/bm_main.py @@ -2,8 +2,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import glob -import importlib -from os.path import basename, dirname, isfile, join, sys +import os +import subprocess +import sys +from os.path import dirname, isfile, join if __name__ == "__main__": @@ -11,20 +13,22 @@ if __name__ == "__main__": if len(sys.argv) > 1: # Parse from flags. # pyre-ignore[16] - module_names = [n for n in sys.argv if n.startswith("bm_")] + file_names = [ + join(dirname(__file__), n) for n in sys.argv if n.startswith("bm_") + ] else: # Get all the benchmark files (starting with "bm_"). bm_files = glob.glob(join(dirname(__file__), "bm_*.py")) - module_names = [ - basename(f)[:-3] - for f in bm_files - if isfile(f) and not f.endswith("bm_main.py") - ] + file_names = sorted( + f for f in bm_files if isfile(f) and not f.endswith("bm_main.py") + ) - for module_name in module_names: - module = importlib.import_module(module_name) - for attr in dir(module): - # Run all the functions with names "bm_*" in the module. - if attr.startswith("bm_"): - print("Running benchmarks for " + module_name + "/" + attr + "...") - getattr(module, attr)() + # Forward all important path information to the subprocesses through the + # environment. + os.environ["PATH"] = sys.path[0] + ":" + os.environ.get("PATH", "") + os.environ["LD_LIBRARY_PATH"] = ( + sys.path[0] + ":" + os.environ.get("LD_LIBRARY_PATH", "") + ) + os.environ["PYTHONPATH"] = ":".join(sys.path) + for file_name in file_names: + subprocess.check_call([sys.executable, file_name]) diff --git a/tests/bm_mesh_edge_loss.py b/tests/bm_mesh_edge_loss.py index b7a9566b..4410b0fb 100644 --- a/tests/bm_mesh_edge_loss.py +++ b/tests/bm_mesh_edge_loss.py @@ -19,3 +19,7 @@ def bm_mesh_edge_loss() -> None: benchmark( TestMeshEdgeLoss.mesh_edge_loss, "MESH_EDGE_LOSS", kwargs_list, warmup_iters=1 ) + + +if __name__ == "__main__": + bm_mesh_edge_loss() diff --git a/tests/bm_mesh_io.py b/tests/bm_mesh_io.py index 15719813..a8f43be2 100644 --- a/tests/bm_mesh_io.py +++ b/tests/bm_mesh_io.py @@ -95,3 +95,7 @@ def bm_save_load() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_save_load() diff --git a/tests/bm_mesh_laplacian_smoothing.py b/tests/bm_mesh_laplacian_smoothing.py index 44eeec2a..7a3bd337 100644 --- a/tests/bm_mesh_laplacian_smoothing.py +++ b/tests/bm_mesh_laplacian_smoothing.py @@ -30,3 +30,7 @@ def bm_mesh_laplacian_smoothing() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_mesh_laplacian_smoothing() diff --git a/tests/bm_mesh_normal_consistency.py b/tests/bm_mesh_normal_consistency.py index 2d69c76d..f6f48699 100644 --- a/tests/bm_mesh_normal_consistency.py +++ b/tests/bm_mesh_normal_consistency.py @@ -27,3 +27,7 @@ def bm_mesh_normal_consistency() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_mesh_normal_consistency() diff --git a/tests/bm_mesh_rasterizer_transform.py b/tests/bm_mesh_rasterizer_transform.py index 0d875f3f..c21346b2 100644 --- a/tests/bm_mesh_rasterizer_transform.py +++ b/tests/bm_mesh_rasterizer_transform.py @@ -43,3 +43,7 @@ def bm_mesh_rasterizer_transform() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_mesh_rasterizer_transform() diff --git a/tests/bm_meshes.py b/tests/bm_meshes.py index 66c4178e..bd14340a 100644 --- a/tests/bm_meshes.py +++ b/tests/bm_meshes.py @@ -33,3 +33,7 @@ def bm_compute_packed_padded_meshes() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_compute_packed_padded_meshes() diff --git a/tests/bm_packed_to_padded.py b/tests/bm_packed_to_padded.py index ff597a21..5a8f9122 100644 --- a/tests/bm_packed_to_padded.py +++ b/tests/bm_packed_to_padded.py @@ -38,3 +38,7 @@ def bm_packed_to_padded() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_packed_to_padded() diff --git a/tests/bm_perspective_n_points.py b/tests/bm_perspective_n_points.py index 75d77a37..c8f3e1b6 100644 --- a/tests/bm_perspective_n_points.py +++ b/tests/bm_perspective_n_points.py @@ -23,3 +23,7 @@ def bm_perspective_n_points() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_perspective_n_points() diff --git a/tests/bm_point_mesh_distance.py b/tests/bm_point_mesh_distance.py index 2f96b461..494806c7 100644 --- a/tests/bm_point_mesh_distance.py +++ b/tests/bm_point_mesh_distance.py @@ -34,3 +34,7 @@ def bm_point_mesh_distance() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_point_mesh_distance() diff --git a/tests/bm_pointclouds.py b/tests/bm_pointclouds.py index a214ce2f..86362514 100644 --- a/tests/bm_pointclouds.py +++ b/tests/bm_pointclouds.py @@ -28,3 +28,7 @@ def bm_compute_packed_padded_pointclouds() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_compute_packed_padded_pointclouds() diff --git a/tests/bm_points_alignment.py b/tests/bm_points_alignment.py index 39e5bb9a..559b1c8c 100644 --- a/tests/bm_points_alignment.py +++ b/tests/bm_points_alignment.py @@ -69,3 +69,8 @@ def bm_corresponding_points_alignment() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_corresponding_points_alignment() + bm_iterative_closest_point() diff --git a/tests/bm_pulsar.py b/tests/bm_pulsar.py new file mode 100755 index 00000000..6dfbe9f9 --- /dev/null +++ b/tests/bm_pulsar.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""Test render speed.""" +import logging +import sys +from os import path + +import torch +from fvcore.common.benchmark import benchmark +from pytorch3d.renderer.points.pulsar import Renderer +from torch.autograd import Variable + + +# Making sure you can run this, even if pulsar hasn't been installed yet. +sys.path.insert(0, path.join(path.dirname(__file__), "..")) +LOGGER = logging.getLogger(__name__) + + +"""Measure the execution speed of the rendering. + +This measures a very pessimistic upper bound on speed, because synchronization +points have to be introduced in Python. On a pure PyTorch execution pipeline, +results should be significantly faster. You can get pure CUDA timings through +C++ by activating `PULSAR_TIMINGS_BATCHED_ENABLED` in the file +`pytorch3d/csrc/pulsar/logging.h` or defining it for your compiler. +""" + + +def _bm_pulsar(): + n_points = 1_000_000 + width = 1_000 + height = 1_000 + renderer = Renderer(width, height, n_points) + # Generate sample data. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + vert_col = torch.rand(n_points, 3, dtype=torch.float32) + vert_rad = torch.rand(n_points, dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32 + ) + device = torch.device("cuda") + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + renderer = renderer.to(device) + vert_pos_var = Variable(vert_pos, requires_grad=False) + vert_col_var = Variable(vert_col, requires_grad=False) + vert_rad_var = Variable(vert_rad, requires_grad=False) + cam_params_var = Variable(cam_params, requires_grad=False) + + def bm_closure(): + renderer.forward( + vert_pos_var, + vert_col_var, + vert_rad_var, + cam_params_var, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + ) + torch.cuda.synchronize() + + return bm_closure + + +def _bm_pulsar_backward(): + n_points = 1_000_000 + width = 1_000 + height = 1_000 + renderer = Renderer(width, height, n_points) + # Generate sample data. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + vert_col = torch.rand(n_points, 3, dtype=torch.float32) + vert_rad = torch.rand(n_points, dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32 + ) + device = torch.device("cuda") + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + renderer = renderer.to(device) + vert_pos_var = Variable(vert_pos, requires_grad=True) + vert_col_var = Variable(vert_col, requires_grad=True) + vert_rad_var = Variable(vert_rad, requires_grad=True) + cam_params_var = Variable(cam_params, requires_grad=True) + res = renderer.forward( + vert_pos_var, + vert_col_var, + vert_rad_var, + cam_params_var, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + ) + loss = res.sum() + + def bm_closure(): + loss.backward(retain_graph=True) + torch.cuda.synchronize() + + return bm_closure + + +def bm_pulsar() -> None: + if not torch.cuda.is_available(): + return + + benchmark(_bm_pulsar, "PULSAR_FORWARD", [{}], warmup_iters=3) + benchmark(_bm_pulsar_backward, "PULSAR_BACKWARD", [{}], warmup_iters=3) + + +if __name__ == "__main__": + bm_pulsar() diff --git a/tests/bm_rasterize_meshes.py b/tests/bm_rasterize_meshes.py index 68832c1c..fe596e5d 100644 --- a/tests/bm_rasterize_meshes.py +++ b/tests/bm_rasterize_meshes.py @@ -85,3 +85,7 @@ def bm_rasterize_meshes() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_rasterize_meshes() diff --git a/tests/bm_rasterize_points.py b/tests/bm_rasterize_points.py index 0deb3ef2..d00e45ac 100644 --- a/tests/bm_rasterize_points.py +++ b/tests/bm_rasterize_points.py @@ -80,3 +80,7 @@ def bm_python_vs_cpu_vs_cuda() -> None: benchmark( _bm_rasterize_points_with_init, "RASTERIZE_CUDA", kwargs_list, warmup_iters=1 ) + + +if __name__ == "__main__": + bm_python_vs_cpu_vs_cuda() diff --git a/tests/bm_sample_points_from_meshes.py b/tests/bm_sample_points_from_meshes.py index 0b8dbadd..630b4a76 100644 --- a/tests/bm_sample_points_from_meshes.py +++ b/tests/bm_sample_points_from_meshes.py @@ -36,3 +36,7 @@ def bm_sample_points() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_sample_points() diff --git a/tests/bm_so3.py b/tests/bm_so3.py index 9d7ebaa0..6762e2e5 100644 --- a/tests/bm_so3.py +++ b/tests/bm_so3.py @@ -13,3 +13,7 @@ def bm_so3() -> None: ] benchmark(TestSO3.so3_expmap, "SO3_EXP", kwargs_list, warmup_iters=1) benchmark(TestSO3.so3_logmap, "SO3_LOG", kwargs_list, warmup_iters=1) + + +if __name__ == "__main__": + bm_so3() diff --git a/tests/bm_subdivide_meshes.py b/tests/bm_subdivide_meshes.py index c4e5b2bc..e74fc09c 100644 --- a/tests/bm_subdivide_meshes.py +++ b/tests/bm_subdivide_meshes.py @@ -21,3 +21,7 @@ def bm_subdivide() -> None: kwargs_list, warmup_iters=1, ) + + +if __name__ == "__main__": + bm_subdivide() diff --git a/tests/bm_vert_align.py b/tests/bm_vert_align.py index 9b695428..092d5dd2 100644 --- a/tests/bm_vert_align.py +++ b/tests/bm_vert_align.py @@ -27,3 +27,7 @@ def bm_vert_align() -> None: benchmark( TestVertAlign.vert_align_with_init, "VERT_ALIGN", kwargs_list, warmup_iters=1 ) + + +if __name__ == "__main__": + bm_vert_align() diff --git a/tests/pulsar/__init__.py b/tests/pulsar/__init__.py new file mode 100644 index 00000000..40539064 --- /dev/null +++ b/tests/pulsar/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. diff --git a/tests/pulsar/create_multiview.py b/tests/pulsar/create_multiview.py new file mode 100644 index 00000000..e060c69c --- /dev/null +++ b/tests/pulsar/create_multiview.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""Create multiview data.""" +import sys +from os import path + + +# Making sure you can run this, even if pulsar hasn't been installed yet. +sys.path.insert(0, path.join(path.dirname(__file__), "..", "..")) + + +def create_multiview(): + """Test multiview optimization.""" + from pytorch3d.renderer.points.pulsar import Renderer + import torch + from torch import nn + import imageio + from torch.autograd import Variable + + # import cv2 + # import skvideo.io + import numpy as np + + # Constructor. + n_points = 10 + width = 1000 + height = 1000 + + class Model(nn.Module): + """A dummy model to test the integration into a stacked model.""" + + def __init__(self): + super(Model, self).__init__() + self.gamma = 0.1 + self.renderer = Renderer(width, height, n_points) + + def forward(self, vp, vc, vr, cam_params): + # self.gamma *= 0.995 + # print("gamma: ", self.gamma) + return self.renderer.forward(vp, vc, vr, cam_params, self.gamma, 45.0) + + # Generate sample data. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + # print(vert_pos[0]) + vert_col = torch.rand(n_points, 3, dtype=torch.float32) + vert_rad = torch.rand(n_points, dtype=torch.float32) + + # Distortion. + # vert_pos[:, 1] += 0.5 + vert_col *= 0.5 + # vert_rad *= 0.7 + + for device in [torch.device("cuda")]: + model = Model().to(device) + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + for angle_idx, angle in enumerate([-1.5, -0.8, -0.4, -0.1, 0.1, 0.4, 0.8, 1.5]): + vert_pos_v = Variable(vert_pos, requires_grad=False) + vert_col_v = Variable(vert_col, requires_grad=False) + vert_rad_v = Variable(vert_rad, requires_grad=False) + cam_params = torch.tensor( + [ + np.sin(angle) * 35.0, + 0.0, + 30.0 - np.cos(angle) * 35.0, + 0.0, + -angle, + 0.0, + 5.0, + 2.0, + ], + dtype=torch.float32, + ).to(device) + cam_params_v = Variable(cam_params, requires_grad=False) + result = model.forward(vert_pos_v, vert_col_v, vert_rad_v, cam_params_v) + result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) + imageio.imsave( + "reference/examples_TestRenderer_test_multiview_%d.png" % (angle_idx), + result_im, + ) + + +if __name__ == "__main__": + create_multiview() diff --git a/tests/pulsar/reference/examples_TestRenderer_test_cam.png b/tests/pulsar/reference/examples_TestRenderer_test_cam.png new file mode 100644 index 00000000..836aaa4b Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_cam.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_cam_ortho.png b/tests/pulsar/reference/examples_TestRenderer_test_cam_ortho.png new file mode 100644 index 00000000..33b70485 Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_cam_ortho.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_multiview_0.png b/tests/pulsar/reference/examples_TestRenderer_test_multiview_0.png new file mode 100644 index 00000000..f26e27ad Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_multiview_0.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_multiview_1.png b/tests/pulsar/reference/examples_TestRenderer_test_multiview_1.png new file mode 100644 index 00000000..3973792b Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_multiview_1.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_multiview_2.png b/tests/pulsar/reference/examples_TestRenderer_test_multiview_2.png new file mode 100644 index 00000000..9d6c1e12 Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_multiview_2.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_multiview_3.png b/tests/pulsar/reference/examples_TestRenderer_test_multiview_3.png new file mode 100644 index 00000000..e9fabd78 Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_multiview_3.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_multiview_4.png b/tests/pulsar/reference/examples_TestRenderer_test_multiview_4.png new file mode 100644 index 00000000..d882ccf1 Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_multiview_4.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_multiview_5.png b/tests/pulsar/reference/examples_TestRenderer_test_multiview_5.png new file mode 100644 index 00000000..525550bd Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_multiview_5.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_multiview_6.png b/tests/pulsar/reference/examples_TestRenderer_test_multiview_6.png new file mode 100644 index 00000000..ba28c124 Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_multiview_6.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_multiview_7.png b/tests/pulsar/reference/examples_TestRenderer_test_multiview_7.png new file mode 100644 index 00000000..9af131e4 Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_multiview_7.png differ diff --git a/tests/pulsar/reference/examples_TestRenderer_test_smallopt.png b/tests/pulsar/reference/examples_TestRenderer_test_smallopt.png new file mode 100644 index 00000000..c0f26837 Binary files /dev/null and b/tests/pulsar/reference/examples_TestRenderer_test_smallopt.png differ diff --git a/tests/pulsar/reference/nr0000-in.pth b/tests/pulsar/reference/nr0000-in.pth new file mode 100644 index 00000000..c89dd9f1 Binary files /dev/null and b/tests/pulsar/reference/nr0000-in.pth differ diff --git a/tests/pulsar/reference/nr0000-out.pth b/tests/pulsar/reference/nr0000-out.pth new file mode 100644 index 00000000..92b8ecc2 Binary files /dev/null and b/tests/pulsar/reference/nr0000-out.pth differ diff --git a/tests/pulsar/test_channels.py b/tests/pulsar/test_channels.py new file mode 100644 index 00000000..9c6922f7 --- /dev/null +++ b/tests/pulsar/test_channels.py @@ -0,0 +1,149 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""Test number of channels.""" +import logging +import sys +import unittest +from os import path + +import torch + + +# fmt: off +# Make the mixin available. +sys.path.insert(0, path.join(path.dirname(__file__), "..")) +from common_testing import TestCaseMixin # isort:skip # noqa: E402 +# fmt: on + + +sys.path.insert(0, path.join(path.dirname(__file__), "..", "..")) +devices = [torch.device("cuda"), torch.device("cpu")] + + +class TestChannels(TestCaseMixin, unittest.TestCase): + """Test different numbers of channels.""" + + def test_basic(self): + """Basic forward test.""" + from pytorch3d.renderer.points.pulsar import Renderer + import torch + + n_points = 10 + width = 1_000 + height = 1_000 + renderer_1 = Renderer(width, height, n_points, n_channels=1) + renderer_3 = Renderer(width, height, n_points, n_channels=3) + renderer_8 = Renderer(width, height, n_points, n_channels=8) + # Generate sample data. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + vert_col = torch.rand(n_points, 8, dtype=torch.float32) + vert_rad = torch.rand(n_points, dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32 + ) + for device in devices: + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + renderer_1 = renderer_1.to(device) + renderer_3 = renderer_3.to(device) + renderer_8 = renderer_8.to(device) + result_1 = ( + renderer_1.forward( + vert_pos, + vert_col[:, :1], + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + ) + .cpu() + .detach() + .numpy() + ) + hits_1 = ( + renderer_1.forward( + vert_pos, + vert_col[:, :1], + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + .cpu() + .detach() + .numpy() + ) + result_3 = ( + renderer_3.forward( + vert_pos, + vert_col[:, :3], + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + ) + .cpu() + .detach() + .numpy() + ) + hits_3 = ( + renderer_3.forward( + vert_pos, + vert_col[:, :3], + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + .cpu() + .detach() + .numpy() + ) + result_8 = ( + renderer_8.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + ) + .cpu() + .detach() + .numpy() + ) + hits_8 = ( + renderer_8.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + .cpu() + .detach() + .numpy() + ) + self.assertClose(result_1, result_3[:, :, :1]) + self.assertClose(result_3, result_8[:, :, :3]) + self.assertClose(hits_1, hits_3) + self.assertClose(hits_8, hits_3) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + unittest.main() diff --git a/tests/pulsar/test_depth.py b/tests/pulsar/test_depth.py new file mode 100644 index 00000000..82cee449 --- /dev/null +++ b/tests/pulsar/test_depth.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""Test the sorting of the closest spheres.""" +import logging +import os +import sys +import unittest +from os import path + +import imageio +import numpy as np +import torch + + +# fmt: off +# Make the mixin available. +sys.path.insert(0, path.join(path.dirname(__file__), "..")) +from common_testing import TestCaseMixin # isort:skip # noqa: E402 +# fmt: on + +# Making sure you can run this, even if pulsar hasn't been installed yet. +sys.path.insert(0, path.join(path.dirname(__file__), "..", "..")) + +devices = [torch.device("cuda"), torch.device("cpu")] +IN_REF_FP = path.join(path.dirname(__file__), "reference", "nr0000-in.pth") +OUT_REF_FP = path.join(path.dirname(__file__), "reference", "nr0000-out.pth") + + +class TestDepth(TestCaseMixin, unittest.TestCase): + """Test different numbers of channels.""" + + def test_basic(self): + from pytorch3d.renderer.points.pulsar import Renderer + + for device in devices: + gamma = 1e-5 + max_depth = 15.0 + min_depth = 5.0 + renderer = Renderer( + 256, + 256, + 10000, + orthogonal_projection=True, + right_handed_system=False, + n_channels=1, + ).to(device) + data = torch.load(IN_REF_FP, map_location="cpu") + # data["pos"] = torch.rand_like(data["pos"]) + # data["pos"][:, 0] = data["pos"][:, 0] * 2. - 1. + # data["pos"][:, 1] = data["pos"][:, 1] * 2. - 1. + # data["pos"][:, 2] = data["pos"][:, 2] + 9.5 + result, result_info = renderer.forward( + data["pos"].to(device), + data["col"].to(device), + data["rad"].to(device), + data["cam_params"].to(device), + gamma, + min_depth=min_depth, + max_depth=max_depth, + return_forward_info=True, + bg_col=torch.zeros(1, device=device, dtype=torch.float32), + percent_allowed_difference=0.01, + ) + sphere_ids = Renderer.sphere_ids_from_result_info_nograd(result_info) + depth_map = Renderer.depth_map_from_result_info_nograd(result_info) + depth_vis = (depth_map - depth_map[depth_map > 0].min()) * 200 / ( + depth_map.max() - depth_map[depth_map > 0.0].min() + ) + 50 + if not os.environ.get("FB_TEST", False): + imageio.imwrite( + path.join( + path.dirname(__file__), + "test_out", + "test_depth_test_basic_depth.png", + ), + depth_vis.cpu().numpy().astype(np.uint8), + ) + # torch.save( + # data, path.join(path.dirname(__file__), "reference", "nr0000-in.pth") + # ) + # torch.save( + # {"sphere_ids": sphere_ids, "depth_map": depth_map}, + # path.join(path.dirname(__file__), "reference", "nr0000-out.pth"), + # ) + # sys.exit(0) + reference = torch.load(OUT_REF_FP, map_location="cpu") + self.assertTrue( + torch.sum( + reference["sphere_ids"][..., 0].to(device) == sphere_ids[..., 0] + ) + > 65530 + ) + self.assertClose(reference["depth_map"].to(device), depth_map) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + unittest.main() diff --git a/tests/pulsar/test_forward.py b/tests/pulsar/test_forward.py new file mode 100644 index 00000000..44d175f0 --- /dev/null +++ b/tests/pulsar/test_forward.py @@ -0,0 +1,353 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""Basic rendering test.""" +import logging +import os +import sys +import unittest +from os import path + +import imageio +import numpy as np +import torch + + +# Making sure you can run this, even if pulsar hasn't been installed yet. +sys.path.insert(0, path.join(path.dirname(__file__), "..", "..")) +LOGGER = logging.getLogger(__name__) +devices = [torch.device("cuda"), torch.device("cpu")] + + +class TestForward(unittest.TestCase): + """Rendering tests.""" + + def test_bg_weight(self): + """Test background reweighting.""" + from pytorch3d.renderer.points.pulsar import Renderer + + LOGGER.info("Setting up rendering test for 3 channels...") + n_points = 1 + width = 1_000 + height = 1_000 + renderer = Renderer(width, height, n_points, background_normalized_depth=0.999) + vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32) + vert_col = torch.tensor([[0.3, 0.5, 0.7]], dtype=torch.float32) + vert_rad = torch.tensor([1.0], dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32 + ) + for device in devices: + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + renderer = renderer.to(device) + LOGGER.info("Rendering...") + # Measurements. + result = renderer.forward( + vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0 + ) + hits = renderer.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + if not os.environ.get("FB_TEST", False): + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_bg_weight.png", + ), + (result * 255.0).cpu().to(torch.uint8).numpy(), + ) + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_bg_weight_hits.png", + ), + (hits * 255.0).cpu().to(torch.uint8).numpy(), + ) + self.assertEqual(hits[500, 500, 0].item(), 1.0) + self.assertTrue( + np.allclose( + result[500, 500, :].cpu().numpy(), + [1.0, 1.0, 1.0], + rtol=1e-2, + atol=1e-2, + ) + ) + + def test_basic_3chan(self): + """Test rendering one image with one sphere, 3 channels.""" + from pytorch3d.renderer.points.pulsar import Renderer + + LOGGER.info("Setting up rendering test for 3 channels...") + n_points = 1 + width = 1_000 + height = 1_000 + renderer = Renderer(width, height, n_points) + vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32) + vert_col = torch.tensor([[0.3, 0.5, 0.7]], dtype=torch.float32) + vert_rad = torch.tensor([1.0], dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32 + ) + for device in devices: + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + renderer = renderer.to(device) + LOGGER.info("Rendering...") + # Measurements. + result = renderer.forward( + vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0 + ) + hits = renderer.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + if not os.environ.get("FB_TEST", False): + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_basic_3chan.png", + ), + (result * 255.0).cpu().to(torch.uint8).numpy(), + ) + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_basic_3chan_hits.png", + ), + (hits * 255.0).cpu().to(torch.uint8).numpy(), + ) + self.assertEqual(hits[500, 500, 0].item(), 1.0) + self.assertTrue( + np.allclose( + result[500, 500, :].cpu().numpy(), + [0.3, 0.5, 0.7], + rtol=1e-2, + atol=1e-2, + ) + ) + + def test_basic_1chan(self): + """Test rendering one image with one sphere, 1 channel.""" + from pytorch3d.renderer.points.pulsar import Renderer + + LOGGER.info("Setting up rendering test for 1 channel...") + n_points = 1 + width = 1_000 + height = 1_000 + renderer = Renderer(width, height, n_points, n_channels=1) + vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32) + vert_col = torch.tensor([[0.3]], dtype=torch.float32) + vert_rad = torch.tensor([1.0], dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32 + ) + for device in devices: + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + renderer = renderer.to(device) + LOGGER.info("Rendering...") + # Measurements. + result = renderer.forward( + vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0 + ) + hits = renderer.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + if not os.environ.get("FB_TEST", False): + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_basic_1chan.png", + ), + (result * 255.0).cpu().to(torch.uint8).numpy(), + ) + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_basic_1chan_hits.png", + ), + (hits * 255.0).cpu().to(torch.uint8).numpy(), + ) + self.assertEqual(hits[500, 500, 0].item(), 1.0) + self.assertTrue( + np.allclose( + result[500, 500, :].cpu().numpy(), [0.3], rtol=1e-2, atol=1e-2 + ) + ) + + def test_basic_8chan(self): + """Test rendering one image with one sphere, 8 channels.""" + from pytorch3d.renderer.points.pulsar import Renderer + + LOGGER.info("Setting up rendering test for 8 channels...") + n_points = 1 + width = 1_000 + height = 1_000 + renderer = Renderer(width, height, n_points, n_channels=8) + vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32) + vert_col = torch.tensor( + [[1.0, 1.0, 1.0, 1.0, 1.0, 0.3, 0.5, 0.7]], dtype=torch.float32 + ) + vert_rad = torch.tensor([1.0], dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32 + ) + for device in devices: + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + renderer = renderer.to(device) + LOGGER.info("Rendering...") + # Measurements. + result = renderer.forward( + vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0 + ) + hits = renderer.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + if not os.environ.get("FB_TEST", False): + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_basic_8chan.png", + ), + (result[:, :, 5:8] * 255.0).cpu().to(torch.uint8).numpy(), + ) + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_basic_8chan_hits.png", + ), + (hits * 255.0).cpu().to(torch.uint8).numpy(), + ) + self.assertEqual(hits[500, 500, 0].item(), 1.0) + self.assertTrue( + np.allclose( + result[500, 500, 5:8].cpu().numpy(), + [0.3, 0.5, 0.7], + rtol=1e-2, + atol=1e-2, + ) + ) + self.assertTrue( + np.allclose( + result[500, 500, :5].cpu().numpy(), 1.0, rtol=1e-2, atol=1e-2 + ) + ) + + def test_principal_point(self): + """Test shifting the principal point.""" + from pytorch3d.renderer.points.pulsar import Renderer + + LOGGER.info("Setting up rendering test for shifted principal point...") + n_points = 1 + width = 1_000 + height = 1_000 + renderer = Renderer(width, height, n_points, n_channels=1) + vert_pos = torch.tensor([[0.0, 0.0, 25.0]], dtype=torch.float32) + vert_col = torch.tensor([[0.0]], dtype=torch.float32) + vert_rad = torch.tensor([1.0], dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0, 0.0, 0.0], dtype=torch.float32 + ) + for device in devices: + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + cam_params[-2] = -250.0 + cam_params[-1] = -250.0 + renderer = renderer.to(device) + LOGGER.info("Rendering...") + # Measurements. + result = renderer.forward( + vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0 + ) + if not os.environ.get("FB_TEST", False): + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_principal_point.png", + ), + (result * 255.0).cpu().to(torch.uint8).numpy(), + ) + self.assertTrue( + np.allclose( + result[750, 750, :].cpu().numpy(), [0.0], rtol=1e-2, atol=1e-2 + ) + ) + for device in devices: + vert_pos = vert_pos.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + cam_params[-2] = 250.0 + cam_params[-1] = 250.0 + renderer = renderer.to(device) + LOGGER.info("Rendering...") + # Measurements. + result = renderer.forward( + vert_pos, vert_col, vert_rad, cam_params, 1.0e-1, 45.0 + ) + if not os.environ.get("FB_TEST", False): + imageio.imsave( + path.join( + path.dirname(__file__), + "test_out", + "test_forward_TestForward_test_principal_point.png", + ), + (result * 255.0).cpu().to(torch.uint8).numpy(), + ) + self.assertTrue( + np.allclose( + result[250, 250, :].cpu().numpy(), [0.0], rtol=1e-2, atol=1e-2 + ) + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + logging.getLogger("pulsar.renderer").setLevel(logging.WARN) + unittest.main() diff --git a/tests/pulsar/test_hands.py b/tests/pulsar/test_hands.py new file mode 100644 index 00000000..259bfe99 --- /dev/null +++ b/tests/pulsar/test_hands.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""Test right hand/left hand system compatibility.""" +import logging +import sys +import unittest +from os import path + +import torch + + +# fmt: off +# Make the mixin available. +sys.path.insert(0, path.join(path.dirname(__file__), "..")) +from common_testing import TestCaseMixin # isort:skip # noqa: E402 +# fmt: on + + +# Making sure you can run this, even if pulsar hasn't been installed yet. +sys.path.insert(0, path.join(path.dirname(__file__), "..", "..")) +devices = [torch.device("cuda"), torch.device("cpu")] + + +class TestHands(TestCaseMixin, unittest.TestCase): + """Test right hand/left hand system compatibility.""" + + def test_basic(self): + """Basic forward test.""" + from pytorch3d.renderer.points.pulsar import Renderer + + n_points = 10 + width = 1000 + height = 1000 + renderer_left = Renderer(width, height, n_points, right_handed_system=False) + renderer_right = Renderer(width, height, n_points, right_handed_system=True) + # Generate sample data. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + vert_pos_neg = vert_pos.clone() + vert_pos_neg[:, 2] *= -1.0 + vert_col = torch.rand(n_points, 3, dtype=torch.float32) + vert_rad = torch.rand(n_points, dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32 + ) + for device in devices: + vert_pos = vert_pos.to(device) + vert_pos_neg = vert_pos_neg.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + renderer_left = renderer_left.to(device) + renderer_right = renderer_right.to(device) + result_left = ( + renderer_left.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + ) + .cpu() + .detach() + .numpy() + ) + hits_left = ( + renderer_left.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + .cpu() + .detach() + .numpy() + ) + result_right = ( + renderer_right.forward( + vert_pos_neg, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + ) + .cpu() + .detach() + .numpy() + ) + hits_right = ( + renderer_right.forward( + vert_pos_neg, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + .cpu() + .detach() + .numpy() + ) + self.assertClose(result_left, result_right) + self.assertClose(hits_left, hits_right) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + logging.getLogger("pulsar.renderer").setLevel(logging.WARN) + unittest.main() diff --git a/tests/pulsar/test_ortho.py b/tests/pulsar/test_ortho.py new file mode 100644 index 00000000..64f377a1 --- /dev/null +++ b/tests/pulsar/test_ortho.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""Tests for the orthogonal projection.""" +import logging +import sys +import unittest +from os import path + +import numpy as np +import torch + + +# Making sure you can run this, even if pulsar hasn't been installed yet. +sys.path.insert(0, path.join(path.dirname(__file__), "..")) +devices = [torch.device("cuda"), torch.device("cpu")] + + +class TestOrtho(unittest.TestCase): + """Test the orthogonal projection.""" + + def test_basic(self): + """Basic forward test of the orthogonal projection.""" + from pytorch3d.renderer.points.pulsar import Renderer + + n_points = 10 + width = 1000 + height = 1000 + renderer_left = Renderer( + width, + height, + n_points, + right_handed_system=False, + orthogonal_projection=True, + ) + renderer_right = Renderer( + width, + height, + n_points, + right_handed_system=True, + orthogonal_projection=True, + ) + # Generate sample data. + torch.manual_seed(1) + vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 + vert_pos[:, 2] += 25.0 + vert_pos[:, :2] -= 5.0 + vert_pos_neg = vert_pos.clone() + vert_pos_neg[:, 2] *= -1.0 + vert_col = torch.rand(n_points, 3, dtype=torch.float32) + vert_rad = torch.rand(n_points, dtype=torch.float32) + cam_params = torch.tensor( + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 20.0], dtype=torch.float32 + ) + for device in devices: + vert_pos = vert_pos.to(device) + vert_pos_neg = vert_pos_neg.to(device) + vert_col = vert_col.to(device) + vert_rad = vert_rad.to(device) + cam_params = cam_params.to(device) + renderer_left = renderer_left.to(device) + renderer_right = renderer_right.to(device) + result_left = ( + renderer_left.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + ) + .cpu() + .detach() + .numpy() + ) + hits_left = ( + renderer_left.forward( + vert_pos, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + .cpu() + .detach() + .numpy() + ) + result_right = ( + renderer_right.forward( + vert_pos_neg, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + ) + .cpu() + .detach() + .numpy() + ) + hits_right = ( + renderer_right.forward( + vert_pos_neg, + vert_col, + vert_rad, + cam_params, + 1.0e-1, + 45.0, + percent_allowed_difference=0.01, + mode=1, + ) + .cpu() + .detach() + .numpy() + ) + self.assertTrue(np.allclose(result_left, result_right)) + self.assertTrue(np.allclose(hits_left, hits_right)) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + logging.getLogger("pulsar.renderer").setLevel(logging.WARN) + unittest.main() diff --git a/tests/pulsar/test_out/empty.txt b/tests/pulsar/test_out/empty.txt new file mode 100644 index 00000000..e69de29b diff --git a/tests/pulsar/test_small_spheres.py b/tests/pulsar/test_small_spheres.py new file mode 100644 index 00000000..2cefccff --- /dev/null +++ b/tests/pulsar/test_small_spheres.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +"""Test right hand/left hand system compatibility.""" +import sys +import unittest +from os import path + +import numpy as np +import torch +from torch import nn + + +sys.path.insert(0, path.join(path.dirname(__file__), "..")) +devices = [torch.device("cuda"), torch.device("cpu")] + + +n_points = 10 +width = 1_000 +height = 1_000 + + +class SceneModel(nn.Module): + """A simple model to demonstrate use in Modules.""" + + def __init__(self): + super(SceneModel, self).__init__() + from pytorch3d.renderer.points.pulsar import Renderer + + self.gamma = 1.0 + # Points. + torch.manual_seed(1) + vert_pos = torch.rand((1, n_points, 3), dtype=torch.float32) * 10.0 + vert_pos[:, :, 2] += 25.0 + vert_pos[:, :, :2] -= 5.0 + self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False)) + self.register_parameter( + "vert_col", + nn.Parameter( + torch.zeros(1, n_points, 3, dtype=torch.float32), requires_grad=True + ), + ) + self.register_parameter( + "vert_rad", + nn.Parameter( + torch.ones(1, n_points, dtype=torch.float32) * 0.001, + requires_grad=False, + ), + ) + self.register_parameter( + "vert_opy", + nn.Parameter( + torch.ones(1, n_points, dtype=torch.float32), requires_grad=False + ), + ) + self.register_buffer( + "cam_params", + torch.tensor( + [ + [ + np.sin(angle) * 35.0, + 0.0, + 30.0 - np.cos(angle) * 35.0, + 0.0, + -angle, + 0.0, + 5.0, + 2.0, + ] + for angle in [-1.5, -0.8, -0.4, -0.1, 0.1, 0.4, 0.8, 1.5] + ], + dtype=torch.float32, + ), + ) + self.renderer = Renderer(width, height, n_points) + + def forward(self, cam=None): + if cam is None: + cam = self.cam_params + n_views = 8 + else: + n_views = 1 + return self.renderer.forward( + self.vert_pos.expand(n_views, -1, -1), + self.vert_col.expand(n_views, -1, -1), + self.vert_rad.expand(n_views, -1), + cam, + self.gamma, + 45.0, + return_forward_info=True, + ) + + +class TestSmallSpheres(unittest.TestCase): + """Test small sphere rendering and gradients.""" + + def test_basic(self): + for device in devices: + # Set up model. + model = SceneModel().to(device) + angle = 0.0 + for _ in range(50): + cam_control = torch.tensor( + [ + [ + np.sin(angle) * 35.0, + 0.0, + 30.0 - np.cos(angle) * 35.0, + 0.0, + -angle, + 0.0, + 5.0, + 2.0, + ] + ], + dtype=torch.float32, + ).to(device) + result, forw_info = model(cam=cam_control) + sphere_ids = model.renderer.sphere_ids_from_result_info_nograd( + forw_info + ) + # Assert all spheres are rendered. + for idx in range(n_points): + self.assertTrue( + (sphere_ids == idx).sum() > 0, "Sphere ID %d missing!" % (idx) + ) + # Visualize. + # result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) + # cv2.imshow("res", result_im[0, :, :, ::-1]) + # cv2.waitKey(0) + # Back-propagate some dummy gradients. + loss = ((result - torch.ones_like(result)).abs()).sum() + loss.backward() + # Now check whether the gradient arrives at every sphere. + self.assertTrue(torch.all(model.vert_col.grad[:, :, 0].abs() > 0.0)) + angle += 0.15 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_build.py b/tests/test_build.py index 1c47c5de..1b30607b 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -27,28 +27,6 @@ class TestBuild(unittest.TestCase): for k, v in counter.items(): self.assertEqual(v, 1, f"Too many files with stem {k}.") - @unittest.skipIf(in_conda_build, "In conda build") - def test_deprecated_usage(self): - # Check certain expressions do not occur in the csrc code - test_dir = Path(__file__).resolve().parent - source_dir = test_dir.parent / "pytorch3d" / "csrc" - - files = sorted(source_dir.glob("**/*.*")) - self.assertGreater(len(files), 4) - - patterns = [".type()", ".data()"] - - for file in files: - with open(file) as f: - text = f.read() - for pattern in patterns: - found = pattern in text - msg = ( - f"{pattern} found in {file.name}" - + ", this has been deprecated." - ) - self.assertFalse(found, msg) - @unittest.skipIf(in_conda_build, "In conda build") def test_copyright(self): test_dir = Path(__file__).resolve().parent @@ -63,6 +41,13 @@ class TestBuild(unittest.TestCase): for extension in extensions: for i in root_dir.glob(f"**/*.{extension}"): + print(i) + if str(i).endswith( + "pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py" + ): + continue + if str(i).endswith("pytorch3d/csrc/pulsar/include/fastermath.h"): + continue with open(i) as f: firstline = f.readline() if firstline.startswith(("# -*-", "#!")):