mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
pulsar integration.
Summary: This diff integrates the pulsar renderer source code into PyTorch3D as an alternative backend for the PyTorch3D point renderer. This diff is the first of a series of three diffs to complete that migration and focuses on the packaging and integration of the source code. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder `docs/examples`. Tasks addressed in the following diffs: * Add the PyTorch3D interface, * Add notebook examples and documentation (or adapt the existing ones to feature both interfaces). Reviewed By: nikhilaravi Differential Revision: D23947736 fbshipit-source-id: a5e77b53e6750334db22aefa89b4c079cda1b443
This commit is contained in:
parent
d565032399
commit
b19fe1de2f
50
docs/examples/pulsar_basic.py
Executable file
50
docs/examples/pulsar_basic.py
Executable file
@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
This example demonstrates the most trivial, direct interface of the pulsar
|
||||
sphere renderer. It renders and saves an image with 10 random spheres.
|
||||
Output: basic.png.
|
||||
"""
|
||||
from os import path
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
from pytorch3d.renderer.points.pulsar import Renderer
|
||||
|
||||
|
||||
n_points = 10
|
||||
width = 1_000
|
||||
height = 1_000
|
||||
device = torch.device("cuda")
|
||||
renderer = Renderer(width, height, n_points).to(device)
|
||||
# Generate sample data.
|
||||
vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0
|
||||
vert_pos[:, 2] += 25.0
|
||||
vert_pos[:, :2] -= 5.0
|
||||
vert_col = torch.rand(n_points, 3, dtype=torch.float32, device=device)
|
||||
vert_rad = torch.rand(n_points, dtype=torch.float32, device=device)
|
||||
cam_params = torch.tensor(
|
||||
[
|
||||
0.0,
|
||||
0.0,
|
||||
0.0, # Position 0, 0, 0 (x, y, z).
|
||||
0.0,
|
||||
0.0,
|
||||
0.0, # Rotation 0, 0, 0 (in axis-angle format).
|
||||
5.0, # Focal length in world size.
|
||||
2.0, # Sensor size in world size.
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
# Render.
|
||||
image = renderer(
|
||||
vert_pos,
|
||||
vert_col,
|
||||
vert_rad,
|
||||
cam_params,
|
||||
1.0e-1, # Renderer blending parameter gamma, in [1., 1e-5].
|
||||
45.0, # Maximum depth.
|
||||
)
|
||||
print("Writing image to `%s`." % (path.abspath("basic.png")))
|
||||
imageio.imsave("basic.png", (image.cpu().detach() * 255.0).to(torch.uint8).numpy())
|
158
docs/examples/pulsar_cam.py
Executable file
158
docs/examples/pulsar_cam.py
Executable file
@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
This example demonstrates camera parameter optimization with the plain
|
||||
pulsar interface. For this, a reference image has been pre-generated
|
||||
(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_cam.png`).
|
||||
The same scene parameterization is loaded and the camera parameters
|
||||
distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
"""
|
||||
from os import path
|
||||
|
||||
import cv2
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer.points.pulsar import Renderer
|
||||
from torch import nn, optim
|
||||
|
||||
|
||||
n_points = 20
|
||||
width = 1_000
|
||||
height = 1_000
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
class SceneModel(nn.Module):
|
||||
"""
|
||||
A simple scene model to demonstrate use of pulsar in PyTorch modules.
|
||||
|
||||
The scene model is parameterized with sphere locations (vert_pos),
|
||||
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
|
||||
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
|
||||
|
||||
The forward method of the model renders this scene description. Any
|
||||
of these parameters could instead be passed as inputs to the forward
|
||||
method and come from a different model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SceneModel, self).__init__()
|
||||
self.gamma = 0.1
|
||||
# Points.
|
||||
torch.manual_seed(1)
|
||||
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
|
||||
vert_pos[:, 2] += 25.0
|
||||
vert_pos[:, :2] -= 5.0
|
||||
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False))
|
||||
self.register_parameter(
|
||||
"vert_col",
|
||||
nn.Parameter(
|
||||
torch.rand(n_points, 3, dtype=torch.float32), requires_grad=False
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"vert_rad",
|
||||
nn.Parameter(
|
||||
torch.rand(n_points, dtype=torch.float32), requires_grad=False
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"cam_pos",
|
||||
nn.Parameter(
|
||||
torch.tensor([0.1, 0.1, 0.0], dtype=torch.float32), requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"cam_rot",
|
||||
nn.Parameter(
|
||||
torch.tensor(
|
||||
[
|
||||
# We're using the 6D rot. representation for better gradients.
|
||||
0.9995,
|
||||
0.0300445,
|
||||
-0.0098482,
|
||||
-0.0299445,
|
||||
0.9995,
|
||||
0.0101482,
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=True,
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"cam_sensor",
|
||||
nn.Parameter(
|
||||
torch.tensor([4.8, 1.8], dtype=torch.float32), requires_grad=True
|
||||
),
|
||||
)
|
||||
self.renderer = Renderer(width, height, n_points)
|
||||
|
||||
def forward(self):
|
||||
return self.renderer.forward(
|
||||
self.vert_pos,
|
||||
self.vert_col,
|
||||
self.vert_rad,
|
||||
torch.cat([self.cam_pos, self.cam_rot, self.cam_sensor]),
|
||||
self.gamma,
|
||||
45.0,
|
||||
)
|
||||
|
||||
|
||||
# Load reference.
|
||||
ref = (
|
||||
torch.from_numpy(
|
||||
imageio.imread(
|
||||
"../../tests/pulsar/reference/examples_TestRenderer_test_cam.png"
|
||||
)
|
||||
).to(torch.float32)
|
||||
/ 255.0
|
||||
).to(device)
|
||||
# Set up model.
|
||||
model = SceneModel().to(device)
|
||||
# Optimizer.
|
||||
optimizer = optim.SGD(
|
||||
[
|
||||
{"params": [model.cam_pos], "lr": 1e-4}, # 1e-3
|
||||
{"params": [model.cam_rot], "lr": 5e-6},
|
||||
{"params": [model.cam_sensor], "lr": 1e-4},
|
||||
]
|
||||
)
|
||||
|
||||
print("Writing video to `%s`." % (path.abspath("cam.gif")))
|
||||
writer = imageio.get_writer("cam.gif", format="gif", fps=25)
|
||||
|
||||
# Optimize.
|
||||
for i in range(300):
|
||||
optimizer.zero_grad()
|
||||
result = model()
|
||||
# Visualize.
|
||||
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
||||
cv2.imshow("opt", result_im[:, :, ::-1])
|
||||
writer.append_data(result_im)
|
||||
overlay_img = np.ascontiguousarray(
|
||||
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
|
||||
:, :, ::-1
|
||||
]
|
||||
)
|
||||
overlay_img = cv2.putText(
|
||||
overlay_img,
|
||||
"Step %d" % (i),
|
||||
(10, 40),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(0, 0, 0),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
False,
|
||||
)
|
||||
cv2.imshow("overlay", overlay_img)
|
||||
cv2.waitKey(1)
|
||||
# Update.
|
||||
loss = ((result - ref) ** 2).sum()
|
||||
print("loss {}: {}".format(i, loss.item()))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
writer.close()
|
201
docs/examples/pulsar_multiview.py
Executable file
201
docs/examples/pulsar_multiview.py
Executable file
@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
This example demonstrates multiview 3D reconstruction using the plain
|
||||
pulsar interface. For this, reference images have been pre-generated
|
||||
(you can find them at `../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png`).
|
||||
The camera parameters are assumed given. The scene is initialized with
|
||||
random spheres. Gradient-based optimization is used to optimize sphere
|
||||
parameters and prune spheres to converge to a 3D representation.
|
||||
"""
|
||||
from os import path
|
||||
|
||||
import cv2
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer.points.pulsar import Renderer
|
||||
from torch import nn, optim
|
||||
|
||||
|
||||
n_points = 400_000
|
||||
width = 1_000
|
||||
height = 1_000
|
||||
visualize_ids = [0, 1]
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
class SceneModel(nn.Module):
|
||||
"""
|
||||
A simple scene model to demonstrate use of pulsar in PyTorch modules.
|
||||
|
||||
The scene model is parameterized with sphere locations (vert_pos),
|
||||
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
|
||||
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
|
||||
|
||||
The forward method of the model renders this scene description. Any
|
||||
of these parameters could instead be passed as inputs to the forward
|
||||
method and come from a different model. Optionally, camera parameters can
|
||||
be provided to the forward method in which case the scene is rendered
|
||||
using those parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SceneModel, self).__init__()
|
||||
self.gamma = 1.0
|
||||
# Points.
|
||||
torch.manual_seed(1)
|
||||
vert_pos = torch.rand((1, n_points, 3), dtype=torch.float32) * 10.0
|
||||
vert_pos[:, :, 2] += 25.0
|
||||
vert_pos[:, :, :2] -= 5.0
|
||||
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
|
||||
self.register_parameter(
|
||||
"vert_col",
|
||||
nn.Parameter(
|
||||
torch.ones(1, n_points, 3, dtype=torch.float32) * 0.5,
|
||||
requires_grad=True,
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"vert_rad",
|
||||
nn.Parameter(
|
||||
torch.ones(1, n_points, dtype=torch.float32) * 0.05, requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"vert_opy",
|
||||
nn.Parameter(
|
||||
torch.ones(1, n_points, dtype=torch.float32), requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"cam_params",
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
np.sin(angle) * 35.0,
|
||||
0.0,
|
||||
30.0 - np.cos(angle) * 35.0,
|
||||
0.0,
|
||||
-angle,
|
||||
0.0,
|
||||
5.0,
|
||||
2.0,
|
||||
]
|
||||
for angle in [-1.5, -0.8, -0.4, -0.1, 0.1, 0.4, 0.8, 1.5]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
self.renderer = Renderer(width, height, n_points)
|
||||
|
||||
def forward(self, cam=None):
|
||||
if cam is None:
|
||||
cam = self.cam_params
|
||||
n_views = 8
|
||||
else:
|
||||
n_views = 1
|
||||
return self.renderer.forward(
|
||||
self.vert_pos.expand(n_views, -1, -1),
|
||||
self.vert_col.expand(n_views, -1, -1),
|
||||
self.vert_rad.expand(n_views, -1),
|
||||
cam,
|
||||
self.gamma,
|
||||
45.0,
|
||||
)
|
||||
|
||||
|
||||
# Load reference.
|
||||
ref = torch.stack(
|
||||
[
|
||||
torch.from_numpy(
|
||||
imageio.imread(
|
||||
"../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png"
|
||||
% idx
|
||||
)
|
||||
).to(torch.float32)
|
||||
/ 255.0
|
||||
for idx in range(8)
|
||||
]
|
||||
).to(device)
|
||||
# Set up model.
|
||||
model = SceneModel().to(device)
|
||||
# Optimizer.
|
||||
optimizer = optim.SGD(
|
||||
[
|
||||
{"params": [model.vert_col], "lr": 1e-1},
|
||||
{"params": [model.vert_rad], "lr": 1e-3},
|
||||
{"params": [model.vert_pos], "lr": 1e-3},
|
||||
]
|
||||
)
|
||||
|
||||
# For visualization.
|
||||
angle = 0.0
|
||||
print("Writing video to `%s`." % (path.abspath("multiview.avi")))
|
||||
writer = imageio.get_writer("multiview.gif", format="gif", fps=25)
|
||||
|
||||
# Optimize.
|
||||
for i in range(300):
|
||||
optimizer.zero_grad()
|
||||
result = model()
|
||||
# Visualize.
|
||||
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
||||
cv2.imshow("opt", result_im[0, :, :, ::-1])
|
||||
overlay_img = np.ascontiguousarray(
|
||||
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
|
||||
0, :, :, ::-1
|
||||
]
|
||||
)
|
||||
overlay_img = cv2.putText(
|
||||
overlay_img,
|
||||
"Step %d" % (i),
|
||||
(10, 40),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(0, 0, 0),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
False,
|
||||
)
|
||||
cv2.imshow("overlay", overlay_img)
|
||||
cv2.waitKey(1)
|
||||
# Update.
|
||||
loss = ((result - ref) ** 2).sum()
|
||||
print("loss {}: {}".format(i, loss.item()))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# Cleanup.
|
||||
with torch.no_grad():
|
||||
model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
|
||||
# Remove points.
|
||||
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
|
||||
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
|
||||
vd = (
|
||||
(model.vert_col - torch.ones(1, 1, 3, dtype=torch.float32).to(device))
|
||||
.abs()
|
||||
.sum(dim=2)
|
||||
)
|
||||
model.vert_pos.data[vd <= 0.2] = -1000.0
|
||||
# Rotating visualization.
|
||||
cam_control = torch.tensor(
|
||||
[
|
||||
[
|
||||
np.sin(angle) * 35.0,
|
||||
0.0,
|
||||
30.0 - np.cos(angle) * 35.0,
|
||||
0.0,
|
||||
-angle,
|
||||
0.0,
|
||||
5.0,
|
||||
2.0,
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
).to(device)
|
||||
with torch.no_grad():
|
||||
result = model.forward(cam=cam_control)[0]
|
||||
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
||||
cv2.imshow("vis", result_im[:, :, ::-1])
|
||||
writer.append_data(result_im)
|
||||
angle += 0.05
|
||||
writer.close()
|
140
docs/examples/pulsar_optimization.py
Executable file
140
docs/examples/pulsar_optimization.py
Executable file
@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
This example demonstrates scene optimization with the plain
|
||||
pulsar interface. For this, a reference image has been pre-generated
|
||||
(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png`).
|
||||
The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
import cv2
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer.points.pulsar import Renderer
|
||||
from torch import nn, optim
|
||||
|
||||
|
||||
n_points = 10_000
|
||||
width = 1_000
|
||||
height = 1_000
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
class SceneModel(nn.Module):
|
||||
"""
|
||||
A simple scene model to demonstrate use of pulsar in PyTorch modules.
|
||||
|
||||
The scene model is parameterized with sphere locations (vert_pos),
|
||||
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
|
||||
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
|
||||
|
||||
The forward method of the model renders this scene description. Any
|
||||
of these parameters could instead be passed as inputs to the forward
|
||||
method and come from a different model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SceneModel, self).__init__()
|
||||
self.gamma = 1.0
|
||||
# Points.
|
||||
torch.manual_seed(1)
|
||||
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
|
||||
vert_pos[:, 2] += 25.0
|
||||
vert_pos[:, :2] -= 5.0
|
||||
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
|
||||
self.register_parameter(
|
||||
"vert_col",
|
||||
nn.Parameter(
|
||||
torch.ones(n_points, 3, dtype=torch.float32) * 0.5, requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"vert_rad",
|
||||
nn.Parameter(
|
||||
torch.ones(n_points, dtype=torch.float32) * 0.3, requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"cam_params",
|
||||
torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32),
|
||||
)
|
||||
# The volumetric optimization works better with a higher number of tracked
|
||||
# intersections per ray.
|
||||
self.renderer = Renderer(width, height, n_points, n_track=32)
|
||||
|
||||
def forward(self):
|
||||
return self.renderer.forward(
|
||||
self.vert_pos,
|
||||
self.vert_col,
|
||||
self.vert_rad,
|
||||
self.cam_params,
|
||||
self.gamma,
|
||||
45.0,
|
||||
return_forward_info=True,
|
||||
)
|
||||
|
||||
|
||||
# Load reference.
|
||||
ref = (
|
||||
torch.from_numpy(
|
||||
imageio.imread(
|
||||
"../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png"
|
||||
)
|
||||
).to(torch.float32)
|
||||
/ 255.0
|
||||
).to(device)
|
||||
# Set up model.
|
||||
model = SceneModel().to(device)
|
||||
# Optimizer.
|
||||
optimizer = optim.SGD(
|
||||
[
|
||||
{"params": [model.vert_col], "lr": 1e0},
|
||||
{"params": [model.vert_rad], "lr": 5e-3},
|
||||
{"params": [model.vert_pos], "lr": 1e-2},
|
||||
]
|
||||
)
|
||||
|
||||
# Optimize.
|
||||
for i in range(500):
|
||||
optimizer.zero_grad()
|
||||
result, result_info = model()
|
||||
# Visualize.
|
||||
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
||||
cv2.imshow("opt", result_im[:, :, ::-1])
|
||||
overlay_img = np.ascontiguousarray(
|
||||
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
|
||||
:, :, ::-1
|
||||
]
|
||||
)
|
||||
overlay_img = cv2.putText(
|
||||
overlay_img,
|
||||
"Step %d" % (i),
|
||||
(10, 40),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(0, 0, 0),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
False,
|
||||
)
|
||||
cv2.imshow("overlay", overlay_img)
|
||||
cv2.waitKey(1)
|
||||
# Update.
|
||||
loss = ((result - ref) ** 2).sum()
|
||||
print("loss {}: {}".format(i, loss.item()))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# Cleanup.
|
||||
with torch.no_grad():
|
||||
model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
|
||||
# Remove points.
|
||||
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
|
||||
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
|
||||
vd = (
|
||||
(model.vert_col - torch.ones(3, dtype=torch.float32).to(device))
|
||||
.abs()
|
||||
.sum(dim=1)
|
||||
)
|
||||
model.vert_pos.data[vd <= 0.2] = -1000.0
|
@ -1,6 +1,11 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
// clang-format off
|
||||
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
||||
#include <torch/extension.h>
|
||||
// clang-format on
|
||||
#include "./pulsar/pytorch/renderer.h"
|
||||
#include "./pulsar/pytorch/tensor_util.h"
|
||||
#include "blending/sigmoid_alpha_blend.h"
|
||||
#include "compositing/alpha_composite.h"
|
||||
#include "compositing/norm_weighted_sum.h"
|
||||
@ -65,4 +70,90 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("face_point_dist_backward", &FacePointDistanceBackward);
|
||||
m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward);
|
||||
m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward);
|
||||
|
||||
// Pulsar.
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
c10::ShowLogInfoToStderr();
|
||||
#endif
|
||||
py::class_<
|
||||
pulsar::pytorch::Renderer,
|
||||
std::shared_ptr<pulsar::pytorch::Renderer>>(m, "PulsarRenderer")
|
||||
.def(py::init<
|
||||
const uint&,
|
||||
const uint&,
|
||||
const uint&,
|
||||
const bool&,
|
||||
const bool&,
|
||||
const float&,
|
||||
const uint&,
|
||||
const uint&>())
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const pulsar::pytorch::Renderer& a,
|
||||
const pulsar::pytorch::Renderer& b) { return a == b; },
|
||||
py::is_operator())
|
||||
.def(
|
||||
"__ne__",
|
||||
[](const pulsar::pytorch::Renderer& a,
|
||||
const pulsar::pytorch::Renderer& b) { return !(a == b); },
|
||||
py::is_operator())
|
||||
.def(
|
||||
"__repr__",
|
||||
[](const pulsar::pytorch::Renderer& self) {
|
||||
std::stringstream ss;
|
||||
ss << self;
|
||||
return ss.str();
|
||||
})
|
||||
.def(
|
||||
"forward",
|
||||
&pulsar::pytorch::Renderer::forward,
|
||||
py::arg("vert_pos"),
|
||||
py::arg("vert_col"),
|
||||
py::arg("vert_radii"),
|
||||
|
||||
py::arg("cam_pos"),
|
||||
py::arg("pixel_0_0_center"),
|
||||
py::arg("pixel_vec_x"),
|
||||
py::arg("pixel_vec_y"),
|
||||
py::arg("focal_length"),
|
||||
py::arg("principal_point_offsets"),
|
||||
|
||||
py::arg("gamma"),
|
||||
py::arg("max_depth"),
|
||||
py::arg("min_depth") /* = 0.f*/,
|
||||
py::arg(
|
||||
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */
|
||||
,
|
||||
py::arg("opacity") /* = at::nullopt ... */,
|
||||
py::arg("percent_allowed_difference") = 0.01f,
|
||||
py::arg("max_n_hits") = MAX_UINT,
|
||||
py::arg("mode") = 0)
|
||||
.def("backward", &pulsar::pytorch::Renderer::backward)
|
||||
.def_property(
|
||||
"device_tracker",
|
||||
[](const pulsar::pytorch::Renderer& self) {
|
||||
return self.device_tracker;
|
||||
},
|
||||
[](pulsar::pytorch::Renderer& self, const torch::Tensor& val) {
|
||||
self.device_tracker = val;
|
||||
})
|
||||
.def_property_readonly("width", &pulsar::pytorch::Renderer::width)
|
||||
.def_property_readonly("height", &pulsar::pytorch::Renderer::height)
|
||||
.def_property_readonly(
|
||||
"max_num_balls", &pulsar::pytorch::Renderer::max_num_balls)
|
||||
.def_property_readonly(
|
||||
"orthogonal", &pulsar::pytorch::Renderer::orthogonal)
|
||||
.def_property_readonly(
|
||||
"right_handed", &pulsar::pytorch::Renderer::right_handed)
|
||||
.def_property_readonly("n_track", &pulsar::pytorch::Renderer::n_track);
|
||||
m.def(
|
||||
"pulsar_sphere_ids_from_result_info_nograd",
|
||||
&pulsar::pytorch::sphere_ids_from_result_info_nograd);
|
||||
// Constants.
|
||||
m.attr("EPS") = py::float_(EPS);
|
||||
m.attr("MAX_FLOAT") = py::float_(MAX_FLOAT);
|
||||
m.attr("MAX_INT") = py::int_(MAX_INT);
|
||||
m.attr("MAX_UINT") = py::int_(MAX_UINT);
|
||||
m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
|
||||
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
|
||||
}
|
||||
|
12
pytorch3d/csrc/pulsar/constants.h
Normal file
12
pytorch3d/csrc/pulsar/constants.h
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_CONSTANTS_H_
|
||||
#define PULSAR_NATIVE_CONSTANTS_H_
|
||||
|
||||
#define EPS 1E-6
|
||||
#define FEPS 1E-6f
|
||||
#define MAX_FLOAT 3.4E38f
|
||||
#define MAX_INT 2147483647
|
||||
#define MAX_UINT 4294967295u
|
||||
#define MAX_USHORT 65535u
|
||||
|
||||
#endif
|
5
pytorch3d/csrc/pulsar/cuda/README.md
Normal file
5
pytorch3d/csrc/pulsar/cuda/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# CUDA device compilation units
|
||||
|
||||
This folder contains `.cu` files to create compilation units
|
||||
for device-specific functions. See `../include/README.md` for
|
||||
more information.
|
501
pytorch3d/csrc/pulsar/cuda/commands.h
Normal file
501
pytorch3d/csrc/pulsar/cuda/commands.h
Normal file
@ -0,0 +1,501 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_CUDA_COMMANDS_H_
|
||||
#define PULSAR_NATIVE_CUDA_COMMANDS_H_
|
||||
|
||||
// Definitions for GPU commands.
|
||||
#include <cooperative_groups.h>
|
||||
#include <cub/cub.cuh>
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
#ifdef __DRIVER_TYPES_H__
|
||||
#ifndef DEVICE_RESET
|
||||
#define DEVICE_RESET cudaDeviceReset();
|
||||
#endif
|
||||
#else
|
||||
#ifndef DEVICE_RESET
|
||||
#define DEVICE_RESET
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define HANDLECUDA(CMD) CMD
|
||||
// handleCudaError((CMD), __FILE__, __LINE__)
|
||||
inline void
|
||||
handleCudaError(const cudaError_t err, const char* file, const int line) {
|
||||
if (err != cudaSuccess) {
|
||||
#ifndef __NVCC__
|
||||
fprintf(
|
||||
stderr,
|
||||
"%s(%i) : getLastCudaError() CUDA error :"
|
||||
" (%d) %s.\n",
|
||||
file,
|
||||
line,
|
||||
static_cast<int>(err),
|
||||
cudaGetErrorString(err));
|
||||
DEVICE_RESET
|
||||
exit(1);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
inline void
|
||||
getLastCudaError(const char* errorMessage, const char* file, const int line) {
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (cudaSuccess != err) {
|
||||
fprintf(stderr, "Error: %s.", errorMessage);
|
||||
handleCudaError(err, file, line);
|
||||
}
|
||||
}
|
||||
|
||||
#define ALIGN(VAL) __align__(VAL)
|
||||
#define SYNC() HANDLECUDE(cudaDeviceSynchronize())
|
||||
#define THREADFENCE_B() __threadfence_block()
|
||||
#define SHFL_SYNC(a, b, c) __shfl_sync((a), (b), (c))
|
||||
#define SHARED __shared__
|
||||
#define ACTIVEMASK() __activemask()
|
||||
#define BALLOT(mask, val) __ballot_sync((mask), val)
|
||||
/**
|
||||
* Find the cumulative sum within a warp up to the current
|
||||
* thread lane, with each mask thread contributing base.
|
||||
*/
|
||||
template <typename T>
|
||||
DEVICE T
|
||||
WARP_CUMSUM(const cg::coalesced_group& group, const uint& mask, const T& base) {
|
||||
T ret = base;
|
||||
T shfl_val;
|
||||
shfl_val = __shfl_down_sync(mask, ret, 1u); // Deactivate the rightmost lane.
|
||||
ret += (group.thread_rank() < 31) * shfl_val;
|
||||
shfl_val = __shfl_down_sync(mask, ret, 2u);
|
||||
ret += (group.thread_rank() < 30) * shfl_val;
|
||||
shfl_val = __shfl_down_sync(mask, ret, 4u); // ...4
|
||||
ret += (group.thread_rank() < 28) * shfl_val;
|
||||
shfl_val = __shfl_down_sync(mask, ret, 8u); // ...8
|
||||
ret += (group.thread_rank() < 24) * shfl_val;
|
||||
shfl_val = __shfl_down_sync(mask, ret, 16u); // ...16
|
||||
ret += (group.thread_rank() < 16) * shfl_val;
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DEVICE T
|
||||
WARP_MAX(const cg::coalesced_group& group, const uint& mask, const T& base) {
|
||||
T ret = base;
|
||||
ret = max(ret, __shfl_down_sync(mask, ret, 16u));
|
||||
ret = max(ret, __shfl_down_sync(mask, ret, 8u));
|
||||
ret = max(ret, __shfl_down_sync(mask, ret, 4u));
|
||||
ret = max(ret, __shfl_down_sync(mask, ret, 2u));
|
||||
ret = max(ret, __shfl_down_sync(mask, ret, 1u));
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DEVICE T
|
||||
WARP_SUM(const cg::coalesced_group& group, const uint& mask, const T& base) {
|
||||
T ret = base;
|
||||
ret = ret + __shfl_down_sync(mask, ret, 16u);
|
||||
ret = ret + __shfl_down_sync(mask, ret, 8u);
|
||||
ret = ret + __shfl_down_sync(mask, ret, 4u);
|
||||
ret = ret + __shfl_down_sync(mask, ret, 2u);
|
||||
ret = ret + __shfl_down_sync(mask, ret, 1u);
|
||||
return ret;
|
||||
}
|
||||
|
||||
INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
const cg::coalesced_group& group,
|
||||
const uint& mask,
|
||||
const float3& base) {
|
||||
float3 ret = base;
|
||||
ret.x = WARP_SUM(group, mask, base.x);
|
||||
ret.y = WARP_SUM(group, mask, base.y);
|
||||
ret.z = WARP_SUM(group, mask, base.z);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Floating point.
|
||||
// #define FMUL(a, b) __fmul_rn((a), (b))
|
||||
#define FMUL(a, b) ((a) * (b))
|
||||
#define FDIV(a, b) __fdiv_rn((a), (b))
|
||||
// #define FSUB(a, b) __fsub_rn((a), (b))
|
||||
#define FSUB(a, b) ((a) - (b))
|
||||
#define FADD(a, b) __fadd_rn((a), (b))
|
||||
#define FSQRT(a) __fsqrt_rn(a)
|
||||
#define FEXP(a) fasterexp(a)
|
||||
#define FLN(a) fasterlog(a)
|
||||
#define FPOW(a, b) __powf((a), (b))
|
||||
#define FMAX(a, b) fmax((a), (b))
|
||||
#define FMIN(a, b) fmin((a), (b))
|
||||
#define FCEIL(a) ceilf(a)
|
||||
#define FFLOOR(a) floorf(a)
|
||||
#define FROUND(x) nearbyintf(x)
|
||||
#define FSATURATE(x) __saturatef(x)
|
||||
#define FABS(a) abs(a)
|
||||
#define IASF(a, loc) (loc) = __int_as_float(a)
|
||||
#define FASI(a, loc) (loc) = __float_as_int(a)
|
||||
#define FABSLEQAS(a, b, c) \
|
||||
((a) <= (b) ? FSUB((b), (a)) <= (c) : FSUB((a), (b)) < (c))
|
||||
/** Calculates x*y+z. */
|
||||
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
|
||||
#define I2F(a) __int2float_rn(a)
|
||||
#define FRCP(x) __frcp_rn(x)
|
||||
__device__ static float atomicMax(float* address, float val) {
|
||||
int* address_as_i = (int*)address;
|
||||
int old = *address_as_i, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = ::atomicCAS(
|
||||
address_as_i,
|
||||
assumed,
|
||||
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
__device__ static float atomicMin(float* address, float val) {
|
||||
int* address_as_i = (int*)address;
|
||||
int old = *address_as_i, assumed;
|
||||
do {
|
||||
assumed = old;
|
||||
old = ::atomicCAS(
|
||||
address_as_i,
|
||||
assumed,
|
||||
__float_as_int(::fminf(val, __int_as_float(assumed))));
|
||||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
#define DMAX(a, b) FMAX(a, b)
|
||||
#define DMIN(a, b) FMIN(a, b)
|
||||
#define DSQRT(a) sqrt(a)
|
||||
#define DSATURATE(a) DMIN(1., DMAX(0., (a)))
|
||||
// half
|
||||
#define HADD(a, b) __hadd((a), (b))
|
||||
#define HSUB2(a, b) __hsub2((a), (b))
|
||||
#define HMUL2(a, b) __hmul2((a), (b))
|
||||
#define HSQRT(a) hsqrt(a)
|
||||
|
||||
// uint.
|
||||
#define CLZ(VAL) __clz(VAL)
|
||||
#define POPC(a) __popc(a)
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
#define ATOMICADD(PTR, VAL) atomicAdd((PTR), (VAL))
|
||||
#define ATOMICADD_F3(PTR, VAL) \
|
||||
ATOMICADD(&((PTR)->x), VAL.x); \
|
||||
ATOMICADD(&((PTR)->y), VAL.y); \
|
||||
ATOMICADD(&((PTR)->z), VAL.z);
|
||||
#if (CUDART_VERSION >= 10000)
|
||||
#define ATOMICADD_B(PTR, VAL) atomicAdd_block((PTR), (VAL))
|
||||
#else
|
||||
#define ATOMICADD_B(PTR, VAL) ATOMICADD(PTR, VAL)
|
||||
#endif
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
// int.
|
||||
#define IMIN(a, b) min((a), (b))
|
||||
#define IMAX(a, b) max((a), (b))
|
||||
#define IABS(a) abs(a)
|
||||
|
||||
// Checks.
|
||||
#define CHECKOK THCudaCheck
|
||||
#define ARGCHECK THArgCheck
|
||||
|
||||
// Math.
|
||||
#define NORM3DF(x, y, z) norm3df(x, y, z)
|
||||
#define RNORM3DF(x, y, z) rnorm3df(x, y, z)
|
||||
|
||||
// High level.
|
||||
INLINE DEVICE void prefetch_l1(unsigned long addr) {
|
||||
asm(" prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
|
||||
}
|
||||
#define PREFETCH(PTR) prefetch_l1((unsigned long)(PTR))
|
||||
#define GET_SORT_WS_SIZE(RES_PTR, KEY_TYPE, VAL_TYPE, NUM_OBJECTS) \
|
||||
cub::DeviceRadixSort::SortPairsDescending( \
|
||||
(void*)NULL, \
|
||||
*(RES_PTR), \
|
||||
reinterpret_cast<KEY_TYPE*>(NULL), \
|
||||
reinterpret_cast<KEY_TYPE*>(NULL), \
|
||||
reinterpret_cast<VAL_TYPE*>(NULL), \
|
||||
reinterpret_cast<VAL_TYPE*>(NULL), \
|
||||
(NUM_OBJECTS));
|
||||
#define GET_REDUCE_WS_SIZE(RES_PTR, TYPE, REDUCE_OP, NUM_OBJECTS) \
|
||||
{ \
|
||||
TYPE init = TYPE(); \
|
||||
cub::DeviceReduce::Reduce( \
|
||||
(void*)NULL, \
|
||||
*(RES_PTR), \
|
||||
(TYPE*)NULL, \
|
||||
(TYPE*)NULL, \
|
||||
(NUM_OBJECTS), \
|
||||
(REDUCE_OP), \
|
||||
init); \
|
||||
}
|
||||
#define GET_SELECT_WS_SIZE( \
|
||||
RES_PTR, TYPE_SELECTOR, TYPE_SELECTION, NUM_OBJECTS) \
|
||||
{ \
|
||||
cub::DeviceSelect::Flagged( \
|
||||
(void*)NULL, \
|
||||
*(RES_PTR), \
|
||||
(TYPE_SELECTION*)NULL, \
|
||||
(TYPE_SELECTOR*)NULL, \
|
||||
(TYPE_SELECTION*)NULL, \
|
||||
(int*)NULL, \
|
||||
(NUM_OBJECTS)); \
|
||||
}
|
||||
#define GET_SUM_WS_SIZE(RES_PTR, TYPE_SUM, NUM_OBJECTS) \
|
||||
{ \
|
||||
cub::DeviceReduce::Sum( \
|
||||
(void*)NULL, \
|
||||
*(RES_PTR), \
|
||||
(TYPE_SUM*)NULL, \
|
||||
(TYPE_SUM*)NULL, \
|
||||
NUM_OBJECTS); \
|
||||
}
|
||||
#define GET_MM_WS_SIZE(RES_PTR, TYPE, NUM_OBJECTS) \
|
||||
{ \
|
||||
TYPE init = TYPE(); \
|
||||
cub::DeviceReduce::Max( \
|
||||
(void*)NULL, *(RES_PTR), (TYPE*)NULL, (TYPE*)NULL, (NUM_OBJECTS)); \
|
||||
}
|
||||
#define SORT_DESCENDING( \
|
||||
TMPN1, SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS) \
|
||||
void* TMPN1 = NULL; \
|
||||
size_t TMPN1##_bytes = 0; \
|
||||
cub::DeviceRadixSort::SortPairsDescending( \
|
||||
TMPN1, \
|
||||
TMPN1##_bytes, \
|
||||
(SORT_PTR), \
|
||||
(SORTED_PTR), \
|
||||
(VAL_PTR), \
|
||||
(VAL_SORTED_PTR), \
|
||||
(NUM_OBJECTS)); \
|
||||
HANDLECUDA(cudaMalloc(&TMPN1, TMPN1##_bytes)); \
|
||||
cub::DeviceRadixSort::SortPairsDescending( \
|
||||
TMPN1, \
|
||||
TMPN1##_bytes, \
|
||||
(SORT_PTR), \
|
||||
(SORTED_PTR), \
|
||||
(VAL_PTR), \
|
||||
(VAL_SORTED_PTR), \
|
||||
(NUM_OBJECTS)); \
|
||||
HANDLECUDA(cudaFree(TMPN1));
|
||||
#define SORT_DESCENDING_WS( \
|
||||
TMPN1, \
|
||||
SORT_PTR, \
|
||||
SORTED_PTR, \
|
||||
VAL_PTR, \
|
||||
VAL_SORTED_PTR, \
|
||||
NUM_OBJECTS, \
|
||||
WORKSPACE_PTR, \
|
||||
WORKSPACE_BYTES) \
|
||||
cub::DeviceRadixSort::SortPairsDescending( \
|
||||
(WORKSPACE_PTR), \
|
||||
(WORKSPACE_BYTES), \
|
||||
(SORT_PTR), \
|
||||
(SORTED_PTR), \
|
||||
(VAL_PTR), \
|
||||
(VAL_SORTED_PTR), \
|
||||
(NUM_OBJECTS));
|
||||
#define SORT_ASCENDING_WS( \
|
||||
SORT_PTR, \
|
||||
SORTED_PTR, \
|
||||
VAL_PTR, \
|
||||
VAL_SORTED_PTR, \
|
||||
NUM_OBJECTS, \
|
||||
WORKSPACE_PTR, \
|
||||
WORKSPACE_BYTES, \
|
||||
STREAM) \
|
||||
cub::DeviceRadixSort::SortPairs( \
|
||||
(WORKSPACE_PTR), \
|
||||
(WORKSPACE_BYTES), \
|
||||
(SORT_PTR), \
|
||||
(SORTED_PTR), \
|
||||
(VAL_PTR), \
|
||||
(VAL_SORTED_PTR), \
|
||||
(NUM_OBJECTS), \
|
||||
0, \
|
||||
sizeof(*(SORT_PTR)) * 8, \
|
||||
(STREAM));
|
||||
#define SUM_WS( \
|
||||
SUM_PTR, OUT_PTR, NUM_OBJECTS, WORKSPACE_PTR, WORKSPACE_BYTES, STREAM) \
|
||||
cub::DeviceReduce::Sum( \
|
||||
(WORKSPACE_PTR), \
|
||||
(WORKSPACE_BYTES), \
|
||||
(SUM_PTR), \
|
||||
(OUT_PTR), \
|
||||
(NUM_OBJECTS), \
|
||||
(STREAM));
|
||||
#define MIN_WS( \
|
||||
MIN_PTR, OUT_PTR, NUM_OBJECTS, WORKSPACE_PTR, WORKSPACE_BYTES, STREAM) \
|
||||
cub::DeviceReduce::Min( \
|
||||
(WORKSPACE_PTR), \
|
||||
(WORKSPACE_BYTES), \
|
||||
(MIN_PTR), \
|
||||
(OUT_PTR), \
|
||||
(NUM_OBJECTS), \
|
||||
(STREAM));
|
||||
#define MAX_WS( \
|
||||
MAX_PTR, OUT_PTR, NUM_OBJECTS, WORKSPACE_PTR, WORKSPACE_BYTES, STREAM) \
|
||||
cub::DeviceReduce::Min( \
|
||||
(WORKSPACE_PTR), \
|
||||
(WORKSPACE_BYTES), \
|
||||
(MAX_PTR), \
|
||||
(OUT_PTR), \
|
||||
(NUM_OBJECTS), \
|
||||
(STREAM));
|
||||
//
|
||||
//
|
||||
//
|
||||
// TODO: rewrite using nested contexts instead of temporary names.
|
||||
#define REDUCE(REDUCE_PTR, RESULT_PTR, NUM_ITEMS, REDUCE_OP, REDUCE_INIT) \
|
||||
cub::DeviceReduce::Reduce( \
|
||||
TMPN1, \
|
||||
TMPN1##_bytes, \
|
||||
(REDUCE_PTR), \
|
||||
(RESULT_PTR), \
|
||||
(NUM_ITEMS), \
|
||||
(REDUCE_OP), \
|
||||
(REDUCE_INIT)); \
|
||||
HANDLECUDA(cudaMalloc(&TMPN1, TMPN1##_bytes)); \
|
||||
cub::DeviceReduce::Reduce( \
|
||||
TMPN1, \
|
||||
TMPN1##_bytes, \
|
||||
(REDUCE_PTR), \
|
||||
(RESULT_PTR), \
|
||||
(NUM_ITEMS), \
|
||||
(REDUCE_OP), \
|
||||
(REDUCE_INIT)); \
|
||||
HANDLECUDA(cudaFree(TMPN1));
|
||||
#define REDUCE_WS( \
|
||||
REDUCE_PTR, \
|
||||
RESULT_PTR, \
|
||||
NUM_ITEMS, \
|
||||
REDUCE_OP, \
|
||||
REDUCE_INIT, \
|
||||
WORKSPACE_PTR, \
|
||||
WORSPACE_BYTES, \
|
||||
STREAM) \
|
||||
cub::DeviceReduce::Reduce( \
|
||||
(WORKSPACE_PTR), \
|
||||
(WORSPACE_BYTES), \
|
||||
(REDUCE_PTR), \
|
||||
(RESULT_PTR), \
|
||||
(NUM_ITEMS), \
|
||||
(REDUCE_OP), \
|
||||
(REDUCE_INIT), \
|
||||
(STREAM));
|
||||
#define SELECT_FLAGS_WS( \
|
||||
FLAGS_PTR, \
|
||||
ITEM_PTR, \
|
||||
OUT_PTR, \
|
||||
NUM_SELECTED_PTR, \
|
||||
NUM_ITEMS, \
|
||||
WORKSPACE_PTR, \
|
||||
WORSPACE_BYTES, \
|
||||
STREAM) \
|
||||
cub::DeviceSelect::Flagged( \
|
||||
(WORKSPACE_PTR), \
|
||||
(WORSPACE_BYTES), \
|
||||
(ITEM_PTR), \
|
||||
(FLAGS_PTR), \
|
||||
(OUT_PTR), \
|
||||
(NUM_SELECTED_PTR), \
|
||||
(NUM_ITEMS), \
|
||||
stream = (STREAM));
|
||||
|
||||
#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \
|
||||
HANDLECUDA(cudaMemcpy( \
|
||||
(PTR_D), (PTR_H), sizeof(TYPE) * (SIZE), cudaMemcpyHostToDevice))
|
||||
#define COPY_DEV_HOST(PTR_H, PTR_D, TYPE, SIZE) \
|
||||
HANDLECUDA(cudaMemcpy( \
|
||||
(PTR_H), (PTR_D), sizeof(TYPE) * (SIZE), cudaMemcpyDeviceToHost))
|
||||
#define COPY_DEV_DEV(PTR_T, PTR_S, TYPE, SIZE) \
|
||||
HANDLECUDA(cudaMemcpy( \
|
||||
(PTR_T), (PTR_S), sizeof(TYPE) * (SIZE), cudaMemcpyDeviceToDevice))
|
||||
//
|
||||
// We *must* use cudaMallocManaged for pointers on device that should
|
||||
// interact with pytorch. However, this comes at a significant speed penalty.
|
||||
// We're using plain CUDA pointers for the rendering operations and
|
||||
// explicitly copy results to managed pointers wrapped for pytorch (see
|
||||
// pytorch/util.h).
|
||||
#define MALLOC(VAR, TYPE, SIZE) cudaMalloc(&(VAR), sizeof(TYPE) * (SIZE))
|
||||
#define FREE(PTR) HANDLECUDA(cudaFree(PTR))
|
||||
#define MEMSET(VAR, VAL, TYPE, SIZE, STREAM) \
|
||||
HANDLECUDA(cudaMemsetAsync((VAR), (VAL), sizeof(TYPE) * (SIZE), (STREAM)))
|
||||
|
||||
#define LAUNCH_MAX_PARALLEL_1D(FUNC, N, STREAM, ...) \
|
||||
{ \
|
||||
int64_t max_threads = \
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; \
|
||||
uint num_threads = min((N), max_threads); \
|
||||
uint num_blocks = iDivCeil((N), num_threads); \
|
||||
FUNC<<<num_blocks, num_threads, 0, (STREAM)>>>(__VA_ARGS__); \
|
||||
}
|
||||
#define LAUNCH_PARALLEL_1D(FUNC, N, TN, STREAM, ...) \
|
||||
{ \
|
||||
uint num_threads = min(static_cast<int>(N), static_cast<int>(TN)); \
|
||||
uint num_blocks = iDivCeil((N), num_threads); \
|
||||
FUNC<<<num_blocks, num_threads, 0, (STREAM)>>>(__VA_ARGS__); \
|
||||
}
|
||||
#define LAUNCH_MAX_PARALLEL_2D(FUNC, NX, NY, STREAM, ...) \
|
||||
{ \
|
||||
int64_t max_threads = \
|
||||
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock; \
|
||||
int64_t max_threads_sqrt = static_cast<int64_t>(sqrt(max_threads)); \
|
||||
dim3 num_threads, num_blocks; \
|
||||
num_threads.x = min((NX), max_threads_sqrt); \
|
||||
num_blocks.x = iDivCeil((NX), num_threads.x); \
|
||||
num_threads.y = min((NY), max_threads_sqrt); \
|
||||
num_blocks.y = iDivCeil((NY), num_threads.y); \
|
||||
num_threads.z = 1; \
|
||||
num_blocks.z = 1; \
|
||||
FUNC<<<num_blocks, num_threads, 0, (STREAM)>>>(__VA_ARGS__); \
|
||||
}
|
||||
#define LAUNCH_PARALLEL_2D(FUNC, NX, NY, TX, TY, STREAM, ...) \
|
||||
{ \
|
||||
dim3 num_threads, num_blocks; \
|
||||
num_threads.x = min((NX), (TX)); \
|
||||
num_blocks.x = iDivCeil((NX), num_threads.x); \
|
||||
num_threads.y = min((NY), (TY)); \
|
||||
num_blocks.y = iDivCeil((NY), num_threads.y); \
|
||||
num_threads.z = 1; \
|
||||
num_blocks.z = 1; \
|
||||
FUNC<<<num_blocks, num_threads, 0, (STREAM)>>>(__VA_ARGS__); \
|
||||
}
|
||||
|
||||
#define GET_PARALLEL_IDX_1D(VARNAME, N) \
|
||||
const uint VARNAME = __mul24(blockIdx.x, blockDim.x) + threadIdx.x; \
|
||||
if (VARNAME >= (N)) { \
|
||||
return; \
|
||||
}
|
||||
#define GET_PARALLEL_IDS_2D(VAR_X, VAR_Y, WIDTH, HEIGHT) \
|
||||
const uint VAR_X = __mul24(blockIdx.x, blockDim.x) + threadIdx.x; \
|
||||
const uint VAR_Y = __mul24(blockIdx.y, blockDim.y) + threadIdx.y; \
|
||||
if (VAR_X >= (WIDTH) || VAR_Y >= (HEIGHT)) \
|
||||
return;
|
||||
#define END_PARALLEL()
|
||||
#define END_PARALLEL_NORET()
|
||||
#define END_PARALLEL_2D_NORET()
|
||||
#define END_PARALLEL_2D()
|
||||
#define RETURN_PARALLEL() return
|
||||
#define CHECKLAUNCH() THCudaCheck(cudaGetLastError());
|
||||
#define ISONDEVICE true
|
||||
#define SYNCDEVICE() HANDLECUDA(cudaDeviceSynchronize())
|
||||
#define START_TIME(TN) \
|
||||
cudaEvent_t __time_start_##TN, __time_stop_##TN; \
|
||||
cudaEventCreate(&__time_start_##TN); \
|
||||
cudaEventCreate(&__time_stop_##TN); \
|
||||
cudaEventRecord(__time_start_##TN);
|
||||
#define STOP_TIME(TN) cudaEventRecord(__time_stop_##TN);
|
||||
#define GET_TIME(TN, TOPTR) \
|
||||
cudaEventSynchronize(__time_stop_##TN); \
|
||||
cudaEventElapsedTime((TOPTR), __time_start_##TN, __time_stop_##TN);
|
||||
#define START_TIME_CU(TN) START_TIME(CN)
|
||||
#define STOP_TIME_CU(TN) STOP_TIME(TN)
|
||||
#define GET_TIME_CU(TN, TOPTR) GET_TIME(TN, TOPTR)
|
||||
|
||||
#endif
|
2
pytorch3d/csrc/pulsar/cuda/renderer.backward.gpu.cu
Normal file
2
pytorch3d/csrc/pulsar/cuda/renderer.backward.gpu.cu
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.backward.instantiate.h"
|
2
pytorch3d/csrc/pulsar/cuda/renderer.backward_dbg.gpu.cu
Normal file
2
pytorch3d/csrc/pulsar/cuda/renderer.backward_dbg.gpu.cu
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.backward_dbg.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.calc_gradients.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.calc_signature.instantiate.h"
|
2
pytorch3d/csrc/pulsar/cuda/renderer.construct.gpu.cu
Normal file
2
pytorch3d/csrc/pulsar/cuda/renderer.construct.gpu.cu
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.construct.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.create_selector.instantiate.h"
|
2
pytorch3d/csrc/pulsar/cuda/renderer.destruct.gpu.cu
Normal file
2
pytorch3d/csrc/pulsar/cuda/renderer.destruct.gpu.cu
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.destruct.instantiate.h"
|
2
pytorch3d/csrc/pulsar/cuda/renderer.fill_bg.gpu.cu
Normal file
2
pytorch3d/csrc/pulsar/cuda/renderer.fill_bg.gpu.cu
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.fill_bg.instantiate.h"
|
2
pytorch3d/csrc/pulsar/cuda/renderer.forward.gpu.cu
Normal file
2
pytorch3d/csrc/pulsar/cuda/renderer.forward.gpu.cu
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.forward.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.norm_cam_gradients.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.norm_sphere_gradients.instantiate.h"
|
2
pytorch3d/csrc/pulsar/cuda/renderer.render.gpu.cu
Normal file
2
pytorch3d/csrc/pulsar/cuda/renderer.render.gpu.cu
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.render.instantiate.h"
|
85
pytorch3d/csrc/pulsar/global.h
Normal file
85
pytorch3d/csrc/pulsar/global.h
Normal file
@ -0,0 +1,85 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_GLOBAL_H
|
||||
#define PULSAR_GLOBAL_H
|
||||
|
||||
#include "./constants.h"
|
||||
#ifndef WIN32
|
||||
#include <csignal>
|
||||
#endif
|
||||
|
||||
#if defined(_WIN64) || defined(_WIN32)
|
||||
#define uint unsigned int
|
||||
#define ushort unsigned short
|
||||
#endif
|
||||
|
||||
#include "./logging.h" // <- include before torch/extension.h
|
||||
|
||||
#define MAX_GRAD_SPHERES 128
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#define INLINE __forceinline__
|
||||
#define HOST __host__
|
||||
#define DEVICE __device__
|
||||
#define GLOBAL __global__
|
||||
#define RESTRICT __restrict__
|
||||
#define DEBUGBREAK()
|
||||
#pragma diag_suppress = attribute_not_allowed
|
||||
#pragma diag_suppress = 1866
|
||||
#pragma diag_suppress = 2941
|
||||
#pragma diag_suppress = 2951
|
||||
#pragma diag_suppress = 2967
|
||||
#else // __CUDACC__
|
||||
#define INLINE inline
|
||||
#define HOST
|
||||
#define DEVICE
|
||||
#define GLOBAL
|
||||
#define RESTRICT
|
||||
#define DEBUGBREAK() std::raise(SIGINT)
|
||||
// Don't care about pytorch warnings; they shouldn't clutter our warnings.
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/extension.h>
|
||||
#pragma clang diagnostic pop
|
||||
namespace py = pybind11;
|
||||
inline float3 make_float3(const float& x, const float& y, const float& z) {
|
||||
float3 res;
|
||||
res.x = x;
|
||||
res.y = y;
|
||||
res.z = z;
|
||||
return res;
|
||||
}
|
||||
|
||||
inline bool operator==(const float3& a, const float3& b) {
|
||||
return a.x == b.x && a.y == b.y && a.z == b.z;
|
||||
}
|
||||
#endif // __CUDACC__
|
||||
#define IHD INLINE HOST DEVICE
|
||||
|
||||
// An assertion command that can be used on host and device.
|
||||
#ifdef PULSAR_ASSERTIONS
|
||||
#ifdef __CUDACC__
|
||||
#define PASSERT(VAL) \
|
||||
if (!(VAL)) { \
|
||||
printf( \
|
||||
"Pulsar assertion failed in %s, line %d: %s.\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
#VAL); \
|
||||
}
|
||||
#else
|
||||
#define PASSERT(VAL) \
|
||||
if (!(VAL)) { \
|
||||
printf( \
|
||||
"Pulsar assertion failed in %s, line %d: %s.\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
#VAL); \
|
||||
std::raise(SIGINT); \
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
#define PASSERT(VAL)
|
||||
#endif
|
||||
|
||||
#endif
|
5
pytorch3d/csrc/pulsar/host/README.md
Normal file
5
pytorch3d/csrc/pulsar/host/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Device-specific host compilation units
|
||||
|
||||
This folder contains `.cpp` files to create compilation units
|
||||
for device specific functions. See `../include/README.md` for
|
||||
more information.
|
383
pytorch3d/csrc/pulsar/host/commands.h
Normal file
383
pytorch3d/csrc/pulsar/host/commands.h
Normal file
@ -0,0 +1,383 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_COMMANDS_H_
|
||||
#define PULSAR_NATIVE_COMMANDS_H_
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#define __builtin_popcount (int)__popcnt
|
||||
#endif
|
||||
|
||||
// Definitions for CPU commands.
|
||||
// #include <execution>
|
||||
// #include <numeric>
|
||||
|
||||
namespace cg {
|
||||
struct coalesced_group {
|
||||
INLINE uint thread_rank() const {
|
||||
return 0u;
|
||||
}
|
||||
INLINE uint size() const {
|
||||
return 1u;
|
||||
}
|
||||
INLINE uint ballot(uint val) const {
|
||||
return static_cast<uint>(val > 0);
|
||||
}
|
||||
};
|
||||
|
||||
struct thread_block {
|
||||
INLINE uint thread_rank() const {
|
||||
return 0u;
|
||||
}
|
||||
INLINE uint size() const {
|
||||
return 1u;
|
||||
}
|
||||
INLINE void sync() const {}
|
||||
};
|
||||
|
||||
INLINE coalesced_group coalesced_threads() {
|
||||
coalesced_group ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
INLINE thread_block this_thread_block() {
|
||||
thread_block ret;
|
||||
return ret;
|
||||
}
|
||||
} // namespace cg
|
||||
#define SHFL_SYNC(a, b, c) (b)
|
||||
template <typename T>
|
||||
T WARP_CUMSUM(
|
||||
const cg::coalesced_group& group,
|
||||
const uint& mask,
|
||||
const T& base) {
|
||||
return base;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DEVICE T
|
||||
WARP_MAX(const cg::coalesced_group& group, const uint& mask, const T& base) {
|
||||
return base;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DEVICE T
|
||||
WARP_SUM(const cg::coalesced_group& group, const uint& mask, const T& base) {
|
||||
return base;
|
||||
}
|
||||
|
||||
INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
||||
const cg::coalesced_group& group,
|
||||
const uint& mask,
|
||||
const float3& base) {
|
||||
return base;
|
||||
}
|
||||
|
||||
#define ACTIVEMASK() (1u << 31)
|
||||
#define ALIGN(VAL)
|
||||
#define SYNC()
|
||||
#define THREADFENCE_B()
|
||||
#define BALLOT(mask, val) (val != 0)
|
||||
#define SHARED
|
||||
// Floating point.
|
||||
#define FMAX(a, b) std::fmax((a), (b))
|
||||
#define FMIN(a, b) std::fmin((a), (b))
|
||||
INLINE float atomicMax(float* address, float val) {
|
||||
*address = std::max(*address, val);
|
||||
return *address;
|
||||
}
|
||||
INLINE float atomicMin(float* address, float val) {
|
||||
*address = std::min(*address, val);
|
||||
return *address;
|
||||
}
|
||||
#define FMUL(a, b) ((a) * (b))
|
||||
#define FDIV(a, b) ((a) / (b))
|
||||
#define FSUB(a, b) ((a) - (b))
|
||||
#define FABSLEQAS(a, b, c) \
|
||||
((a) <= (b) ? FSUB((b), (a)) <= (c) : FSUB((a), (b)) < (c))
|
||||
#define FADD(a, b) ((a) + (b))
|
||||
#define FSQRT(a) sqrtf(a)
|
||||
#define FEXP(a) fasterexp(a)
|
||||
#define FLN(a) fasterlog(a)
|
||||
#define FPOW(a, b) powf((a), (b))
|
||||
#define FROUND(x) roundf(x)
|
||||
#define FCEIL(a) ceilf(a)
|
||||
#define FFLOOR(a) floorf(a)
|
||||
#define FSATURATE(x) std::max(0.f, std::min(1.f, x))
|
||||
#define FABS(a) abs(a)
|
||||
#define FMA(x, y, z) ((x) * (y) + (z))
|
||||
#define I2F(a) static_cast<float>(a)
|
||||
#define FRCP(x) (1.f / (x))
|
||||
#define IASF(x, loc) memcpy(&(loc), &(x), sizeof(x))
|
||||
#define FASI(x, loc) memcpy(&(loc), &(x), sizeof(x))
|
||||
#define DMAX(a, b) std::max((a), (b))
|
||||
#define DMIN(a, b) std::min((a), (b))
|
||||
#define DSATURATE(a) DMIN(1., DMAX(0., (a)))
|
||||
#define DSQRT(a) sqrt(a)
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
// uint.
|
||||
#define CLZ(VAL) _clz(VAL)
|
||||
template <typename T>
|
||||
INLINE T ATOMICADD(T* address, T val) {
|
||||
T old = *address;
|
||||
*address += val;
|
||||
return old;
|
||||
}
|
||||
template <typename T>
|
||||
INLINE void ATOMICADD_F3(T* address, T val) {
|
||||
ATOMICADD(&(address->x), val.x);
|
||||
ATOMICADD(&(address->y), val.y);
|
||||
ATOMICADD(&(address->z), val.z);
|
||||
}
|
||||
#define ATOMICADD_B(a, b) ATOMICADD((a), (b))
|
||||
#define POPC(a) __builtin_popcount(a)
|
||||
|
||||
// int.
|
||||
#define IMIN(a, b) std::min((a), (b))
|
||||
#define IMAX(a, b) std::max((a), (b))
|
||||
#define IABS(a) abs(a)
|
||||
|
||||
// Checks.
|
||||
#define CHECKOK THCheck
|
||||
#define ARGCHECK THArgCheck
|
||||
|
||||
// Math.
|
||||
#define NORM3DF(x, y, z) sqrtf(x* x + y * y + z * z)
|
||||
#define RNORM3DF(x, y, z) (1.f / sqrtf(x * x + y * y + z * z))
|
||||
|
||||
// High level.
|
||||
#define PREFETCH(PTR)
|
||||
#define GET_SORT_WS_SIZE(RES_PTR, KEY_TYPE, VAL_TYPE, NUM_OBJECTS) \
|
||||
*(RES_PTR) = 0;
|
||||
#define GET_REDUCE_WS_SIZE(RES_PTR, TYPE, REDUCE_OP, NUM_OBJECTS) \
|
||||
*(RES_PTR) = 0;
|
||||
#define GET_SELECT_WS_SIZE( \
|
||||
RES_PTR, TYPE_SELECTOR, TYPE_SELECTION, NUM_OBJECTS) \
|
||||
*(RES_PTR) = 0;
|
||||
#define GET_SUM_WS_SIZE(RES_PTR, TYPE_SUM, NUM_OBJECTS) *(RES_PTR) = 0;
|
||||
#define GET_MM_WS_SIZE(RES_PTR, TYPE, NUM_OBJECTS) *(RES_PTR) = 0;
|
||||
|
||||
#define SORT_DESCENDING( \
|
||||
TMPN1, SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS) \
|
||||
std::vector<size_t> TMPN1(NUM_OBJECTS); \
|
||||
std::iota(TMPN1.begin(), TMPN1.end(), 0); \
|
||||
const auto TMPN1##_val_ptr = (SORT_PTR); \
|
||||
std::sort( \
|
||||
TMPN1.begin(), TMPN1.end(), [&TMPN1##_val_ptr](size_t i1, size_t i2) { \
|
||||
return TMPN1##_val_ptr[i1] > TMPN1##_val_ptr[i2]; \
|
||||
}); \
|
||||
for (int i = 0; i < (NUM_OBJECTS); ++i) { \
|
||||
(SORTED_PTR)[i] = (SORT_PTR)[TMPN1[i]]; \
|
||||
} \
|
||||
for (int i = 0; i < (NUM_OBJECTS); ++i) { \
|
||||
(VAL_SORTED_PTR)[i] = (VAL_PTR)[TMPN1[i]]; \
|
||||
}
|
||||
|
||||
#define SORT_ASCENDING( \
|
||||
SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS, STREAM) \
|
||||
{ \
|
||||
std::vector<size_t> TMPN1(NUM_OBJECTS); \
|
||||
std::iota(TMPN1.begin(), TMPN1.end(), 0); \
|
||||
const auto TMPN1_val_ptr = (SORT_PTR); \
|
||||
std::sort( \
|
||||
TMPN1.begin(), \
|
||||
TMPN1.end(), \
|
||||
[&TMPN1_val_ptr](size_t i1, size_t i2) -> bool { \
|
||||
return TMPN1_val_ptr[i1] < TMPN1_val_ptr[i2]; \
|
||||
}); \
|
||||
for (int i = 0; i < (NUM_OBJECTS); ++i) { \
|
||||
(SORTED_PTR)[i] = (SORT_PTR)[TMPN1[i]]; \
|
||||
} \
|
||||
for (int i = 0; i < (NUM_OBJECTS); ++i) { \
|
||||
(VAL_SORTED_PTR)[i] = (VAL_PTR)[TMPN1[i]]; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define SORT_DESCENDING_WS( \
|
||||
TMPN1, \
|
||||
SORT_PTR, \
|
||||
SORTED_PTR, \
|
||||
VAL_PTR, \
|
||||
VAL_SORTED_PTR, \
|
||||
NUM_OBJECTS, \
|
||||
WORSPACE_PTR, \
|
||||
WORKSPACE_SIZE) \
|
||||
SORT_DESCENDING( \
|
||||
TMPN1, SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS)
|
||||
|
||||
#define SORT_ASCENDING_WS( \
|
||||
SORT_PTR, \
|
||||
SORTED_PTR, \
|
||||
VAL_PTR, \
|
||||
VAL_SORTED_PTR, \
|
||||
NUM_OBJECTS, \
|
||||
WORSPACE_PTR, \
|
||||
WORKSPACE_SIZE, \
|
||||
STREAM) \
|
||||
SORT_ASCENDING( \
|
||||
SORT_PTR, SORTED_PTR, VAL_PTR, VAL_SORTED_PTR, NUM_OBJECTS, STREAM)
|
||||
|
||||
#define REDUCE(REDUCE_PTR, RESULT_PTR, NUM_ITEMS, REDUCE_OP, REDUCE_INIT) \
|
||||
{ \
|
||||
*(RESULT_PTR) = (REDUCE_INIT); \
|
||||
for (int i = 0; i < (NUM_ITEMS); ++i) { \
|
||||
*(RESULT_PTR) = REDUCE_OP(*(RESULT_PTR), (REDUCE_PTR)[i]); \
|
||||
} \
|
||||
}
|
||||
#define REDUCE_WS( \
|
||||
REDUCE_PTR, \
|
||||
RESULT_PTR, \
|
||||
NUM_ITEMS, \
|
||||
REDUCE_OP, \
|
||||
REDUCE_INIT, \
|
||||
WORKSPACE_PTR, \
|
||||
WORKSPACE_SIZE, \
|
||||
STREAM) \
|
||||
REDUCE(REDUCE_PTR, RESULT_PTR, NUM_ITEMS, REDUCE_OP, REDUCE_INIT)
|
||||
|
||||
#define SELECT_FLAGS_WS( \
|
||||
FLAGS_PTR, \
|
||||
ITEM_PTR, \
|
||||
OUT_PTR, \
|
||||
NUM_SELECTED_PTR, \
|
||||
NUM_ITEMS, \
|
||||
WORKSPACE_PTR, \
|
||||
WORSPACE_BYTES, \
|
||||
STREAM) \
|
||||
{ \
|
||||
*NUM_SELECTED_PTR = 0; \
|
||||
ptrdiff_t write_pos = 0; \
|
||||
for (int i = 0; i < NUM_ITEMS; ++i) { \
|
||||
if (FLAGS_PTR[i]) { \
|
||||
OUT_PTR[write_pos++] = ITEM_PTR[i]; \
|
||||
*NUM_SELECTED_PTR += 1; \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SUM_WS(
|
||||
T* SUM_PTR,
|
||||
T* OUT_PTR,
|
||||
size_t NUM_OBJECTS,
|
||||
char* WORKSPACE_PTR,
|
||||
size_t WORKSPACE_BYTES,
|
||||
cudaStream_t STREAM) {
|
||||
*(OUT_PTR) = T();
|
||||
for (int i = 0; i < (NUM_OBJECTS); ++i) {
|
||||
*(OUT_PTR) = *(OUT_PTR) + (SUM_PTR)[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MIN_WS(
|
||||
T* MIN_PTR,
|
||||
T* OUT_PTR,
|
||||
size_t NUM_OBJECTS,
|
||||
char* WORKSPACE_PTR,
|
||||
size_t WORKSPACE_BYTES,
|
||||
cudaStream_t STREAM) {
|
||||
*(OUT_PTR) = T();
|
||||
for (int i = 0; i < (NUM_OBJECTS); ++i) {
|
||||
*(OUT_PTR) = std::min<T>(*(OUT_PTR), (MIN_PTR)[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MAX_WS(
|
||||
T* MAX_PTR,
|
||||
T* OUT_PTR,
|
||||
size_t NUM_OBJECTS,
|
||||
char* WORKSPACE_PTR,
|
||||
size_t WORKSPACE_BYTES,
|
||||
cudaStream_t STREAM) {
|
||||
*(OUT_PTR) = T();
|
||||
for (int i = 0; i < (NUM_OBJECTS); ++i) {
|
||||
*(OUT_PTR) = std::max<T>(*(OUT_PTR), (MAX_PTR)[i]);
|
||||
}
|
||||
}
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \
|
||||
std::memcpy((PTR_D), (PTR_H), sizeof(TYPE) * (SIZE))
|
||||
//
|
||||
#define COPY_DEV_HOST(PTR_H, PTR_D, TYPE, SIZE) \
|
||||
std::memcpy((PTR_H), (PTR_D), sizeof(TYPE) * (SIZE))
|
||||
//
|
||||
#define COPY_DEV_DEV(PTR_T, PTR_S, TYPE, SIZE) \
|
||||
std::memcpy((PTR_T), (PTR_S), sizeof(TYPE) * SIZE)
|
||||
//
|
||||
|
||||
#define MALLOC(VAR, TYPE, SIZE) MALLOC_HOST(VAR, TYPE, SIZE)
|
||||
#define FREE(PTR) FREE_HOST(PTR)
|
||||
#define MEMSET(VAR, VAL, TYPE, SIZE, STREAM) \
|
||||
memset((VAR), (VAL), sizeof(TYPE) * (SIZE))
|
||||
//
|
||||
|
||||
#define LAUNCH_MAX_PARALLEL_1D(FUNC, N, STREAM, ...) FUNC(__VA_ARGS__);
|
||||
#define LAUNCH_PARALLEL_1D(FUNC, N, TN, STREAM, ...) FUNC(__VA_ARGS__);
|
||||
#define LAUNCH_MAX_PARALLEL_2D(FUNC, NX, NY, STREAM, ...) FUNC(__VA_ARGS__);
|
||||
#define LAUNCH_PARALLEL_2D(FUNC, NX, NY, TX, TY, STREAM, ...) FUNC(__VA_ARGS__);
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
//
|
||||
#define GET_PARALLEL_IDX_1D(VARNAME, N) \
|
||||
for (uint VARNAME = 0; VARNAME < (N); ++VARNAME) {
|
||||
#define GET_PARALLEL_IDS_2D(VAR_X, VAR_Y, WIDTH, HEIGHT) \
|
||||
int2 blockDim; \
|
||||
blockDim.x = 1; \
|
||||
blockDim.y = 1; \
|
||||
uint __parallel_2d_width = WIDTH; \
|
||||
uint __parallel_2d_height = HEIGHT; \
|
||||
for (uint VAR_Y = 0; VAR_Y < __parallel_2d_height; ++(VAR_Y)) { \
|
||||
for (uint VAR_X = 0; VAR_X < __parallel_2d_width; ++(VAR_X)) {
|
||||
//
|
||||
//
|
||||
//
|
||||
#define END_PARALLEL() \
|
||||
end_parallel:; \
|
||||
}
|
||||
#define END_PARALLEL_NORET() }
|
||||
#define END_PARALLEL_2D() \
|
||||
end_parallel:; \
|
||||
} \
|
||||
}
|
||||
#define END_PARALLEL_2D_NORET() \
|
||||
} \
|
||||
}
|
||||
#define RETURN_PARALLEL() goto end_parallel;
|
||||
#define CHECKLAUNCH()
|
||||
#define ISONDEVICE false
|
||||
#define SYNCDEVICE()
|
||||
#define START_TIME(TN) \
|
||||
auto __time_start_##TN = std::chrono::steady_clock::now();
|
||||
#define STOP_TIME(TN) auto __time_stop_##TN = std::chrono::steady_clock::now();
|
||||
#define GET_TIME(TN, TOPTR) \
|
||||
*TOPTR = std::chrono::duration_cast<std::chrono::milliseconds>( \
|
||||
__time_stop_##TN - __time_start_##TN) \
|
||||
.count()
|
||||
#define START_TIME_CU(TN) \
|
||||
cudaEvent_t __time_start_##TN, __time_stop_##TN; \
|
||||
cudaEventCreate(&__time_start_##TN); \
|
||||
cudaEventCreate(&__time_stop_##TN); \
|
||||
cudaEventRecord(__time_start_##TN);
|
||||
#define STOP_TIME_CU(TN) cudaEventRecord(__time_stop_##TN);
|
||||
#define GET_TIME_CU(TN, TOPTR) \
|
||||
cudaEventSynchronize(__time_stop_##TN); \
|
||||
cudaEventElapsedTime((TOPTR), __time_start_##TN, __time_stop_##TN);
|
||||
|
||||
#endif
|
2
pytorch3d/csrc/pulsar/host/renderer.backward.cpu.cpp
Normal file
2
pytorch3d/csrc/pulsar/host/renderer.backward.cpu.cpp
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.backward.instantiate.h"
|
2
pytorch3d/csrc/pulsar/host/renderer.backward_dbg.cpu.cpp
Normal file
2
pytorch3d/csrc/pulsar/host/renderer.backward_dbg.cpu.cpp
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.backward_dbg.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.calc_gradients.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.calc_signature.instantiate.h"
|
2
pytorch3d/csrc/pulsar/host/renderer.construct.cpu.cpp
Normal file
2
pytorch3d/csrc/pulsar/host/renderer.construct.cpu.cpp
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.construct.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.create_selector.instantiate.h"
|
2
pytorch3d/csrc/pulsar/host/renderer.destruct.cpu.cpp
Normal file
2
pytorch3d/csrc/pulsar/host/renderer.destruct.cpu.cpp
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.destruct.instantiate.h"
|
2
pytorch3d/csrc/pulsar/host/renderer.fill_bg.cpu.cpp
Normal file
2
pytorch3d/csrc/pulsar/host/renderer.fill_bg.cpu.cpp
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.fill_bg.instantiate.h"
|
2
pytorch3d/csrc/pulsar/host/renderer.forward.cpu.cpp
Normal file
2
pytorch3d/csrc/pulsar/host/renderer.forward.cpu.cpp
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.forward.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.norm_cam_gradients.instantiate.h"
|
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.norm_sphere_gradients.instantiate.h"
|
2
pytorch3d/csrc/pulsar/host/renderer.render.cpu.cpp
Normal file
2
pytorch3d/csrc/pulsar/host/renderer.render.cpu.cpp
Normal file
@ -0,0 +1,2 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "../include/renderer.render.instantiate.h"
|
16
pytorch3d/csrc/pulsar/include/README.md
Normal file
16
pytorch3d/csrc/pulsar/include/README.md
Normal file
@ -0,0 +1,16 @@
|
||||
# The `include` folder
|
||||
|
||||
This folder contains header files with implementations of several useful
|
||||
algorithms. These implementations are usually done in files called `x.device.h`
|
||||
and use macros that route every device specific command to the right
|
||||
implementation (see `commands.h`).
|
||||
|
||||
If you're using a device specific implementation, include `x.device.h`.
|
||||
This gives you the high-speed, device specific implementation that lets
|
||||
you work with all the details of the datastructure. All function calls are
|
||||
inlined. If you need to work with the high-level interface and be able to
|
||||
dynamically pick a device, only include `x.h`. The functions there are
|
||||
templated with a boolean `DEV` flag and are instantiated in device specific
|
||||
compilation units. You will not be able to use any other functions, but can
|
||||
use `func<true>(params)` to work on a CUDA device, or `func<false>(params)`
|
||||
to work on the host.
|
18
pytorch3d/csrc/pulsar/include/camera.device.h
Normal file
18
pytorch3d/csrc/pulsar/include/camera.device.h
Normal file
@ -0,0 +1,18 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_CAMERA_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_CAMERA_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.h"
|
||||
#include "./commands.h"
|
||||
|
||||
namespace pulsar {
|
||||
IHD CamGradInfo::CamGradInfo() {
|
||||
cam_pos = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_dir_x = make_float3(0.f, 0.f, 0.f);
|
||||
pixel_dir_y = make_float3(0.f, 0.f, 0.f);
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
72
pytorch3d/csrc/pulsar/include/camera.h
Normal file
72
pytorch3d/csrc/pulsar/include/camera.h
Normal file
@ -0,0 +1,72 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_CAMERA_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_CAMERA_H_
|
||||
|
||||
#include "../global.h"
|
||||
|
||||
namespace pulsar {
|
||||
/**
|
||||
* Everything that's needed to raycast with our camera model.
|
||||
*/
|
||||
struct CamInfo {
|
||||
float3 eye; /** Position in world coordinates. */
|
||||
float3 pixel_0_0_center; /** LUC center of pixel position in world
|
||||
coordinates. */
|
||||
float3 pixel_dir_x; /** Direction for increasing x for one pixel to the next,
|
||||
* in world coordinates. */
|
||||
float3 pixel_dir_y; /** Direction for increasing y for one pixel to the next,
|
||||
* in world coordinates. */
|
||||
float3 sensor_dir_z; /** Normalized direction vector from eye through the
|
||||
* sensor in z direction (optical axis). */
|
||||
float half_pixel_size; /** Half size of a pixel, in world coordinates. This
|
||||
* must be consistent with pixel_dir_x and pixel_dir_y!
|
||||
*/
|
||||
float focal_length; /** The focal length, if applicable. */
|
||||
uint aperture_width; /** Full image width in px, possibly not fully used
|
||||
* in case of a shifted principal point. */
|
||||
uint aperture_height; /** Full image height in px, possibly not fully used
|
||||
* in case of a shifted principal point. */
|
||||
uint film_width; /** Resulting image width. */
|
||||
uint film_height; /** Resulting image height. */
|
||||
/** The top left coordinates (inclusive) of the film in the full aperture. */
|
||||
uint film_border_left, film_border_top;
|
||||
int32_t principal_point_offset_x; /** Horizontal principal point offset. */
|
||||
int32_t principal_point_offset_y; /** Vertical principal point offset. */
|
||||
float min_dist; /** Minimum distance for a ball to be rendered. */
|
||||
float max_dist; /** Maximum distance for a ball to be rendered. */
|
||||
float norm_fac; /** 1 / (max_dist - min_dist), pre-computed. */
|
||||
/** The depth where to place the background, in normalized coordinates where
|
||||
* 0. is the backmost depth and 1. the frontmost. */
|
||||
float background_normalization_depth;
|
||||
/** The number of image content channels to use. Usually three. */
|
||||
uint n_channels;
|
||||
/** Whether to use an orthogonal instead of a perspective projection. */
|
||||
bool orthogonal_projection;
|
||||
/** Whether to use a right-handed system (inverts the z axis). */
|
||||
bool right_handed;
|
||||
};
|
||||
|
||||
inline bool operator==(const CamInfo& a, const CamInfo& b) {
|
||||
return a.film_width == b.film_width && a.film_height == b.film_height &&
|
||||
a.background_normalization_depth == b.background_normalization_depth &&
|
||||
a.n_channels == b.n_channels &&
|
||||
a.orthogonal_projection == b.orthogonal_projection &&
|
||||
a.right_handed == b.right_handed;
|
||||
};
|
||||
|
||||
struct CamGradInfo {
|
||||
HOST DEVICE CamGradInfo();
|
||||
float3 cam_pos;
|
||||
float3 pixel_0_0_center;
|
||||
float3 pixel_dir_x;
|
||||
float3 pixel_dir_y;
|
||||
};
|
||||
|
||||
// TODO: remove once https://github.com/NVlabs/cub/issues/172 is resolved.
|
||||
struct IntWrapper {
|
||||
int val;
|
||||
};
|
||||
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
131
pytorch3d/csrc/pulsar/include/closest_sphere_tracker.device.h
Normal file
131
pytorch3d/csrc/pulsar/include/closest_sphere_tracker.device.h
Normal file
@ -0,0 +1,131 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_CLOSEST_SPHERE_TRACKER_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_CLOSEST_SPHERE_TRACKER_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
/**
|
||||
* A facility to track the closest spheres to the camera.
|
||||
*
|
||||
* Their max number is defined by MAX_GRAD_SPHERES (this is defined in
|
||||
* `pulsar/native/global.h`). This is done to keep the performance as high as
|
||||
* possible because this struct needs to do updates continuously on the GPU.
|
||||
*/
|
||||
struct ClosestSphereTracker {
|
||||
public:
|
||||
IHD ClosestSphereTracker(const int& n_track) : n_hits(0), n_track(n_track) {
|
||||
PASSERT(n_track < MAX_GRAD_SPHERES);
|
||||
// Initialize the sphere IDs to -1 and the weights to 0.
|
||||
for (int i = 0; i < n_track; ++i) {
|
||||
this->most_important_sphere_ids[i] = -1;
|
||||
this->closest_sphere_intersection_depths[i] = MAX_FLOAT;
|
||||
}
|
||||
};
|
||||
|
||||
IHD void track(
|
||||
const uint& sphere_idx,
|
||||
const float& intersection_depth,
|
||||
const uint& coord_x,
|
||||
const uint& coord_y) {
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_TRACKER_PIX,
|
||||
"tracker|tracking sphere %u (depth: %f).\n",
|
||||
sphere_idx,
|
||||
intersection_depth);
|
||||
for (int i = IMIN(this->n_hits, n_track) - 1; i >= -1; --i) {
|
||||
if (i < 0 ||
|
||||
this->closest_sphere_intersection_depths[i] < intersection_depth) {
|
||||
// Write position is i+1.
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_TRACKER_PIX,
|
||||
"tracker|determined writing position: %d.\n",
|
||||
i + 1);
|
||||
if (i + 1 < n_track) {
|
||||
// Shift every other sphere back.
|
||||
for (int j = n_track - 1; j > i + 1; --j) {
|
||||
this->closest_sphere_intersection_depths[j] =
|
||||
this->closest_sphere_intersection_depths[j - 1];
|
||||
this->most_important_sphere_ids[j] =
|
||||
this->most_important_sphere_ids[j - 1];
|
||||
}
|
||||
this->closest_sphere_intersection_depths[i + 1] = intersection_depth;
|
||||
this->most_important_sphere_ids[i + 1] = sphere_idx;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
#if PULSAR_LOG_TRACKER_PIX
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_TRACKER_PIX,
|
||||
"tracker|sphere list after adding sphere %u:\n",
|
||||
sphere_idx);
|
||||
for (int i = 0; i < n_track; ++i) {
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_TRACKER_PIX,
|
||||
"tracker|sphere %d: %d (depth: %f).\n",
|
||||
i,
|
||||
this->most_important_sphere_ids[i],
|
||||
this->closest_sphere_intersection_depths[i]);
|
||||
}
|
||||
#endif // PULSAR_LOG_TRACKER_PIX
|
||||
this->n_hits += 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of hits registered.
|
||||
*/
|
||||
IHD int get_n_hits() const {
|
||||
return this->n_hits;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the idx closest sphere ID.
|
||||
*
|
||||
* For example, get_closest_sphere_id(0) gives the overall closest
|
||||
* sphere id.
|
||||
*
|
||||
* This method is implemented for highly optimized scenarios and will *not*
|
||||
* perform an index check at runtime if assertions are disabled. idx must be
|
||||
* >=0 and < IMIN(n_hits, n_track) for a valid result, if it is >=
|
||||
* n_hits it will return -1.
|
||||
*/
|
||||
IHD int get_closest_sphere_id(const int& idx) {
|
||||
PASSERT(idx >= 0 && idx < n_track);
|
||||
return this->most_important_sphere_ids[idx];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the idx closest sphere normalized_depth.
|
||||
*
|
||||
* For example, get_closest_sphere_depth(0) gives the overall closest
|
||||
* sphere depth (normalized).
|
||||
*
|
||||
* This method is implemented for highly optimized scenarios and will *not*
|
||||
* perform an index check at runtime if assertions are disabled. idx must be
|
||||
* >=0 and < IMIN(n_hits, n_track) for a valid result, if it is >=
|
||||
* n_hits it will return 1. + FEPS.
|
||||
*/
|
||||
IHD float get_closest_sphere_depth(const int& idx) {
|
||||
PASSERT(idx >= 0 && idx < n_track);
|
||||
return this->closest_sphere_intersection_depths[idx];
|
||||
}
|
||||
|
||||
private:
|
||||
/** The number of registered hits so far. */
|
||||
int n_hits;
|
||||
/** The number of intersections to track. Must be <MAX_GRAD_SPHERES. */
|
||||
int n_track;
|
||||
/** The sphere ids of the n_track spheres with the highest color
|
||||
* contribution. */
|
||||
int most_important_sphere_ids[MAX_GRAD_SPHERES];
|
||||
/** The normalized depths of the closest n_track spheres. */
|
||||
float closest_sphere_intersection_depths[MAX_GRAD_SPHERES];
|
||||
};
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
30
pytorch3d/csrc/pulsar/include/commands.h
Normal file
30
pytorch3d/csrc/pulsar/include/commands.h
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_COMMANDS_ROUTING_H_
|
||||
#define PULSAR_NATIVE_COMMANDS_ROUTING_H_
|
||||
|
||||
#include "../global.h"
|
||||
|
||||
// Commands available everywhere.
|
||||
#define MALLOC_HOST(VAR, TYPE, SIZE) \
|
||||
VAR = static_cast<TYPE*>(malloc(sizeof(TYPE) * (SIZE)))
|
||||
#define FREE_HOST(PTR) free(PTR)
|
||||
|
||||
/* Include command definitions depending on CPU or GPU use. */
|
||||
|
||||
#ifdef __CUDACC__
|
||||
// TODO: find out which compiler we're using here and use the suppression.
|
||||
// #pragma push
|
||||
// #pragma diag_suppress = 68
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <THC/THC.h>
|
||||
// #pragma pop
|
||||
#include "../cuda/commands.h"
|
||||
#else
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
#include <TH/TH.h>
|
||||
#pragma clang diagnostic pop
|
||||
#include "../host/commands.h"
|
||||
#endif
|
||||
|
||||
#endif
|
87
pytorch3d/csrc/pulsar/include/fastermath.h
Normal file
87
pytorch3d/csrc/pulsar/include/fastermath.h
Normal file
@ -0,0 +1,87 @@
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_FASTERMATH_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_FASTERMATH_H_
|
||||
|
||||
/*=====================================================================*
|
||||
* Copyright (C) 2011 Paul Mineiro *
|
||||
* All rights reserved. *
|
||||
* *
|
||||
* Redistribution and use in source and binary forms, with *
|
||||
* or without modification, are permitted provided that the *
|
||||
* following conditions are met: *
|
||||
* *
|
||||
* * Redistributions of source code must retain the *
|
||||
* above copyright notice, this list of conditions and *
|
||||
* the following disclaimer. *
|
||||
* *
|
||||
* * Redistributions in binary form must reproduce the *
|
||||
* above copyright notice, this list of conditions and *
|
||||
* the following disclaimer in the documentation and/or *
|
||||
* other materials provided with the distribution. *
|
||||
* *
|
||||
* * Neither the name of Paul Mineiro nor the names *
|
||||
* of other contributors may be used to endorse or promote *
|
||||
* products derived from this software without specific *
|
||||
* prior written permission. *
|
||||
* *
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND *
|
||||
* CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, *
|
||||
* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES *
|
||||
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE *
|
||||
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER *
|
||||
* OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, *
|
||||
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES *
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE *
|
||||
* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR *
|
||||
* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF *
|
||||
* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT *
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY *
|
||||
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE *
|
||||
* POSSIBILITY OF SUCH DAMAGE. *
|
||||
* *
|
||||
* Contact: Paul Mineiro <paul@mineiro.com> *
|
||||
*=====================================================================*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include "./commands.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
#define cast_uint32_t static_cast<uint32_t>
|
||||
#else
|
||||
#define cast_uint32_t (uint32_t)
|
||||
#endif
|
||||
|
||||
IHD float fasterlog2(float x) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} vx = {x};
|
||||
float y = vx.i;
|
||||
y *= 1.1920928955078125e-7f;
|
||||
return y - 126.94269504f;
|
||||
}
|
||||
|
||||
IHD float fasterlog(float x) {
|
||||
// return 0.69314718f * fasterlog2 (x);
|
||||
union {
|
||||
float f;
|
||||
uint32_t i;
|
||||
} vx = {x};
|
||||
float y = vx.i;
|
||||
y *= 8.2629582881927490e-8f;
|
||||
return y - 87.989971088f;
|
||||
}
|
||||
|
||||
IHD float fasterpow2(float p) {
|
||||
float clipp = (p < -126) ? -126.0f : p;
|
||||
union {
|
||||
uint32_t i;
|
||||
float f;
|
||||
} v = {cast_uint32_t((1 << 23) * (clipp + 126.94269504f))};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
IHD float fasterexp(float p) {
|
||||
return fasterpow2(1.442695040f * p);
|
||||
}
|
||||
|
||||
#endif
|
150
pytorch3d/csrc/pulsar/include/math.h
Normal file
150
pytorch3d/csrc/pulsar/include/math.h
Normal file
@ -0,0 +1,150 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_IMPL_MATH_H_
|
||||
#define PULSAR_NATIVE_IMPL_MATH_H_
|
||||
|
||||
#include "./camera.h"
|
||||
#include "./commands.h"
|
||||
#include "./fastermath.h"
|
||||
|
||||
/**
|
||||
* Get the direction of val.
|
||||
*
|
||||
* Returns +1 if val is positive, -1 if val is zero or negative.
|
||||
*/
|
||||
IHD int sign_dir(const int& val) {
|
||||
return -(static_cast<int>((val <= 0)) << 1) + 1;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the direction of val.
|
||||
*
|
||||
* Returns +1 if val is positive, -1 if val is zero or negative.
|
||||
*/
|
||||
IHD float sign_dir(const float& val) {
|
||||
return static_cast<float>(1 - (static_cast<int>((val <= 0)) << 1));
|
||||
};
|
||||
|
||||
/**
|
||||
* Integer ceil division.
|
||||
*/
|
||||
IHD uint iDivCeil(uint a, uint b) {
|
||||
return (a % b != 0) ? (a / b + 1) : (a / b);
|
||||
}
|
||||
|
||||
IHD float3 outer_product_sum(const float3& a) {
|
||||
return make_float3(
|
||||
a.x * a.x + a.x * a.y + a.x * a.z,
|
||||
a.x * a.y + a.y * a.y + a.y * a.z,
|
||||
a.x * a.z + a.y * a.z + a.z * a.z);
|
||||
}
|
||||
|
||||
// TODO: put intrinsics here.
|
||||
IHD float3 operator+(const float3& a, const float3& b) {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
|
||||
IHD void operator+=(float3& a, const float3& b) {
|
||||
a.x += b.x;
|
||||
a.y += b.y;
|
||||
a.z += b.z;
|
||||
}
|
||||
|
||||
IHD void operator-=(float3& a, const float3& b) {
|
||||
a.x -= b.x;
|
||||
a.y -= b.y;
|
||||
a.z -= b.z;
|
||||
}
|
||||
|
||||
IHD void operator/=(float3& a, const float& b) {
|
||||
a.x /= b;
|
||||
a.y /= b;
|
||||
a.z /= b;
|
||||
}
|
||||
|
||||
IHD void operator*=(float3& a, const float& b) {
|
||||
a.x *= b;
|
||||
a.y *= b;
|
||||
a.z *= b;
|
||||
}
|
||||
|
||||
IHD float3 operator/(const float3& a, const float& b) {
|
||||
return make_float3(a.x / b, a.y / b, a.z / b);
|
||||
}
|
||||
|
||||
IHD float3 operator-(const float3& a, const float3& b) {
|
||||
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
|
||||
}
|
||||
|
||||
IHD float3 operator*(const float3& a, const float& b) {
|
||||
return make_float3(a.x * b, a.y * b, a.z * b);
|
||||
}
|
||||
|
||||
IHD float3 operator*(const float3& a, const float3& b) {
|
||||
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
|
||||
}
|
||||
|
||||
IHD float3 operator*(const float& a, const float3& b) {
|
||||
return b * a;
|
||||
}
|
||||
|
||||
IHD float length(const float3& v) {
|
||||
// TODO: benchmark what's faster.
|
||||
return NORM3DF(v.x, v.y, v.z);
|
||||
// return __fsqrt_rn(v.x * v.x + v.y * v.y + v.z * v.z);
|
||||
}
|
||||
|
||||
/**
|
||||
* Left-hand multiplication of the constructed rotation matrix with the vector.
|
||||
*/
|
||||
IHD float3 rotate(
|
||||
const float3& v,
|
||||
const float3& dir_x,
|
||||
const float3& dir_y,
|
||||
const float3& dir_z) {
|
||||
return make_float3(
|
||||
dir_x.x * v.x + dir_x.y * v.y + dir_x.z * v.z,
|
||||
dir_y.x * v.x + dir_y.y * v.y + dir_y.z * v.z,
|
||||
dir_z.x * v.x + dir_z.y * v.y + dir_z.z * v.z);
|
||||
}
|
||||
|
||||
IHD float3 normalize(const float3& v) {
|
||||
return v * RNORM3DF(v.x, v.y, v.z);
|
||||
}
|
||||
|
||||
INLINE DEVICE float dot(const float3& a, const float3& b) {
|
||||
return FADD(FADD(FMUL(a.x, b.x), FMUL(a.y, b.y)), FMUL(a.z, b.z));
|
||||
}
|
||||
|
||||
INLINE DEVICE float3 cross(const float3& a, const float3& b) {
|
||||
// TODO: faster
|
||||
return make_float3(
|
||||
a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
|
||||
}
|
||||
|
||||
namespace pulsar {
|
||||
IHD CamGradInfo operator+(const CamGradInfo& a, const CamGradInfo& b) {
|
||||
CamGradInfo res;
|
||||
res.cam_pos = a.cam_pos + b.cam_pos;
|
||||
res.pixel_0_0_center = a.pixel_0_0_center + b.pixel_0_0_center;
|
||||
res.pixel_dir_x = a.pixel_dir_x + b.pixel_dir_x;
|
||||
res.pixel_dir_y = a.pixel_dir_y + b.pixel_dir_y;
|
||||
return res;
|
||||
}
|
||||
|
||||
IHD CamGradInfo operator*(const CamGradInfo& a, const float& b) {
|
||||
CamGradInfo res;
|
||||
res.cam_pos = a.cam_pos * b;
|
||||
res.pixel_0_0_center = a.pixel_0_0_center * b;
|
||||
res.pixel_dir_x = a.pixel_dir_x * b;
|
||||
res.pixel_dir_y = a.pixel_dir_y * b;
|
||||
return res;
|
||||
}
|
||||
|
||||
IHD IntWrapper operator+(const IntWrapper& a, const IntWrapper& b) {
|
||||
IntWrapper res;
|
||||
res.val = a.val + b.val;
|
||||
return res;
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
182
pytorch3d/csrc/pulsar/include/renderer.backward.device.h
Normal file
182
pytorch3d/csrc/pulsar/include/renderer.backward.device.h
Normal file
@ -0,0 +1,182 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_RENDERER_BACKWARD_DEVICE_H_
|
||||
#define PULSAR_NATIVE_RENDERER_BACKWARD_DEVICE_H_
|
||||
|
||||
#include "./camera.device.h"
|
||||
#include "./math.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
void backward(
|
||||
Renderer* self,
|
||||
const float* grad_im,
|
||||
const float* image,
|
||||
const float* forw_info,
|
||||
const float* vert_pos,
|
||||
const float* vert_col,
|
||||
const float* vert_rad,
|
||||
const CamInfo& cam,
|
||||
const float& gamma,
|
||||
float percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const float* vert_opy_d,
|
||||
const size_t& num_balls,
|
||||
const uint& mode,
|
||||
const bool& dif_pos,
|
||||
const bool& dif_col,
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
cudaStream_t stream) {
|
||||
ARGCHECK(gamma > 0.f && gamma <= 1.f, 6, "gamma must be in [0., 1.]");
|
||||
ARGCHECK(
|
||||
percent_allowed_difference >= 0.f && percent_allowed_difference <= 1.f,
|
||||
7,
|
||||
"percent_allowed_difference must be in [0., 1.]");
|
||||
ARGCHECK(max_n_hits >= 1u, 8, "max_n_hits must be >= 1");
|
||||
ARGCHECK(
|
||||
num_balls > 0 && num_balls <= self->max_num_balls,
|
||||
9,
|
||||
"num_balls must be >0 and less than max num balls!");
|
||||
ARGCHECK(
|
||||
cam.film_width == self->cam.film_width &&
|
||||
cam.film_height == self->cam.film_height,
|
||||
5,
|
||||
"cam film size must agree");
|
||||
ARGCHECK(mode <= 1, 10, "mode must be <= 1!");
|
||||
if (percent_allowed_difference < EPS) {
|
||||
LOG(WARNING) << "percent_allowed_difference < " << FEPS << "! Clamping to "
|
||||
<< FEPS << ".";
|
||||
percent_allowed_difference = FEPS;
|
||||
}
|
||||
if (percent_allowed_difference > 1.f - FEPS) {
|
||||
LOG(WARNING) << "percent_allowed_difference > " << (1.f - FEPS)
|
||||
<< "! Clamping to " << (1.f - FEPS) << ".";
|
||||
percent_allowed_difference = 1.f - FEPS;
|
||||
}
|
||||
LOG_IF(INFO, PULSAR_LOG_RENDER) << "Rendering backward pass...";
|
||||
// Update camera.
|
||||
self->cam.eye = cam.eye;
|
||||
self->cam.pixel_0_0_center = cam.pixel_0_0_center - cam.eye;
|
||||
self->cam.pixel_dir_x = cam.pixel_dir_x;
|
||||
self->cam.pixel_dir_y = cam.pixel_dir_y;
|
||||
self->cam.sensor_dir_z = cam.sensor_dir_z;
|
||||
self->cam.half_pixel_size = cam.half_pixel_size;
|
||||
self->cam.focal_length = cam.focal_length;
|
||||
self->cam.aperture_width = cam.aperture_width;
|
||||
self->cam.aperture_height = cam.aperture_height;
|
||||
self->cam.min_dist = cam.min_dist;
|
||||
self->cam.max_dist = cam.max_dist;
|
||||
self->cam.norm_fac = cam.norm_fac;
|
||||
self->cam.principal_point_offset_x = cam.principal_point_offset_x;
|
||||
self->cam.principal_point_offset_y = cam.principal_point_offset_y;
|
||||
self->cam.film_border_left = cam.film_border_left;
|
||||
self->cam.film_border_top = cam.film_border_top;
|
||||
#ifdef PULSAR_TIMINGS_ENABLED
|
||||
START_TIME(calc_signature);
|
||||
#endif
|
||||
LAUNCH_MAX_PARALLEL_1D(
|
||||
calc_signature<DEV>,
|
||||
num_balls,
|
||||
stream,
|
||||
*self,
|
||||
reinterpret_cast<const float3*>(vert_pos),
|
||||
vert_col,
|
||||
vert_rad,
|
||||
num_balls);
|
||||
CHECKLAUNCH();
|
||||
#ifdef PULSAR_TIMINGS_ENABLED
|
||||
STOP_TIME(calc_signature);
|
||||
START_TIME(calc_gradients);
|
||||
#endif
|
||||
MEMSET(self->grad_pos_d, 0, float3, num_balls, stream);
|
||||
MEMSET(self->grad_col_d, 0, float, num_balls * self->cam.n_channels, stream);
|
||||
MEMSET(self->grad_rad_d, 0, float, num_balls, stream);
|
||||
MEMSET(self->grad_cam_d, 0, float, 12, stream);
|
||||
MEMSET(self->grad_cam_buf_d, 0, CamGradInfo, num_balls, stream);
|
||||
MEMSET(self->grad_opy_d, 0, float, num_balls, stream);
|
||||
MEMSET(self->ids_sorted_d, 0, int, num_balls, stream);
|
||||
LAUNCH_PARALLEL_2D(
|
||||
calc_gradients<DEV>,
|
||||
self->cam.film_width,
|
||||
self->cam.film_height,
|
||||
GRAD_BLOCK_SIZE,
|
||||
GRAD_BLOCK_SIZE,
|
||||
stream,
|
||||
self->cam,
|
||||
grad_im,
|
||||
gamma,
|
||||
reinterpret_cast<const float3*>(vert_pos),
|
||||
vert_col,
|
||||
vert_rad,
|
||||
vert_opy_d,
|
||||
num_balls,
|
||||
image,
|
||||
forw_info,
|
||||
self->di_d,
|
||||
self->ii_d,
|
||||
dif_pos,
|
||||
dif_col,
|
||||
dif_rad,
|
||||
dif_cam,
|
||||
dif_opy,
|
||||
self->grad_rad_d,
|
||||
self->grad_col_d,
|
||||
self->grad_pos_d,
|
||||
self->grad_cam_buf_d,
|
||||
self->grad_opy_d,
|
||||
self->ids_sorted_d,
|
||||
self->n_track);
|
||||
CHECKLAUNCH();
|
||||
#ifdef PULSAR_TIMINGS_ENABLED
|
||||
STOP_TIME(calc_gradients);
|
||||
START_TIME(normalize);
|
||||
#endif
|
||||
LAUNCH_MAX_PARALLEL_1D(
|
||||
norm_sphere_gradients<DEV>, num_balls, stream, *self, num_balls);
|
||||
CHECKLAUNCH();
|
||||
if (dif_cam) {
|
||||
SUM_WS(
|
||||
self->grad_cam_buf_d,
|
||||
reinterpret_cast<CamGradInfo*>(self->grad_cam_d),
|
||||
static_cast<int>(num_balls),
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
SUM_WS(
|
||||
(IntWrapper*)(self->ids_sorted_d),
|
||||
(IntWrapper*)(self->n_grad_contributions_d),
|
||||
static_cast<int>(num_balls),
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
LAUNCH_MAX_PARALLEL_1D(
|
||||
norm_cam_gradients<DEV>, static_cast<int64_t>(1), stream, *self);
|
||||
CHECKLAUNCH();
|
||||
}
|
||||
#ifdef PULSAR_TIMINGS_ENABLED
|
||||
STOP_TIME(normalize);
|
||||
float time_ms;
|
||||
// This blocks the result and prevents batch-processing from parallelizing.
|
||||
GET_TIME(calc_signature, &time_ms);
|
||||
std::cout << "Time for signature calculation: " << time_ms << " ms"
|
||||
<< std::endl;
|
||||
GET_TIME(calc_gradients, &time_ms);
|
||||
std::cout << "Time for gradient calculation: " << time_ms << " ms"
|
||||
<< std::endl;
|
||||
GET_TIME(normalize, &time_ms);
|
||||
std::cout << "Time for aggregation and normalization: " << time_ms << " ms"
|
||||
<< std::endl;
|
||||
#endif
|
||||
LOG_IF(INFO, PULSAR_LOG_RENDER) << "Backward pass complete.";
|
||||
}
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,30 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./renderer.backward.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template void backward<ISONDEVICE>(
|
||||
Renderer* self,
|
||||
const float* grad_im,
|
||||
const float* image,
|
||||
const float* forw_info,
|
||||
const float* vert_pos,
|
||||
const float* vert_col,
|
||||
const float* vert_rad,
|
||||
const CamInfo& cam,
|
||||
const float& gamma,
|
||||
float percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const float* vert_opy,
|
||||
const size_t& num_balls,
|
||||
const uint& mode,
|
||||
const bool& dif_pos,
|
||||
const bool& dif_col,
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
150
pytorch3d/csrc/pulsar/include/renderer.backward_dbg.device.h
Normal file
150
pytorch3d/csrc/pulsar/include/renderer.backward_dbg.device.h
Normal file
@ -0,0 +1,150 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_RENDERER_BACKWARD_DBG_DEVICE_H_
|
||||
#define PULSAR_NATIVE_RENDERER_BACKWARD_DBG_DEVICE_H_
|
||||
|
||||
#include "./camera.device.h"
|
||||
#include "./math.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
void backward_dbg(
|
||||
Renderer* self,
|
||||
const float* grad_im,
|
||||
const float* image,
|
||||
const float* forw_info,
|
||||
const float* vert_pos,
|
||||
const float* vert_col,
|
||||
const float* vert_rad,
|
||||
const CamInfo& cam,
|
||||
const float& gamma,
|
||||
float percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const float* vert_opy_d,
|
||||
const size_t& num_balls,
|
||||
const uint& mode,
|
||||
const bool& dif_pos,
|
||||
const bool& dif_col,
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const uint& pos_x,
|
||||
const uint& pos_y,
|
||||
cudaStream_t stream) {
|
||||
ARGCHECK(gamma > 0.f && gamma <= 1.f, 6, "gamma must be in [0., 1.]");
|
||||
ARGCHECK(
|
||||
percent_allowed_difference >= 0.f && percent_allowed_difference <= 1.f,
|
||||
7,
|
||||
"percent_allowed_difference must be in [0., 1.]");
|
||||
ARGCHECK(max_n_hits >= 1u, 8, "max_n_hits must be >= 1");
|
||||
ARGCHECK(
|
||||
num_balls > 0 && num_balls <= self->max_num_balls,
|
||||
9,
|
||||
"num_balls must be >0 and less than max num balls!");
|
||||
ARGCHECK(
|
||||
cam.film_width == self->cam.film_width &&
|
||||
cam.film_height == self->cam.film_height,
|
||||
5,
|
||||
"cam film size must agree");
|
||||
ARGCHECK(mode <= 1, 10, "mode must be <= 1!");
|
||||
if (percent_allowed_difference < EPS) {
|
||||
LOG(WARNING) << "percent_allowed_difference < " << FEPS << "! Clamping to "
|
||||
<< FEPS << ".";
|
||||
percent_allowed_difference = FEPS;
|
||||
}
|
||||
ARGCHECK(
|
||||
pos_x < cam.film_width && pos_y < cam.film_height,
|
||||
15,
|
||||
"pos_x must be < width and pos_y < height.");
|
||||
if (percent_allowed_difference > 1.f - FEPS) {
|
||||
LOG(WARNING) << "percent_allowed_difference > " << (1.f - FEPS)
|
||||
<< "! Clamping to " << (1.f - FEPS) << ".";
|
||||
percent_allowed_difference = 1.f - FEPS;
|
||||
}
|
||||
LOG_IF(INFO, PULSAR_LOG_RENDER)
|
||||
<< "Rendering debug backward pass for x: " << pos_x << ", y: " << pos_y;
|
||||
// Update camera.
|
||||
self->cam.eye = cam.eye;
|
||||
self->cam.pixel_0_0_center = cam.pixel_0_0_center - cam.eye;
|
||||
self->cam.pixel_dir_x = cam.pixel_dir_x;
|
||||
self->cam.pixel_dir_y = cam.pixel_dir_y;
|
||||
self->cam.sensor_dir_z = cam.sensor_dir_z;
|
||||
self->cam.half_pixel_size = cam.half_pixel_size;
|
||||
self->cam.focal_length = cam.focal_length;
|
||||
self->cam.aperture_width = cam.aperture_width;
|
||||
self->cam.aperture_height = cam.aperture_height;
|
||||
self->cam.min_dist = cam.min_dist;
|
||||
self->cam.max_dist = cam.max_dist;
|
||||
self->cam.norm_fac = cam.norm_fac;
|
||||
self->cam.principal_point_offset_x = cam.principal_point_offset_x;
|
||||
self->cam.principal_point_offset_y = cam.principal_point_offset_y;
|
||||
self->cam.film_border_left = cam.film_border_left;
|
||||
self->cam.film_border_top = cam.film_border_top;
|
||||
LAUNCH_MAX_PARALLEL_1D(
|
||||
calc_signature<DEV>,
|
||||
num_balls,
|
||||
stream,
|
||||
*self,
|
||||
reinterpret_cast<const float3*>(vert_pos),
|
||||
vert_col,
|
||||
vert_rad,
|
||||
num_balls);
|
||||
CHECKLAUNCH();
|
||||
MEMSET(self->grad_pos_d, 0, float3, num_balls, stream);
|
||||
MEMSET(self->grad_col_d, 0, float, num_balls * self->cam.n_channels, stream);
|
||||
MEMSET(self->grad_rad_d, 0, float, num_balls, stream);
|
||||
MEMSET(self->grad_cam_d, 0, float, 12, stream);
|
||||
MEMSET(self->grad_cam_buf_d, 0, CamGradInfo, num_balls, stream);
|
||||
MEMSET(self->grad_opy_d, 0, float, num_balls, stream);
|
||||
MEMSET(self->ids_sorted_d, 0, int, num_balls, stream);
|
||||
LAUNCH_MAX_PARALLEL_2D(
|
||||
calc_gradients<DEV>,
|
||||
(int64_t)1,
|
||||
(int64_t)1,
|
||||
stream,
|
||||
self->cam,
|
||||
grad_im,
|
||||
gamma,
|
||||
reinterpret_cast<const float3*>(vert_pos),
|
||||
vert_col,
|
||||
vert_rad,
|
||||
vert_opy_d,
|
||||
num_balls,
|
||||
image,
|
||||
forw_info,
|
||||
self->di_d,
|
||||
self->ii_d,
|
||||
dif_pos,
|
||||
dif_col,
|
||||
dif_rad,
|
||||
dif_cam,
|
||||
dif_opy,
|
||||
self->grad_rad_d,
|
||||
self->grad_col_d,
|
||||
self->grad_pos_d,
|
||||
self->grad_cam_buf_d,
|
||||
self->grad_opy_d,
|
||||
self->ids_sorted_d,
|
||||
self->n_track,
|
||||
pos_x,
|
||||
pos_y);
|
||||
CHECKLAUNCH();
|
||||
// We're not doing sphere gradient normalization here.
|
||||
SUM_WS(
|
||||
self->grad_cam_buf_d,
|
||||
reinterpret_cast<CamGradInfo*>(self->grad_cam_d),
|
||||
static_cast<int>(1),
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
// We're not doing camera gradient normalization here.
|
||||
LOG_IF(INFO, PULSAR_LOG_RENDER) << "Debug backward pass complete.";
|
||||
}
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,32 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./renderer.backward_dbg.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template void backward_dbg<ISONDEVICE>(
|
||||
Renderer* self,
|
||||
const float* grad_im,
|
||||
const float* image,
|
||||
const float* forw_info,
|
||||
const float* vert_pos,
|
||||
const float* vert_col,
|
||||
const float* vert_rad,
|
||||
const CamInfo& cam,
|
||||
const float& gamma,
|
||||
float percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const float* vert_opy,
|
||||
const size_t& num_balls,
|
||||
const uint& mode,
|
||||
const bool& dif_pos,
|
||||
const bool& dif_col,
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const uint& pos_x,
|
||||
const uint& pos_y,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
191
pytorch3d/csrc/pulsar/include/renderer.calc_gradients.device.h
Normal file
191
pytorch3d/csrc/pulsar/include/renderer.calc_gradients.device.h
Normal file
@ -0,0 +1,191 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CALC_GRADIENTS_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_CALC_GRADIENTS_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./commands.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
#include "./renderer.draw.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
GLOBAL void calc_gradients(
|
||||
const CamInfo cam, /** Camera in world coordinates. */
|
||||
float const* const RESTRICT grad_im, /** The gradient image. */
|
||||
const float
|
||||
gamma, /** The transparency parameter used in the forward pass. */
|
||||
float3 const* const RESTRICT vert_poss, /** Vertex position vector. */
|
||||
float const* const RESTRICT vert_cols, /** Vertex color vector. */
|
||||
float const* const RESTRICT vert_rads, /** Vertex radius vector. */
|
||||
float const* const RESTRICT opacity, /** Vertex opacity. */
|
||||
const uint num_balls, /** Number of balls. */
|
||||
float const* const RESTRICT result_d, /** Result image. */
|
||||
float const* const RESTRICT forw_info_d, /** Forward pass info. */
|
||||
DrawInfo const* const RESTRICT di_d, /** Draw information. */
|
||||
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
|
||||
// Mode switches.
|
||||
const bool calc_grad_pos,
|
||||
const bool calc_grad_col,
|
||||
const bool calc_grad_rad,
|
||||
const bool calc_grad_cam,
|
||||
const bool calc_grad_opy,
|
||||
// Out variables.
|
||||
float* const RESTRICT grad_rad_d, /** Radius gradients. */
|
||||
float* const RESTRICT grad_col_d, /** Color gradients. */
|
||||
float3* const RESTRICT grad_pos_d, /** Position gradients. */
|
||||
CamGradInfo* const RESTRICT grad_cam_buf_d, /** Camera gradient buffer. */
|
||||
float* const RESTRICT grad_opy_d, /** Opacity gradient buffer. */
|
||||
int* const RESTRICT
|
||||
grad_contributed_d, /** Gradient contribution counter. */
|
||||
// Infrastructure.
|
||||
const int n_track,
|
||||
const uint offs_x,
|
||||
const uint offs_y /** Debug offsets. */
|
||||
) {
|
||||
uint limit_x = cam.film_width, limit_y = cam.film_height;
|
||||
if (offs_x != 0) {
|
||||
// We're in debug mode.
|
||||
limit_x = 1;
|
||||
limit_y = 1;
|
||||
}
|
||||
GET_PARALLEL_IDS_2D(coord_x_base, coord_y_base, limit_x, limit_y);
|
||||
// coord_x_base and coord_y_base are in the film coordinate system.
|
||||
// We now need to translate to the aperture coordinate system. If
|
||||
// the principal point was shifted left/up nothing has to be
|
||||
// subtracted - only shift needs to be added in case it has been
|
||||
// shifted down/right.
|
||||
const uint film_coord_x = coord_x_base + offs_x;
|
||||
const uint ap_coord_x = film_coord_x +
|
||||
2 * static_cast<uint>(std::max(0, cam.principal_point_offset_x));
|
||||
const uint film_coord_y = coord_y_base + offs_y;
|
||||
const uint ap_coord_y = film_coord_y +
|
||||
2 * static_cast<uint>(std::max(0, cam.principal_point_offset_y));
|
||||
const float3 ray_dir = /** Ray cast through the pixel, normalized. */
|
||||
cam.pixel_0_0_center + ap_coord_x * cam.pixel_dir_x +
|
||||
ap_coord_y * cam.pixel_dir_y;
|
||||
const float norm_ray_dir = length(ray_dir);
|
||||
// ray_dir_norm *must* be calculated here in the same way as in the draw
|
||||
// function to have the same values withno other numerical instabilities
|
||||
// (for example, ray_dir * FRCP(norm_ray_dir) does not work)!
|
||||
float3 ray_dir_norm; /** Ray cast through the pixel, normalized. */
|
||||
float2 projected_ray; /** Ray intersection with the sensor. */
|
||||
if (cam.orthogonal_projection) {
|
||||
ray_dir_norm = cam.sensor_dir_z;
|
||||
projected_ray.x = static_cast<float>(ap_coord_x);
|
||||
projected_ray.y = static_cast<float>(ap_coord_y);
|
||||
} else {
|
||||
ray_dir_norm = normalize(
|
||||
cam.pixel_0_0_center + ap_coord_x * cam.pixel_dir_x +
|
||||
ap_coord_y * cam.pixel_dir_y);
|
||||
// This is a reasonable assumption for normal focal lengths and image sizes.
|
||||
PASSERT(FABS(ray_dir_norm.z) > FEPS);
|
||||
projected_ray.x = ray_dir_norm.x / ray_dir_norm.z * cam.focal_length;
|
||||
projected_ray.y = ray_dir_norm.y / ray_dir_norm.z * cam.focal_length;
|
||||
}
|
||||
float* result = const_cast<float*>(
|
||||
result_d + film_coord_y * cam.film_width * cam.n_channels +
|
||||
film_coord_x * cam.n_channels);
|
||||
const float* grad_im_l = grad_im +
|
||||
film_coord_y * cam.film_width * cam.n_channels +
|
||||
film_coord_x * cam.n_channels;
|
||||
// For writing...
|
||||
float3 grad_pos;
|
||||
float grad_rad, grad_opy;
|
||||
CamGradInfo grad_cam_local = CamGradInfo();
|
||||
// Set up shared infrastructure.
|
||||
const int fwi_loc = film_coord_y * cam.film_width * (3 + 2 * n_track) +
|
||||
film_coord_x * (3 + 2 * n_track);
|
||||
float sm_m = forw_info_d[fwi_loc];
|
||||
float sm_d = forw_info_d[fwi_loc + 1];
|
||||
PULSAR_LOG_DEV_APIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad|sm_m: %f, sm_d: %f, result: "
|
||||
"%f, %f, %f; grad_im: %f, %f, %f.\n",
|
||||
sm_m,
|
||||
sm_d,
|
||||
result[0],
|
||||
result[1],
|
||||
result[2],
|
||||
grad_im_l[0],
|
||||
grad_im_l[1],
|
||||
grad_im_l[2]);
|
||||
// Start processing.
|
||||
for (int grad_idx = 0; grad_idx < n_track; ++grad_idx) {
|
||||
int sphere_idx;
|
||||
FASI(forw_info_d[fwi_loc + 3 + 2 * grad_idx], sphere_idx);
|
||||
PASSERT(
|
||||
sphere_idx == -1 ||
|
||||
sphere_idx >= 0 && static_cast<uint>(sphere_idx) < num_balls);
|
||||
if (sphere_idx >= 0) {
|
||||
// TODO: make more efficient.
|
||||
grad_pos = make_float3(0.f, 0.f, 0.f);
|
||||
grad_rad = 0.f;
|
||||
grad_cam_local = CamGradInfo();
|
||||
const DrawInfo di = di_d[sphere_idx];
|
||||
grad_opy = 0.f;
|
||||
draw(
|
||||
di,
|
||||
opacity == NULL ? 1.f : opacity[sphere_idx],
|
||||
cam,
|
||||
gamma,
|
||||
ray_dir_norm,
|
||||
projected_ray,
|
||||
// Mode switches.
|
||||
false, // draw only
|
||||
calc_grad_pos,
|
||||
calc_grad_col,
|
||||
calc_grad_rad,
|
||||
calc_grad_cam,
|
||||
calc_grad_opy,
|
||||
// Position info.
|
||||
ap_coord_x,
|
||||
ap_coord_y,
|
||||
sphere_idx,
|
||||
// Optional in.
|
||||
&ii_d[sphere_idx],
|
||||
&ray_dir,
|
||||
&norm_ray_dir,
|
||||
grad_im_l,
|
||||
NULL,
|
||||
// In/out
|
||||
&sm_d,
|
||||
&sm_m,
|
||||
result,
|
||||
// Optional out.
|
||||
NULL,
|
||||
NULL,
|
||||
&grad_pos,
|
||||
grad_col_d + sphere_idx * cam.n_channels,
|
||||
&grad_rad,
|
||||
&grad_cam_local,
|
||||
&grad_opy);
|
||||
ATOMICADD(&(grad_rad_d[sphere_idx]), grad_rad);
|
||||
// Color has been added directly.
|
||||
ATOMICADD_F3(&(grad_pos_d[sphere_idx]), grad_pos);
|
||||
ATOMICADD_F3(
|
||||
&(grad_cam_buf_d[sphere_idx].cam_pos), grad_cam_local.cam_pos);
|
||||
if (!cam.orthogonal_projection) {
|
||||
ATOMICADD_F3(
|
||||
&(grad_cam_buf_d[sphere_idx].pixel_0_0_center),
|
||||
grad_cam_local.pixel_0_0_center);
|
||||
}
|
||||
ATOMICADD_F3(
|
||||
&(grad_cam_buf_d[sphere_idx].pixel_dir_x),
|
||||
grad_cam_local.pixel_dir_x);
|
||||
ATOMICADD_F3(
|
||||
&(grad_cam_buf_d[sphere_idx].pixel_dir_y),
|
||||
grad_cam_local.pixel_dir_y);
|
||||
ATOMICADD(&(grad_opy_d[sphere_idx]), grad_opy);
|
||||
ATOMICADD(&(grad_contributed_d[sphere_idx]), 1);
|
||||
}
|
||||
}
|
||||
END_PARALLEL_2D_NORET();
|
||||
};
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./renderer.calc_gradients.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template GLOBAL void calc_gradients<ISONDEVICE>(
|
||||
const CamInfo cam, /** Camera in world coordinates. */
|
||||
float const* const RESTRICT grad_im, /** The gradient image. */
|
||||
const float
|
||||
gamma, /** The transparency parameter used in the forward pass. */
|
||||
float3 const* const RESTRICT vert_poss, /** Vertex position vector. */
|
||||
float const* const RESTRICT vert_cols, /** Vertex color vector. */
|
||||
float const* const RESTRICT vert_rads, /** Vertex radius vector. */
|
||||
float const* const RESTRICT opacity, /** Vertex opacity. */
|
||||
const uint num_balls, /** Number of balls. */
|
||||
float const* const RESTRICT result_d, /** Result image. */
|
||||
float const* const RESTRICT forw_info_d, /** Forward pass info. */
|
||||
DrawInfo const* const RESTRICT di_d, /** Draw information. */
|
||||
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
|
||||
// Mode switches.
|
||||
const bool calc_grad_pos,
|
||||
const bool calc_grad_col,
|
||||
const bool calc_grad_rad,
|
||||
const bool calc_grad_cam,
|
||||
const bool calc_grad_opy,
|
||||
// Out variables.
|
||||
float* const RESTRICT grad_rad_d, /** Radius gradients. */
|
||||
float* const RESTRICT grad_col_d, /** Color gradients. */
|
||||
float3* const RESTRICT grad_pos_d, /** Position gradients. */
|
||||
CamGradInfo* const RESTRICT grad_cam_buf_d, /** Camera gradient buffer. */
|
||||
float* const RESTRICT grad_opy_d, /** Opacity gradient buffer. */
|
||||
int* const RESTRICT
|
||||
grad_contributed_d, /** Gradient contribution counter. */
|
||||
// Infrastructure.
|
||||
const int n_track,
|
||||
const uint offs_x,
|
||||
const uint offs_y);
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
194
pytorch3d/csrc/pulsar/include/renderer.calc_signature.device.h
Normal file
194
pytorch3d/csrc/pulsar/include/renderer.calc_signature.device.h
Normal file
@ -0,0 +1,194 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.device.h"
|
||||
#include "./commands.h"
|
||||
#include "./math.h"
|
||||
#include "./renderer.get_screen_area.device.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
GLOBAL void calc_signature(
|
||||
Renderer renderer,
|
||||
float3 const* const RESTRICT vert_poss,
|
||||
float const* const RESTRICT vert_cols,
|
||||
float const* const RESTRICT vert_rads,
|
||||
const uint num_balls) {
|
||||
/* We're not using RESTRICT here for the pointers within `renderer`. Just one
|
||||
value is being read from each of the pointers, so the effect would be
|
||||
negligible or non-existent. */
|
||||
GET_PARALLEL_IDX_1D(idx, num_balls);
|
||||
// Create aliases.
|
||||
// For reading...
|
||||
const float3& vert_pos = vert_poss[idx]; /** Vertex position. */
|
||||
const float* vert_col =
|
||||
vert_cols + idx * renderer.cam.n_channels; /** Vertex color. */
|
||||
const float& vert_rad = vert_rads[idx]; /** Vertex radius. */
|
||||
const CamInfo& cam = renderer.cam; /** Camera in world coordinates. */
|
||||
// For writing...
|
||||
/** Ball ID (either original index of the ball or -1 if not visible). */
|
||||
int& id_out = renderer.ids_d[idx];
|
||||
/** Intersection helper structure for the ball. */
|
||||
IntersectInfo& intersect_helper_out = renderer.ii_d[idx];
|
||||
/** Draw helper structure for this ball. */
|
||||
DrawInfo& draw_helper_out = renderer.di_d[idx];
|
||||
/** Minimum possible intersection depth for this ball. */
|
||||
float& closest_possible_intersect_out = renderer.min_depth_d[idx];
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|vert_pos: %.9f, %.9f, %.9f, vert_col (first three): "
|
||||
"%.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
vert_pos.x,
|
||||
vert_pos.y,
|
||||
vert_pos.z,
|
||||
vert_col[0],
|
||||
vert_col[1],
|
||||
vert_col[2]);
|
||||
// Set flags to invalid for a potential early return.
|
||||
id_out = -1; // Invalid ID.
|
||||
closest_possible_intersect_out =
|
||||
MAX_FLOAT; // These spheres are sorted to the very end.
|
||||
intersect_helper_out.max.x = MAX_USHORT; // No intersection possible.
|
||||
intersect_helper_out.min.x = MAX_USHORT;
|
||||
intersect_helper_out.max.y = MAX_USHORT;
|
||||
intersect_helper_out.min.y = MAX_USHORT;
|
||||
// Start processing.
|
||||
/** Ball center in the camera coordinate system. */
|
||||
const float3 ball_center_cam = vert_pos - cam.eye;
|
||||
/** Distance to the ball center in the camera coordinate system. */
|
||||
const float t_center = length(ball_center_cam);
|
||||
/** Closest possible intersection with this ball from the camera. */
|
||||
float closest_possible_intersect;
|
||||
if (cam.orthogonal_projection) {
|
||||
const float3 ball_center_cam_rot = rotate(
|
||||
ball_center_cam,
|
||||
cam.pixel_dir_x / length(cam.pixel_dir_x),
|
||||
cam.pixel_dir_y / length(cam.pixel_dir_y),
|
||||
cam.sensor_dir_z);
|
||||
closest_possible_intersect = ball_center_cam_rot.z - vert_rad;
|
||||
} else {
|
||||
closest_possible_intersect = t_center - vert_rad;
|
||||
}
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|t_center: %f. vert_rad: %f. "
|
||||
"closest_possible_intersect: %f.\n",
|
||||
idx,
|
||||
t_center,
|
||||
vert_rad,
|
||||
closest_possible_intersect);
|
||||
/**
|
||||
* Corner points of the enclosing projected rectangle of the ball.
|
||||
* They are first calculated in the camera coordinate system, then
|
||||
* converted to the pixel coordinate system.
|
||||
*/
|
||||
float x_1, x_2, y_1, y_2;
|
||||
bool hits_screen_plane;
|
||||
float3 ray_center_norm = ball_center_cam / t_center;
|
||||
PASSERT(vert_rad >= 0.f);
|
||||
if (closest_possible_intersect < cam.min_dist ||
|
||||
closest_possible_intersect > cam.max_dist) {
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|ignoring sphere out of min/max bounds: %.9f, "
|
||||
"min: %.9f, max: %.9f.\n",
|
||||
idx,
|
||||
closest_possible_intersect,
|
||||
cam.min_dist,
|
||||
cam.max_dist);
|
||||
RETURN_PARALLEL();
|
||||
}
|
||||
// Find the relevant region on the screen plane.
|
||||
hits_screen_plane = get_screen_area(
|
||||
ball_center_cam,
|
||||
ray_center_norm,
|
||||
vert_rad,
|
||||
cam,
|
||||
idx,
|
||||
&x_1,
|
||||
&x_2,
|
||||
&y_1,
|
||||
&y_2);
|
||||
if (!hits_screen_plane)
|
||||
RETURN_PARALLEL();
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|in pixels: x_1: %f, x_2: %f, y_1: %f, y_2: %f.\n",
|
||||
idx,
|
||||
x_1,
|
||||
x_2,
|
||||
y_1,
|
||||
y_2);
|
||||
// Check whether the pixel coordinates are on screen.
|
||||
if (FMAX(x_1, x_2) <= static_cast<float>(cam.film_border_left) ||
|
||||
FMIN(x_1, x_2) >=
|
||||
static_cast<float>(cam.film_border_left + cam.film_width) - 0.5f ||
|
||||
FMAX(y_1, y_2) <= static_cast<float>(cam.film_border_top) ||
|
||||
FMIN(y_1, y_2) >
|
||||
static_cast<float>(cam.film_border_top + cam.film_height) - 0.5f)
|
||||
RETURN_PARALLEL();
|
||||
// Write results.
|
||||
id_out = idx;
|
||||
intersect_helper_out.min.x = static_cast<ushort>(
|
||||
FMAX(FMIN(x_1, x_2), static_cast<float>(cam.film_border_left)));
|
||||
intersect_helper_out.min.y = static_cast<ushort>(
|
||||
FMAX(FMIN(y_1, y_2), static_cast<float>(cam.film_border_top)));
|
||||
// In the following calculations, the max that needs to be stored is
|
||||
// exclusive.
|
||||
// That means that the calculated value needs to be `ceil`ed and incremented
|
||||
// to find the correct value.
|
||||
intersect_helper_out.max.x = static_cast<ushort>(FMIN(
|
||||
FCEIL(FMAX(x_1, x_2)) + 1,
|
||||
static_cast<float>(cam.film_border_left + cam.film_width)));
|
||||
intersect_helper_out.max.y = static_cast<ushort>(FMIN(
|
||||
FCEIL(FMAX(y_1, y_2)) + 1,
|
||||
static_cast<float>(cam.film_border_top + cam.film_height)));
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|limits after refining: x_1: %u, x_2: %u, "
|
||||
"y_1: %u, y_2: %u.\n",
|
||||
idx,
|
||||
intersect_helper_out.min.x,
|
||||
intersect_helper_out.max.x,
|
||||
intersect_helper_out.min.y,
|
||||
intersect_helper_out.max.y);
|
||||
if (intersect_helper_out.min.x == MAX_USHORT) {
|
||||
id_out = -1;
|
||||
RETURN_PARALLEL();
|
||||
}
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|writing info. closest_possible_intersect: %.9f. "
|
||||
"ray_center_norm: %.9f, %.9f, %.9f. t_center: %.9f. radius: %.9f.\n",
|
||||
idx,
|
||||
closest_possible_intersect,
|
||||
ray_center_norm.x,
|
||||
ray_center_norm.y,
|
||||
ray_center_norm.z,
|
||||
t_center,
|
||||
vert_rad);
|
||||
closest_possible_intersect_out = closest_possible_intersect;
|
||||
draw_helper_out.ray_center_norm = ray_center_norm;
|
||||
draw_helper_out.t_center = t_center;
|
||||
draw_helper_out.radius = vert_rad;
|
||||
if (cam.n_channels <= 3) {
|
||||
draw_helper_out.first_color = vert_col[0];
|
||||
for (uint c_id = 1; c_id < cam.n_channels; ++c_id) {
|
||||
draw_helper_out.color_union.color[c_id - 1] = vert_col[c_id];
|
||||
}
|
||||
} else {
|
||||
draw_helper_out.color_union.ptr = const_cast<float*>(vert_col);
|
||||
}
|
||||
END_PARALLEL();
|
||||
};
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,18 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_INSTANTIATE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_INSTANTIATE_H_
|
||||
|
||||
#include "./renderer.calc_signature.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
template GLOBAL void calc_signature<ISONDEVICE>(
|
||||
Renderer renderer,
|
||||
float3 const* const RESTRICT vert_poss,
|
||||
float const* const RESTRICT vert_cols,
|
||||
float const* const RESTRICT vert_rads,
|
||||
const uint num_balls);
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
104
pytorch3d/csrc/pulsar/include/renderer.construct.device.h
Normal file
104
pytorch3d/csrc/pulsar/include/renderer.construct.device.h
Normal file
@ -0,0 +1,104 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CONSTRUCT_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_CONSTRUCT_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.device.h"
|
||||
#include "./commands.h"
|
||||
#include "./math.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
HOST void construct(
|
||||
Renderer* self,
|
||||
const size_t& max_num_balls,
|
||||
const int& width,
|
||||
const int& height,
|
||||
const bool& orthogonal_projection,
|
||||
const bool& right_handed_system,
|
||||
const float& background_normalization_depth,
|
||||
const uint& n_channels,
|
||||
const uint& n_track) {
|
||||
ARGCHECK(
|
||||
(max_num_balls > 0 && max_num_balls < MAX_INT),
|
||||
2,
|
||||
("the maximum number of balls must be >0 and <" +
|
||||
std::to_string(MAX_INT) + ". Is " + std::to_string(max_num_balls) + ".")
|
||||
.c_str());
|
||||
ARGCHECK(width > 1, 3, "the image width must be > 1");
|
||||
ARGCHECK(height > 1, 4, "the image height must be > 1");
|
||||
ARGCHECK(
|
||||
background_normalization_depth > 0.f &&
|
||||
background_normalization_depth < 1.f,
|
||||
6,
|
||||
"background_normalization_depth must be in ]0., 1.[.");
|
||||
ARGCHECK(n_channels > 0, 7, "n_channels must be >0!");
|
||||
ARGCHECK(
|
||||
n_track > 0 && n_track <= MAX_GRAD_SPHERES,
|
||||
8,
|
||||
("n_track must be >0 and <" + std::to_string(MAX_GRAD_SPHERES) + ". Is " +
|
||||
std::to_string(n_track) + ".")
|
||||
.c_str());
|
||||
self->cam.film_width = width;
|
||||
self->cam.film_height = height;
|
||||
self->max_num_balls = max_num_balls;
|
||||
MALLOC(self->result_d, float, width* height* n_channels);
|
||||
self->cam.orthogonal_projection = orthogonal_projection;
|
||||
self->cam.right_handed = right_handed_system;
|
||||
self->cam.background_normalization_depth = background_normalization_depth;
|
||||
self->cam.n_channels = n_channels;
|
||||
MALLOC(self->min_depth_d, float, max_num_balls);
|
||||
MALLOC(self->min_depth_sorted_d, float, max_num_balls);
|
||||
MALLOC(self->ii_d, IntersectInfo, max_num_balls);
|
||||
MALLOC(self->ii_sorted_d, IntersectInfo, max_num_balls);
|
||||
MALLOC(self->ids_d, int, max_num_balls);
|
||||
MALLOC(self->ids_sorted_d, int, max_num_balls);
|
||||
size_t sort_id_size = 0;
|
||||
GET_SORT_WS_SIZE(&sort_id_size, float, int, max_num_balls);
|
||||
CHECKLAUNCH();
|
||||
size_t sort_ii_size = 0;
|
||||
GET_SORT_WS_SIZE(&sort_ii_size, float, IntersectInfo, max_num_balls);
|
||||
CHECKLAUNCH();
|
||||
size_t sort_di_size = 0;
|
||||
GET_SORT_WS_SIZE(&sort_di_size, float, DrawInfo, max_num_balls);
|
||||
CHECKLAUNCH();
|
||||
size_t select_ii_size = 0;
|
||||
GET_SELECT_WS_SIZE(&select_ii_size, char, IntersectInfo, max_num_balls);
|
||||
size_t select_di_size = 0;
|
||||
GET_SELECT_WS_SIZE(&select_di_size, char, DrawInfo, max_num_balls);
|
||||
size_t sum_size = 0;
|
||||
GET_SUM_WS_SIZE(&sum_size, CamGradInfo, max_num_balls);
|
||||
size_t sum_cont_size = 0;
|
||||
GET_SUM_WS_SIZE(&sum_cont_size, int, max_num_balls);
|
||||
size_t reduce_size = 0;
|
||||
GET_REDUCE_WS_SIZE(
|
||||
&reduce_size, IntersectInfo, IntersectInfoMinMax(), max_num_balls);
|
||||
self->workspace_size = IMAX(
|
||||
IMAX(IMAX(sort_id_size, sort_ii_size), sort_di_size),
|
||||
IMAX(
|
||||
IMAX(select_di_size, select_ii_size),
|
||||
IMAX(IMAX(sum_size, sum_cont_size), reduce_size)));
|
||||
MALLOC(self->workspace_d, char, self->workspace_size);
|
||||
MALLOC(self->di_d, DrawInfo, max_num_balls);
|
||||
MALLOC(self->di_sorted_d, DrawInfo, max_num_balls);
|
||||
MALLOC(self->region_flags_d, char, max_num_balls);
|
||||
MALLOC(self->num_selected_d, size_t, 1);
|
||||
MALLOC(self->forw_info_d, float, width* height*(3 + 2 * n_track));
|
||||
MALLOC(self->min_max_pixels_d, IntersectInfo, 1);
|
||||
MALLOC(self->grad_pos_d, float3, max_num_balls);
|
||||
MALLOC(self->grad_col_d, float, max_num_balls* n_channels);
|
||||
MALLOC(self->grad_rad_d, float, max_num_balls);
|
||||
MALLOC(self->grad_cam_d, float, 12);
|
||||
MALLOC(self->grad_cam_buf_d, CamGradInfo, max_num_balls);
|
||||
MALLOC(self->grad_opy_d, float, max_num_balls);
|
||||
MALLOC(self->n_grad_contributions_d, int, 1);
|
||||
self->n_track = static_cast<int>(n_track);
|
||||
}
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,22 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CONSTRUCT_INSTANTIATE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_CONSTRUCT_INSTANTIATE_H_
|
||||
|
||||
#include "./renderer.construct.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
template void construct<ISONDEVICE>(
|
||||
Renderer* self,
|
||||
const size_t& max_num_balls,
|
||||
const int& width,
|
||||
const int& height,
|
||||
const bool& orthogonal_projection,
|
||||
const bool& right_handed_system,
|
||||
const float& background_normalization_depth,
|
||||
const uint& n_channels,
|
||||
const uint& n_track);
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,34 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CREATE_SELECTOR_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_CREATE_SELECTOR_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./commands.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
GLOBAL void create_selector(
|
||||
IntersectInfo const* const RESTRICT ii_sorted_d,
|
||||
const uint num_balls,
|
||||
const int min_x,
|
||||
const int max_x,
|
||||
const int min_y,
|
||||
const int max_y,
|
||||
/* Out variables. */
|
||||
char* RESTRICT region_flags_d) {
|
||||
GET_PARALLEL_IDX_1D(idx, num_balls);
|
||||
bool hit = (static_cast<int>(ii_sorted_d[idx].min.x) <= max_x) &&
|
||||
(static_cast<int>(ii_sorted_d[idx].max.x) > min_x) &&
|
||||
(static_cast<int>(ii_sorted_d[idx].min.y) <= max_y) &&
|
||||
(static_cast<int>(ii_sorted_d[idx].max.y) > min_y);
|
||||
region_flags_d[idx] = hit;
|
||||
END_PARALLEL_NORET();
|
||||
}
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,23 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CREATE_SELECTOR_INSTANTIATE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_CREATE_SELECTOR_INSTANTIATE_H_
|
||||
|
||||
#include "./renderer.create_selector.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template GLOBAL void create_selector<ISONDEVICE>(
|
||||
IntersectInfo const* const RESTRICT ii_sorted_d,
|
||||
const uint num_balls,
|
||||
const int min_x,
|
||||
const int max_x,
|
||||
const int min_y,
|
||||
const int max_y,
|
||||
/* Out variables. */
|
||||
char* RESTRICT region_flags_d);
|
||||
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
82
pytorch3d/csrc/pulsar/include/renderer.destruct.device.h
Normal file
82
pytorch3d/csrc/pulsar/include/renderer.destruct.device.h
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_DESTRUCT_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_DESTRUCT_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./commands.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
HOST void destruct(Renderer* self) {
|
||||
if (self->result_d != NULL)
|
||||
FREE(self->result_d);
|
||||
self->result_d = NULL;
|
||||
if (self->min_depth_d != NULL)
|
||||
FREE(self->min_depth_d);
|
||||
self->min_depth_d = NULL;
|
||||
if (self->min_depth_sorted_d != NULL)
|
||||
FREE(self->min_depth_sorted_d);
|
||||
self->min_depth_sorted_d = NULL;
|
||||
if (self->ii_d != NULL)
|
||||
FREE(self->ii_d);
|
||||
self->ii_d = NULL;
|
||||
if (self->ii_sorted_d != NULL)
|
||||
FREE(self->ii_sorted_d);
|
||||
self->ii_sorted_d = NULL;
|
||||
if (self->ids_d != NULL)
|
||||
FREE(self->ids_d);
|
||||
self->ids_d = NULL;
|
||||
if (self->ids_sorted_d != NULL)
|
||||
FREE(self->ids_sorted_d);
|
||||
self->ids_sorted_d = NULL;
|
||||
if (self->workspace_d != NULL)
|
||||
FREE(self->workspace_d);
|
||||
self->workspace_d = NULL;
|
||||
if (self->di_d != NULL)
|
||||
FREE(self->di_d);
|
||||
self->di_d = NULL;
|
||||
if (self->di_sorted_d != NULL)
|
||||
FREE(self->di_sorted_d);
|
||||
self->di_sorted_d = NULL;
|
||||
if (self->region_flags_d != NULL)
|
||||
FREE(self->region_flags_d);
|
||||
self->region_flags_d = NULL;
|
||||
if (self->num_selected_d != NULL)
|
||||
FREE(self->num_selected_d);
|
||||
self->num_selected_d = NULL;
|
||||
if (self->forw_info_d != NULL)
|
||||
FREE(self->forw_info_d);
|
||||
self->forw_info_d = NULL;
|
||||
if (self->min_max_pixels_d != NULL)
|
||||
FREE(self->min_max_pixels_d);
|
||||
self->min_max_pixels_d = NULL;
|
||||
if (self->grad_pos_d != NULL)
|
||||
FREE(self->grad_pos_d);
|
||||
self->grad_pos_d = NULL;
|
||||
if (self->grad_col_d != NULL)
|
||||
FREE(self->grad_col_d);
|
||||
self->grad_col_d = NULL;
|
||||
if (self->grad_rad_d != NULL)
|
||||
FREE(self->grad_rad_d);
|
||||
self->grad_rad_d = NULL;
|
||||
if (self->grad_cam_d != NULL)
|
||||
FREE(self->grad_cam_d);
|
||||
self->grad_cam_d = NULL;
|
||||
if (self->grad_cam_buf_d != NULL)
|
||||
FREE(self->grad_cam_buf_d);
|
||||
self->grad_cam_buf_d = NULL;
|
||||
if (self->grad_opy_d != NULL)
|
||||
FREE(self->grad_opy_d);
|
||||
self->grad_opy_d = NULL;
|
||||
if (self->n_grad_contributions_d != NULL)
|
||||
FREE(self->n_grad_contributions_d);
|
||||
self->n_grad_contributions_d = NULL;
|
||||
}
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,13 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_DESTRUCT_INSTANTIATE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_DESTRUCT_INSTANTIATE_H_
|
||||
|
||||
#include "./renderer.destruct.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
template void destruct<ISONDEVICE>(Renderer* self);
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
839
pytorch3d/csrc/pulsar/include/renderer.draw.device.h
Normal file
839
pytorch3d/csrc/pulsar/include/renderer.draw.device.h
Normal file
@ -0,0 +1,839 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_CALC_SIGNATURE_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.device.h"
|
||||
#include "./commands.h"
|
||||
#include "./math.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
/**
|
||||
* Draw a ball into the `result`.
|
||||
*
|
||||
* Returns whether a hit was noticed. See README for an explanation of sphere
|
||||
* points and variable notation.
|
||||
*/
|
||||
INLINE DEVICE bool draw(
|
||||
/* In variables. */
|
||||
const DrawInfo& draw_info, /** The draw information for this ball. */
|
||||
const float& opacity, /** The sphere opacity. */
|
||||
const CamInfo&
|
||||
cam, /** Camera information. Doesn't have to be normalized. */
|
||||
const float& gamma, /** 'Transparency' indicator (see paper for details). */
|
||||
const float3& ray_dir_norm, /** The direction of the ray, normalized. */
|
||||
const float2& projected_ray, /** The intersection of the ray with the image
|
||||
in pixel space. */
|
||||
/** Mode switches. */
|
||||
const bool& draw_only, /** Whether we are in draw vs. grad mode. */
|
||||
const bool& calc_grad_pos, /** Calculate position gradients. */
|
||||
const bool& calc_grad_col, /** Calculate color gradients. */
|
||||
const bool& calc_grad_rad, /** Calculate radius gradients. */
|
||||
const bool& calc_grad_cam, /** Calculate camera gradients. */
|
||||
const bool& calc_grad_opy, /** Calculate opacity gradients. */
|
||||
/** Position info. */
|
||||
const uint& coord_x, /** The pixel position x to draw at. */
|
||||
const uint& coord_y, /** The pixel position y to draw at. */
|
||||
const uint& idx, /** The id of the sphere to process. */
|
||||
/* Optional in variables. */
|
||||
IntersectInfo const* const RESTRICT
|
||||
intersect_info, /** The intersect information for this ball. */
|
||||
float3 const* const RESTRICT ray_dir, /** The ray direction (not normalized)
|
||||
to draw at. Only used for grad computation. */
|
||||
float const* const RESTRICT norm_ray_dir, /** The length of the direction
|
||||
vector. Only used for grad computation. */
|
||||
float const* const RESTRICT grad_pix, /** The gradient for this pixel. Only
|
||||
used for grad computation. */
|
||||
float const* const RESTRICT
|
||||
ln_pad_over_1minuspad, /** Allowed percentage indicator. */
|
||||
/* In or out variables, depending on mode. */
|
||||
float* const RESTRICT sm_d, /** Normalization denominator. */
|
||||
float* const RESTRICT
|
||||
sm_m, /** Maximum of normalization weight factors observed. */
|
||||
float* const RESTRICT
|
||||
result, /** Result pixel color. Must be zeros initially. */
|
||||
/* Optional out variables. */
|
||||
float* const RESTRICT depth_threshold, /** The depth threshold to use. Only
|
||||
used for rendering. */
|
||||
float* const RESTRICT intersection_depth_norm_out, /** The intersection
|
||||
depth. Only set when rendering. */
|
||||
float3* const RESTRICT grad_pos, /** Gradient w.r.t. position. */
|
||||
float* const RESTRICT grad_col, /** Gradient w.r.t. color. */
|
||||
float* const RESTRICT grad_rad, /** Gradient w.r.t. radius. */
|
||||
CamGradInfo* const RESTRICT grad_cam, /** Gradient w.r.t. camera. */
|
||||
float* const RESTRICT grad_opy /** Gradient w.r.t. opacity. */
|
||||
) {
|
||||
// TODO: variable reuse?
|
||||
PASSERT(
|
||||
isfinite(draw_info.ray_center_norm.x) &&
|
||||
isfinite(draw_info.ray_center_norm.y) &&
|
||||
isfinite(draw_info.ray_center_norm.z));
|
||||
PASSERT(isfinite(draw_info.t_center) && draw_info.t_center >= 0.f);
|
||||
PASSERT(
|
||||
isfinite(draw_info.radius) && draw_info.radius >= 0.f &&
|
||||
draw_info.radius <= draw_info.t_center);
|
||||
PASSERT(isfinite(ray_dir_norm.x));
|
||||
PASSERT(isfinite(ray_dir_norm.y));
|
||||
PASSERT(isfinite(ray_dir_norm.z));
|
||||
PASSERT(isfinite(*sm_d));
|
||||
PASSERT(
|
||||
cam.orthogonal_projection && cam.focal_length == 0.f ||
|
||||
cam.focal_length > 0.f);
|
||||
PASSERT(gamma <= 1.f && gamma >= 1e-5f);
|
||||
/** The ball center in the camera coordinate system. */
|
||||
float3 center = draw_info.ray_center_norm * draw_info.t_center;
|
||||
/** The vector from the reference point to the ball center. */
|
||||
float3 raydiff;
|
||||
if (cam.orthogonal_projection) {
|
||||
center = rotate(
|
||||
center,
|
||||
cam.pixel_dir_x / length(cam.pixel_dir_x),
|
||||
cam.pixel_dir_y / length(cam.pixel_dir_y),
|
||||
cam.sensor_dir_z);
|
||||
raydiff =
|
||||
make_float3( // TODO: make offset consistent with `get_screen_area`.
|
||||
center.x -
|
||||
(projected_ray.x -
|
||||
static_cast<float>(cam.aperture_width) * .5f) *
|
||||
(2.f * cam.half_pixel_size),
|
||||
center.y -
|
||||
(projected_ray.y -
|
||||
static_cast<float>(cam.aperture_height) * .5f) *
|
||||
(2.f * cam.half_pixel_size),
|
||||
0.f);
|
||||
} else {
|
||||
/** The reference point on the ray; the point in the same distance
|
||||
* from the camera as the ball center, but along the ray.
|
||||
*/
|
||||
const float3 rayref = ray_dir_norm * draw_info.t_center;
|
||||
raydiff = center - rayref;
|
||||
}
|
||||
/** The closeness of the reference point to ball center in world coords.
|
||||
*
|
||||
* In [0., radius].
|
||||
*/
|
||||
const float closeness_world = length(raydiff);
|
||||
/** The reciprocal radius. */
|
||||
const float radius_rcp = FRCP(draw_info.radius);
|
||||
/** The closeness factor normalized with the ball radius.
|
||||
*
|
||||
* In [0., 1.].
|
||||
*/
|
||||
float closeness = FSATURATE(FMA(-closeness_world, radius_rcp, 1.f));
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_DRAW_PIX,
|
||||
"drawprep %u|center: %.9f, %.9f, %.9f. raydiff: %.9f, "
|
||||
"%.9f, %.9f. closeness_world: %.9f. closeness: %.9f\n",
|
||||
idx,
|
||||
center.x,
|
||||
center.y,
|
||||
center.z,
|
||||
raydiff.x,
|
||||
raydiff.y,
|
||||
raydiff.z,
|
||||
closeness_world,
|
||||
closeness);
|
||||
/** Whether this is the 'center pixel' for this ball, the pixel that
|
||||
* is closest to its projected center. This information is used to
|
||||
* make sure to draw 'tiny' spheres with less than one pixel in
|
||||
* projected size.
|
||||
*/
|
||||
bool ray_through_center_pixel;
|
||||
float projected_radius, projected_x, projected_y;
|
||||
if (cam.orthogonal_projection) {
|
||||
projected_x = center.x / (2.f * cam.half_pixel_size) +
|
||||
(static_cast<float>(cam.aperture_width) - 1.f) / 2.f;
|
||||
projected_y = center.y / (2.f * cam.half_pixel_size) +
|
||||
(static_cast<float>(cam.aperture_height) - 1.f) / 2.f;
|
||||
projected_radius = draw_info.radius / (2.f * cam.half_pixel_size);
|
||||
ray_through_center_pixel =
|
||||
(FABS(FSUB(projected_x, projected_ray.x)) < 0.5f + FEPS &&
|
||||
FABS(FSUB(projected_y, projected_ray.y)) < 0.5f + FEPS);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_DRAW_PIX,
|
||||
"drawprep %u|closeness_world: %.9f. closeness: %.9f. "
|
||||
"projected (x, y): %.9f, %.9f. projected_ray (x, y): "
|
||||
"%.9f, %.9f. ray_through_center_pixel: %d.\n",
|
||||
idx,
|
||||
closeness_world,
|
||||
closeness,
|
||||
projected_x,
|
||||
projected_y,
|
||||
projected_ray.x,
|
||||
projected_ray.y,
|
||||
ray_through_center_pixel);
|
||||
} else {
|
||||
// Misusing this variable for half pixel size projected to the depth
|
||||
// at which the sphere resides. Leave some slack for numerical
|
||||
// inaccuracy (factor 1.5).
|
||||
projected_x = FMUL(cam.half_pixel_size * 1.5, draw_info.t_center) *
|
||||
FRCP(cam.focal_length);
|
||||
projected_radius = FMUL(draw_info.radius, cam.focal_length) *
|
||||
FRCP(draw_info.t_center) / (2.f * cam.half_pixel_size);
|
||||
ray_through_center_pixel = projected_x > closeness_world;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_DRAW_PIX,
|
||||
"drawprep %u|closeness_world: %.9f. closeness: %.9f. "
|
||||
"projected half pixel size: %.9f. "
|
||||
"ray_through_center_pixel: %d.\n",
|
||||
idx,
|
||||
closeness_world,
|
||||
closeness,
|
||||
projected_x,
|
||||
ray_through_center_pixel);
|
||||
}
|
||||
if (draw_only && draw_info.radius < closeness_world &&
|
||||
!ray_through_center_pixel) {
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_DRAW_PIX,
|
||||
"drawprep %u|Abandoning since no hit has been detected.\n",
|
||||
idx);
|
||||
return false;
|
||||
} else {
|
||||
// This is always a hit since we are following the forward execution pass.
|
||||
// p2 is the closest intersection point with the sphere.
|
||||
}
|
||||
if (ray_through_center_pixel && projected_radius < 1.f) {
|
||||
// Make a tiny sphere visible.
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_DRAW_PIX,
|
||||
"drawprep %u|Setting closeness to 1 (projected radius: %.9f).\n",
|
||||
idx,
|
||||
projected_radius);
|
||||
closeness = 1.;
|
||||
}
|
||||
PASSERT(closeness >= 0.f && closeness <= 1.f);
|
||||
/** Distance between the camera (`o`) and `p1`, the closest point to the
|
||||
* ball center along the casted ray.
|
||||
*
|
||||
* In [t_center - radius, t_center].
|
||||
*/
|
||||
float o__p1_;
|
||||
/** The distance from ball center to p1.
|
||||
*
|
||||
* In [0., sqrt(t_center ^ 2 - (t_center - radius) ^ 2)].
|
||||
*/
|
||||
float c__p1_;
|
||||
if (cam.orthogonal_projection) {
|
||||
o__p1_ = FABS(center.z);
|
||||
c__p1_ = length(raydiff);
|
||||
} else {
|
||||
o__p1_ = dot(center, ray_dir_norm);
|
||||
/**
|
||||
* This is being calculated as sqrt(t_center^2 - o__p1_^2) =
|
||||
* sqrt((t_center + o__p1_) * (t_center - o__p1_)) to avoid
|
||||
* catastrophic cancellation in floating point representations.
|
||||
*/
|
||||
c__p1_ = FSQRT(
|
||||
(draw_info.t_center + o__p1_) * FMAX(draw_info.t_center - o__p1_, 0.f));
|
||||
// PASSERT(o__p1_ >= draw_info.t_center - draw_info.radius);
|
||||
// Numerical errors lead to too large values.
|
||||
o__p1_ = FMIN(o__p1_, draw_info.t_center);
|
||||
// PASSERT(o__p1_ <= draw_info.t_center);
|
||||
}
|
||||
/** The distance from the closest point to the sphere center (p1)
|
||||
* to the closest intersection point (p2).
|
||||
*
|
||||
* In [0., radius].
|
||||
*/
|
||||
const float p1__p2_ =
|
||||
FSQRT((draw_info.radius + c__p1_) * FMAX(draw_info.radius - c__p1_, 0.f));
|
||||
PASSERT(p1__p2_ >= 0.f && p1__p2_ <= draw_info.radius);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_DRAW_PIX,
|
||||
"drawprep %u|o__p1_: %.9f, c__p1_: %.9f, p1__p2_: %.9f.\n",
|
||||
idx,
|
||||
o__p1_,
|
||||
c__p1_,
|
||||
p1__p2_);
|
||||
/** The intersection depth of the ray with this ball.
|
||||
*
|
||||
* In [t_center - radius, t_center].
|
||||
*/
|
||||
const float intersection_depth = (o__p1_ - p1__p2_);
|
||||
PASSERT(
|
||||
cam.orthogonal_projection &&
|
||||
(intersection_depth >= center.z - draw_info.radius &&
|
||||
intersection_depth <= center.z) ||
|
||||
intersection_depth >= draw_info.t_center - draw_info.radius &&
|
||||
intersection_depth <= draw_info.t_center);
|
||||
/** Normalized distance of the closest intersection point; in [0., 1.]. */
|
||||
const float norm_dist =
|
||||
FMUL(FSUB(intersection_depth, cam.min_dist), cam.norm_fac);
|
||||
PASSERT(norm_dist >= 0.f && norm_dist <= 1.f);
|
||||
/** Scaled, normalized distance in [1., 0.] (closest, farthest). */
|
||||
const float norm_dist_scaled = FSUB(1.f, norm_dist) / gamma * opacity;
|
||||
PASSERT(norm_dist_scaled >= 0.f && norm_dist_scaled <= 1.f / gamma);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_DRAW_PIX,
|
||||
"drawprep %u|intersection_depth: %.9f, norm_dist: %.9f, "
|
||||
"norm_dist_scaled: %.9f.\n",
|
||||
idx,
|
||||
intersection_depth,
|
||||
norm_dist,
|
||||
norm_dist_scaled);
|
||||
float const* const col_ptr =
|
||||
cam.n_channels > 3 ? draw_info.color_union.ptr : &draw_info.first_color;
|
||||
// The implementation for the numerically stable weighted softmax is based
|
||||
// on https://arxiv.org/pdf/1805.02867.pdf .
|
||||
if (draw_only) {
|
||||
/** The old maximum observed value. */
|
||||
const float sm_m_old = *sm_m;
|
||||
*sm_m = FMAX(*sm_m, norm_dist_scaled);
|
||||
const float coeff_exp = FEXP(norm_dist_scaled - *sm_m);
|
||||
PASSERT(isfinite(coeff_exp));
|
||||
/** The color coefficient for the ball color; in [0., 1.]. */
|
||||
const float coeff = closeness * coeff_exp * opacity;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_DRAW_PIX,
|
||||
"draw %u|coeff: %.9f. closeness: %.9f. coeff_exp: %.9f. "
|
||||
"opacity: %.9f.\n",
|
||||
idx,
|
||||
coeff,
|
||||
closeness,
|
||||
coeff_exp,
|
||||
opacity);
|
||||
// Rendering.
|
||||
if (sm_m_old == *sm_m) {
|
||||
// Use the fact that exp(0) = 1 to avoid the exp calculation for
|
||||
// the case that the maximum remains the same (which it should
|
||||
// most of the time).
|
||||
*sm_d = FADD(*sm_d, coeff);
|
||||
for (uint c_id = 0; c_id < cam.n_channels; ++c_id) {
|
||||
PASSERT(isfinite(result[c_id]));
|
||||
result[c_id] = FMA(coeff, col_ptr[c_id], result[c_id]);
|
||||
}
|
||||
} else {
|
||||
const float exp_correction = FEXP(sm_m_old - *sm_m);
|
||||
*sm_d = FMA(*sm_d, exp_correction, coeff);
|
||||
for (uint c_id = 0; c_id < cam.n_channels; ++c_id) {
|
||||
PASSERT(isfinite(result[c_id]));
|
||||
result[c_id] =
|
||||
FMA(coeff, col_ptr[c_id], FMUL(result[c_id], exp_correction));
|
||||
}
|
||||
}
|
||||
PASSERT(isfinite(*sm_d));
|
||||
*intersection_depth_norm_out = intersection_depth;
|
||||
// Update the depth threshold.
|
||||
*depth_threshold =
|
||||
1.f - (FLN(*sm_d + FEPS) + *ln_pad_over_1minuspad + *sm_m) * gamma;
|
||||
*depth_threshold =
|
||||
FMA(*depth_threshold, FSUB(cam.max_dist, cam.min_dist), cam.min_dist);
|
||||
} else {
|
||||
// Gradient computation.
|
||||
const float coeff_exp = FEXP(norm_dist_scaled - *sm_m);
|
||||
const float gamma_rcp = FRCP(gamma);
|
||||
const float radius_sq = FMUL(draw_info.radius, draw_info.radius);
|
||||
const float coeff = FMAX(
|
||||
FMIN(closeness * coeff_exp * opacity, *sm_d - FEPS),
|
||||
0.f); // in [0., sm_d - FEPS].
|
||||
PASSERT(coeff >= 0.f && coeff <= *sm_d);
|
||||
const float otherw = *sm_d - coeff; // in [FEPS, sm_d].
|
||||
const float p1__p2_safe = FMAX(p1__p2_, FEPS); // in [eps, t_center].
|
||||
const float cam_range = FSUB(cam.max_dist, cam.min_dist); // in ]0, inf[
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|pos: %.9f, %.9f, %.9f. pixeldirx: %.9f, %.9f, %.9f. "
|
||||
"pixeldiry: %.9f, %.9f, %.9f. pixel00center: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
draw_info.ray_center_norm.x * draw_info.t_center,
|
||||
draw_info.ray_center_norm.y * draw_info.t_center,
|
||||
draw_info.ray_center_norm.z * draw_info.t_center,
|
||||
cam.pixel_dir_x.x,
|
||||
cam.pixel_dir_x.y,
|
||||
cam.pixel_dir_x.z,
|
||||
cam.pixel_dir_y.x,
|
||||
cam.pixel_dir_y.y,
|
||||
cam.pixel_dir_y.z,
|
||||
cam.pixel_0_0_center.x,
|
||||
cam.pixel_0_0_center.y,
|
||||
cam.pixel_0_0_center.z);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|ray_dir: %.9f, %.9f, %.9f. "
|
||||
"ray_dir_norm: %.9f, %.9f, %.9f. "
|
||||
"draw_info.ray_center_norm: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
ray_dir->x,
|
||||
ray_dir->y,
|
||||
ray_dir->z,
|
||||
ray_dir_norm.x,
|
||||
ray_dir_norm.y,
|
||||
ray_dir_norm.z,
|
||||
draw_info.ray_center_norm.x,
|
||||
draw_info.ray_center_norm.y,
|
||||
draw_info.ray_center_norm.z);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|coeff_exp: %.9f. "
|
||||
"norm_dist_scaled: %.9f. cam.norm_fac: %f.\n",
|
||||
idx,
|
||||
coeff_exp,
|
||||
norm_dist_scaled,
|
||||
cam.norm_fac);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|p1__p2_: %.9f. p1__p2_safe: %.9f.\n",
|
||||
idx,
|
||||
p1__p2_,
|
||||
p1__p2_safe);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|o__p1_: %.9f. c__p1_: %.9f.\n",
|
||||
idx,
|
||||
o__p1_,
|
||||
c__p1_);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|intersection_depth: %f. norm_dist: %f. "
|
||||
"coeff: %.9f. closeness: %f. coeff_exp: %f. opacity: "
|
||||
"%f. color: %f, %f, %f.\n",
|
||||
idx,
|
||||
intersection_depth,
|
||||
norm_dist,
|
||||
coeff,
|
||||
closeness,
|
||||
coeff_exp,
|
||||
opacity,
|
||||
draw_info.first_color,
|
||||
draw_info.color_union.color[0],
|
||||
draw_info.color_union.color[1]);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|t_center: %.9f. "
|
||||
"radius: %.9f. max_dist: %f. min_dist: %f. gamma: %f.\n",
|
||||
idx,
|
||||
draw_info.t_center,
|
||||
draw_info.radius,
|
||||
cam.max_dist,
|
||||
cam.min_dist,
|
||||
gamma);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|sm_d: %f. sm_m: %f. grad_pix (first three): %f, %f, %f.\n",
|
||||
idx,
|
||||
*sm_d,
|
||||
*sm_m,
|
||||
grad_pix[0],
|
||||
grad_pix[1],
|
||||
grad_pix[2]);
|
||||
PULSAR_LOG_DEV_PIX(PULSAR_LOG_GRAD, "grad %u|otherw: %f.\n", idx, otherw);
|
||||
if (calc_grad_col) {
|
||||
const float sm_d_norm = FRCP(FMAX(*sm_d, FEPS));
|
||||
// First do the multiplication of coeff (in [0., sm_d]) and 1/sm_d. The
|
||||
// result is a factor in [0., 1.] to be multiplied with the incoming
|
||||
// gradient.
|
||||
for (uint c_id = 0; c_id < cam.n_channels; ++c_id) {
|
||||
ATOMICADD(grad_col + c_id, grad_pix[c_id] * FMUL(coeff, sm_d_norm));
|
||||
}
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdcol.x: %f. dresDdcol.x: %f.\n",
|
||||
idx,
|
||||
FMUL(coeff, sm_d_norm) * grad_pix[0],
|
||||
coeff * sm_d_norm);
|
||||
}
|
||||
// We disable the computation for too small spheres.
|
||||
// The comparison is made this way to avoid subtraction of unsigned types.
|
||||
if (calc_grad_cam || calc_grad_pos || calc_grad_rad || calc_grad_opy) {
|
||||
//! First find dimDdcoeff.
|
||||
const float n0 =
|
||||
otherw * FRCP(FMAX(*sm_d * *sm_d, FEPS)); // in [0., 1. / sm_d].
|
||||
PASSERT(isfinite(n0) && n0 >= 0. && n0 <= 1. / *sm_d + 1e2f * FEPS);
|
||||
// We'll aggergate dimDdcoeff over all the 'color' channels.
|
||||
float dimDdcoeff = 0.f;
|
||||
const float otherw_safe_rcp = FRCP(FMAX(otherw, FEPS));
|
||||
float othercol;
|
||||
for (uint c_id = 0; c_id < cam.n_channels; ++c_id) {
|
||||
othercol =
|
||||
(result[c_id] * *sm_d - col_ptr[c_id] * coeff) * otherw_safe_rcp;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|othercol[%u]: %.9f.\n",
|
||||
idx,
|
||||
c_id,
|
||||
othercol);
|
||||
dimDdcoeff +=
|
||||
FMUL(FMUL(grad_pix[c_id], FSUB(col_ptr[c_id], othercol)), n0);
|
||||
}
|
||||
PASSERT(isfinite(dimDdcoeff));
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdcoeff: %.9f, n0: %f.\n",
|
||||
idx,
|
||||
dimDdcoeff,
|
||||
n0);
|
||||
if (calc_grad_opy) {
|
||||
//! dimDdopacity.
|
||||
*grad_opy += dimDdcoeff * coeff_exp * closeness *
|
||||
(1.f + opacity * (1.f - norm_dist) * gamma_rcp);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcoeffDdopacity: %.9f, dimDdopacity: %.9f.\n",
|
||||
idx,
|
||||
coeff_exp * closeness,
|
||||
dimDdcoeff * coeff_exp * closeness);
|
||||
}
|
||||
if (intersect_info->max.x >= intersect_info->min.x + 3 &&
|
||||
intersect_info->max.y >= intersect_info->min.y + 3) {
|
||||
//! Now find dcoeffDdintersection_depth and dcoeffDdcloseness.
|
||||
const float dcoeffDdintersection_depth =
|
||||
-closeness * coeff_exp * opacity * opacity / (gamma * cam_range);
|
||||
const float dcoeffDdcloseness = coeff_exp * opacity;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcoeffDdintersection_depth: %.9f. "
|
||||
"dimDdintersection_depth: %.9f. "
|
||||
"dcoeffDdcloseness: %.9f. dimDdcloseness: %.9f.\n",
|
||||
idx,
|
||||
dcoeffDdintersection_depth,
|
||||
dimDdcoeff * dcoeffDdintersection_depth,
|
||||
dcoeffDdcloseness,
|
||||
dimDdcoeff * dcoeffDdcloseness);
|
||||
//! Here, the execution paths for orthogonal and pinyhole camera split.
|
||||
if (cam.orthogonal_projection) {
|
||||
if (calc_grad_rad) {
|
||||
//! Find dcoeffDdrad.
|
||||
float dcoeffDdrad =
|
||||
dcoeffDdcloseness * (closeness_world / radius_sq) -
|
||||
dcoeffDdintersection_depth * draw_info.radius / p1__p2_safe;
|
||||
PASSERT(isfinite(dcoeffDdrad));
|
||||
*grad_rad += FMUL(dimDdcoeff, dcoeffDdrad);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdrad: %.9f. dcoeffDdrad: %.9f.\n",
|
||||
idx,
|
||||
FMUL(dimDdcoeff, dcoeffDdrad),
|
||||
dcoeffDdrad);
|
||||
}
|
||||
if (calc_grad_pos || calc_grad_cam) {
|
||||
float3 dimDdcenter = raydiff /
|
||||
p1__p2_safe; /* making it dintersection_depthDdcenter. */
|
||||
dimDdcenter.z = sign_dir(center.z);
|
||||
PASSERT(FABS(center.z) >= cam.min_dist && cam.min_dist >= FEPS);
|
||||
dimDdcenter *= dcoeffDdintersection_depth; // dcoeffDdcenter
|
||||
dimDdcenter -= dcoeffDdcloseness * /* dclosenessDdcenter. */
|
||||
raydiff * FRCP(FMAX(length(raydiff) * draw_info.radius, FEPS));
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcoeffDdcenter: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dimDdcenter.x,
|
||||
dimDdcenter.y,
|
||||
dimDdcenter.z);
|
||||
// Now dcoeffDdcenter is stored in dimDdcenter.
|
||||
dimDdcenter *= dimDdcoeff;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdcenter: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dimDdcenter.x,
|
||||
dimDdcenter.y,
|
||||
dimDdcenter.z);
|
||||
// Prepare for posglob and cam pos.
|
||||
const float pixel_size = length(cam.pixel_dir_x);
|
||||
// pixel_size is the same as length(pixeldiry)!
|
||||
const float pixel_size_rcp = FRCP(pixel_size);
|
||||
float3 dcenterDdposglob =
|
||||
(cam.pixel_dir_x + cam.pixel_dir_y) * pixel_size_rcp +
|
||||
cam.sensor_dir_z;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcenterDdposglob: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dcenterDdposglob.x,
|
||||
dcenterDdposglob.y,
|
||||
dcenterDdposglob.z);
|
||||
if (calc_grad_pos) {
|
||||
//! dcenterDdposglob.
|
||||
*grad_pos += dimDdcenter * dcenterDdposglob;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdpos: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dimDdcenter.x * dcenterDdposglob.x,
|
||||
dimDdcenter.y * dcenterDdposglob.y,
|
||||
dimDdcenter.z * dcenterDdposglob.z);
|
||||
}
|
||||
if (calc_grad_cam) {
|
||||
//! Camera.
|
||||
grad_cam->cam_pos -= dimDdcenter * dcenterDdposglob;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdeye: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
-dimDdcenter.x * dcenterDdposglob.x,
|
||||
-dimDdcenter.y * dcenterDdposglob.y,
|
||||
-dimDdcenter.z * dcenterDdposglob.z);
|
||||
// coord_world
|
||||
/*
|
||||
float3 dclosenessDdcoord_world =
|
||||
raydiff * FRCP(FMAX(draw_info.radius * length(raydiff), FEPS));
|
||||
float3 dintersection_depthDdcoord_world = -2.f * raydiff;
|
||||
*/
|
||||
float3 dimDdcoord_world = /* dcoeffDdcoord_world */
|
||||
dcoeffDdcloseness * raydiff *
|
||||
FRCP(FMAX(draw_info.radius * length(raydiff), FEPS)) -
|
||||
dcoeffDdintersection_depth * raydiff / p1__p2_safe;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcoeffDdcoord_world: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dimDdcoord_world.x,
|
||||
dimDdcoord_world.y,
|
||||
dimDdcoord_world.z);
|
||||
dimDdcoord_world *= dimDdcoeff;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdcoord_world: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dimDdcoord_world.x,
|
||||
dimDdcoord_world.y,
|
||||
dimDdcoord_world.z);
|
||||
// The third component of dimDdcoord_world is 0!
|
||||
PASSERT(dimDdcoord_world.z == 0.f);
|
||||
float3 coord_world = center - raydiff;
|
||||
coord_world.z = 0.f;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|coord_world: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
coord_world.x,
|
||||
coord_world.y,
|
||||
coord_world.z);
|
||||
// Do this component-wise to save unnecessary matmul steps.
|
||||
grad_cam->pixel_dir_x += dimDdcoord_world.x * cam.pixel_dir_x *
|
||||
coord_world.x * pixel_size_rcp * pixel_size_rcp;
|
||||
grad_cam->pixel_dir_x += dimDdcoord_world.y * cam.pixel_dir_x *
|
||||
coord_world.y * pixel_size_rcp * pixel_size_rcp;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdpixel_dir_x|coord_world: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
grad_cam->pixel_dir_x.x,
|
||||
grad_cam->pixel_dir_x.y,
|
||||
grad_cam->pixel_dir_x.z);
|
||||
// dcenterkDdpixel_dir_k.
|
||||
float3 center_in_pixels = draw_info.ray_center_norm *
|
||||
draw_info.t_center * pixel_size_rcp;
|
||||
grad_cam->pixel_dir_x += dimDdcenter.x *
|
||||
(center_in_pixels -
|
||||
outer_product_sum(cam.pixel_dir_x) * center_in_pixels *
|
||||
pixel_size_rcp * pixel_size_rcp);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcenter0dpixel_dir_x: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
(center_in_pixels -
|
||||
outer_product_sum(cam.pixel_dir_x) * center_in_pixels *
|
||||
pixel_size_rcp * pixel_size_rcp)
|
||||
.x,
|
||||
(center_in_pixels -
|
||||
outer_product_sum(cam.pixel_dir_x) * center_in_pixels *
|
||||
pixel_size_rcp * pixel_size_rcp)
|
||||
.y,
|
||||
(center_in_pixels -
|
||||
outer_product_sum(cam.pixel_dir_x) * center_in_pixels *
|
||||
pixel_size_rcp * pixel_size_rcp)
|
||||
.z);
|
||||
grad_cam->pixel_dir_y += dimDdcenter.y *
|
||||
(center_in_pixels -
|
||||
outer_product_sum(cam.pixel_dir_y) * center_in_pixels *
|
||||
pixel_size_rcp * pixel_size_rcp);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcenter1dpixel_dir_y: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
(center_in_pixels -
|
||||
outer_product_sum(cam.pixel_dir_y) * center_in_pixels *
|
||||
pixel_size_rcp * pixel_size_rcp)
|
||||
.x,
|
||||
(center_in_pixels -
|
||||
outer_product_sum(cam.pixel_dir_y) * center_in_pixels *
|
||||
pixel_size_rcp * pixel_size_rcp)
|
||||
.y,
|
||||
(center_in_pixels -
|
||||
outer_product_sum(cam.pixel_dir_y) * center_in_pixels *
|
||||
pixel_size_rcp * pixel_size_rcp)
|
||||
.z);
|
||||
// dcenterzDdpixel_dir_k.
|
||||
float sensordirz_norm_rcp = FRCP(
|
||||
FMAX(length(cross(cam.pixel_dir_y, cam.pixel_dir_x)), FEPS));
|
||||
grad_cam->pixel_dir_x += dimDdcenter.z *
|
||||
(dot(center, cam.sensor_dir_z) *
|
||||
cross(cam.pixel_dir_y, cam.sensor_dir_z) -
|
||||
cross(cam.pixel_dir_y, center)) *
|
||||
sensordirz_norm_rcp;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcenterzDdpixel_dir_x: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
((dot(center, cam.sensor_dir_z) *
|
||||
cross(cam.pixel_dir_y, cam.sensor_dir_z) -
|
||||
cross(cam.pixel_dir_y, center)) *
|
||||
sensordirz_norm_rcp)
|
||||
.x,
|
||||
((dot(center, cam.sensor_dir_z) *
|
||||
cross(cam.pixel_dir_y, cam.sensor_dir_z) -
|
||||
cross(cam.pixel_dir_y, center)) *
|
||||
sensordirz_norm_rcp)
|
||||
.y,
|
||||
((dot(center, cam.sensor_dir_z) *
|
||||
cross(cam.pixel_dir_y, cam.sensor_dir_z) -
|
||||
cross(cam.pixel_dir_y, center)) *
|
||||
sensordirz_norm_rcp)
|
||||
.z);
|
||||
grad_cam->pixel_dir_y += dimDdcenter.z *
|
||||
(dot(center, cam.sensor_dir_z) *
|
||||
cross(cam.pixel_dir_x, cam.sensor_dir_z) -
|
||||
cross(cam.pixel_dir_x, center)) *
|
||||
sensordirz_norm_rcp;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcenterzDdpixel_dir_y: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
((dot(center, cam.sensor_dir_z) *
|
||||
cross(cam.pixel_dir_x, cam.sensor_dir_z) -
|
||||
cross(cam.pixel_dir_x, center)) *
|
||||
sensordirz_norm_rcp)
|
||||
.x,
|
||||
((dot(center, cam.sensor_dir_z) *
|
||||
cross(cam.pixel_dir_x, cam.sensor_dir_z) -
|
||||
cross(cam.pixel_dir_x, center)) *
|
||||
sensordirz_norm_rcp)
|
||||
.y,
|
||||
((dot(center, cam.sensor_dir_z) *
|
||||
cross(cam.pixel_dir_x, cam.sensor_dir_z) -
|
||||
cross(cam.pixel_dir_x, center)) *
|
||||
sensordirz_norm_rcp)
|
||||
.z);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdpixel_dir_x: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
grad_cam->pixel_dir_x.x,
|
||||
grad_cam->pixel_dir_x.y,
|
||||
grad_cam->pixel_dir_x.z);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdpixel_dir_y: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
grad_cam->pixel_dir_y.x,
|
||||
grad_cam->pixel_dir_y.y,
|
||||
grad_cam->pixel_dir_y.z);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (calc_grad_rad) {
|
||||
//! Find dcoeffDdrad.
|
||||
float dcoeffDdrad =
|
||||
dcoeffDdcloseness * (closeness_world / radius_sq) -
|
||||
dcoeffDdintersection_depth * draw_info.radius / p1__p2_safe;
|
||||
PASSERT(isfinite(dcoeffDdrad));
|
||||
*grad_rad += FMUL(dimDdcoeff, dcoeffDdrad);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdrad: %.9f. dcoeffDdrad: %.9f.\n",
|
||||
idx,
|
||||
FMUL(dimDdcoeff, dcoeffDdrad),
|
||||
dcoeffDdrad);
|
||||
}
|
||||
if (calc_grad_pos || calc_grad_cam) {
|
||||
const float3 tmp1 = center - ray_dir_norm * o__p1_;
|
||||
const float3 tmp1n = tmp1 / p1__p2_safe;
|
||||
const float ray_dir_normDotRaydiff = dot(ray_dir_norm, raydiff);
|
||||
const float3 dcoeffDdray = dcoeffDdintersection_depth *
|
||||
(tmp1 - o__p1_ * tmp1n) / *norm_ray_dir +
|
||||
dcoeffDdcloseness *
|
||||
(ray_dir_norm * -ray_dir_normDotRaydiff + raydiff) /
|
||||
(closeness_world * draw_info.radius) *
|
||||
(draw_info.t_center / *norm_ray_dir);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcoeffDdray: %.9f, %.9f, %.9f. dimDdray: "
|
||||
"%.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dcoeffDdray.x,
|
||||
dcoeffDdray.y,
|
||||
dcoeffDdray.z,
|
||||
dimDdcoeff * dcoeffDdray.x,
|
||||
dimDdcoeff * dcoeffDdray.y,
|
||||
dimDdcoeff * dcoeffDdray.z);
|
||||
const float3 dcoeffDdcenter =
|
||||
dcoeffDdintersection_depth * (ray_dir_norm + tmp1n) +
|
||||
dcoeffDdcloseness *
|
||||
(draw_info.ray_center_norm * ray_dir_normDotRaydiff -
|
||||
raydiff) /
|
||||
(closeness_world * draw_info.radius);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dcoeffDdcenter: %.9f, %.9f, %.9f. "
|
||||
"dimDdcenter: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dcoeffDdcenter.x,
|
||||
dcoeffDdcenter.y,
|
||||
dcoeffDdcenter.z,
|
||||
dimDdcoeff * dcoeffDdcenter.x,
|
||||
dimDdcoeff * dcoeffDdcenter.y,
|
||||
dimDdcoeff * dcoeffDdcenter.z);
|
||||
if (calc_grad_pos) {
|
||||
*grad_pos += dimDdcoeff * dcoeffDdcenter;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdposglob: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dimDdcoeff * dcoeffDdcenter.x,
|
||||
dimDdcoeff * dcoeffDdcenter.y,
|
||||
dimDdcoeff * dcoeffDdcenter.z);
|
||||
}
|
||||
if (calc_grad_cam) {
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdeye: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
-dimDdcoeff * (dcoeffDdcenter.x + dcoeffDdray.x),
|
||||
-dimDdcoeff * (dcoeffDdcenter.y + dcoeffDdray.y),
|
||||
-dimDdcoeff * (dcoeffDdcenter.z + dcoeffDdray.z));
|
||||
grad_cam->cam_pos += -dimDdcoeff * (dcoeffDdcenter + dcoeffDdray);
|
||||
grad_cam->pixel_0_0_center += dimDdcoeff * dcoeffDdray;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdpixel00centerglob: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
dimDdcoeff * dcoeffDdray.x,
|
||||
dimDdcoeff * dcoeffDdray.y,
|
||||
dimDdcoeff * dcoeffDdray.z);
|
||||
grad_cam->pixel_dir_x +=
|
||||
(dimDdcoeff * static_cast<float>(coord_x)) * dcoeffDdray;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdpixel_dir_x: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
(dimDdcoeff * static_cast<float>(coord_x)) * dcoeffDdray.x,
|
||||
(dimDdcoeff * static_cast<float>(coord_x)) * dcoeffDdray.y,
|
||||
(dimDdcoeff * static_cast<float>(coord_x)) * dcoeffDdray.z);
|
||||
grad_cam->pixel_dir_y +=
|
||||
(dimDdcoeff * static_cast<float>(coord_y)) * dcoeffDdray;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_GRAD,
|
||||
"grad %u|dimDdpixel_dir_y: %.9f, %.9f, %.9f.\n",
|
||||
idx,
|
||||
(dimDdcoeff * static_cast<float>(coord_y)) * dcoeffDdray.x,
|
||||
(dimDdcoeff * static_cast<float>(coord_y)) * dcoeffDdray.y,
|
||||
(dimDdcoeff * static_cast<float>(coord_y)) * dcoeffDdray.z);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
55
pytorch3d/csrc/pulsar/include/renderer.fill_bg.device.h
Normal file
55
pytorch3d/csrc/pulsar/include/renderer.fill_bg.device.h
Normal file
@ -0,0 +1,55 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_FILL_BG_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_FILL_BG_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.h"
|
||||
#include "./commands.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
GLOBAL void fill_bg(
|
||||
Renderer renderer,
|
||||
const CamInfo cam,
|
||||
float const* const bg_col_d,
|
||||
const float gamma,
|
||||
const uint mode) {
|
||||
GET_PARALLEL_IDS_2D(coord_x, coord_y, cam.film_width, cam.film_height);
|
||||
int write_loc = coord_y * cam.film_width * (3 + 2 * renderer.n_track) +
|
||||
coord_x * (3 + 2 * renderer.n_track);
|
||||
if (renderer.forw_info_d[write_loc + 1] // sm_d
|
||||
== 0.f) {
|
||||
// This location has not been processed yet.
|
||||
// Write first the forw_info:
|
||||
// sm_m
|
||||
renderer.forw_info_d[write_loc] =
|
||||
cam.background_normalization_depth / gamma;
|
||||
// sm_d
|
||||
renderer.forw_info_d[write_loc + 1] = 1.f;
|
||||
// max_closest_possible_intersection_hit
|
||||
renderer.forw_info_d[write_loc + 2] = -1.f;
|
||||
// sphere IDs and intersection depths.
|
||||
for (int i = 0; i < renderer.n_track; ++i) {
|
||||
int sphere_id = -1;
|
||||
IASF(sphere_id, renderer.forw_info_d[write_loc + 3 + i * 2]);
|
||||
renderer.forw_info_d[write_loc + 3 + i * 2 + 1] = -1.f;
|
||||
}
|
||||
if (mode == 0) {
|
||||
// Image background.
|
||||
for (int i = 0; i < cam.n_channels; ++i) {
|
||||
renderer.result_d
|
||||
[coord_y * cam.film_width * cam.n_channels +
|
||||
coord_x * cam.n_channels + i] = bg_col_d[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
END_PARALLEL_2D_NORET();
|
||||
};
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
15
pytorch3d/csrc/pulsar/include/renderer.fill_bg.instantiate.h
Normal file
15
pytorch3d/csrc/pulsar/include/renderer.fill_bg.instantiate.h
Normal file
@ -0,0 +1,15 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./renderer.fill_bg.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template GLOBAL void fill_bg<ISONDEVICE>(
|
||||
Renderer renderer,
|
||||
const CamInfo norm,
|
||||
float const* const bg_col_d,
|
||||
const float gamma,
|
||||
const uint mode);
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
293
pytorch3d/csrc/pulsar/include/renderer.forward.device.h
Normal file
293
pytorch3d/csrc/pulsar/include/renderer.forward.device.h
Normal file
@ -0,0 +1,293 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_FORWARD_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_FORWARD_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.device.h"
|
||||
#include "./commands.h"
|
||||
#include "./math.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
void forward(
|
||||
Renderer* self,
|
||||
const float* vert_pos,
|
||||
const float* vert_col,
|
||||
const float* vert_rad,
|
||||
const CamInfo& cam,
|
||||
const float& gamma,
|
||||
float percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const float* bg_col_d,
|
||||
const float* opacity_d,
|
||||
const size_t& num_balls,
|
||||
const uint& mode,
|
||||
cudaStream_t stream) {
|
||||
ARGCHECK(gamma > 0.f && gamma <= 1.f, 6, "gamma must be in [0., 1.]");
|
||||
ARGCHECK(
|
||||
percent_allowed_difference >= 0.f && percent_allowed_difference <= 1.f,
|
||||
7,
|
||||
"percent_allowed_difference must be in [0., 1.]");
|
||||
ARGCHECK(max_n_hits >= 1u, 8, "max_n_hits must be >= 1");
|
||||
ARGCHECK(
|
||||
num_balls > 0 && num_balls <= self->max_num_balls,
|
||||
9,
|
||||
("num_balls must be >0 and <= max num balls! (" +
|
||||
std::to_string(num_balls) + " vs. " +
|
||||
std::to_string(self->max_num_balls) + ")")
|
||||
.c_str());
|
||||
ARGCHECK(
|
||||
cam.film_width == self->cam.film_width &&
|
||||
cam.film_height == self->cam.film_height,
|
||||
5,
|
||||
"cam result width and height must agree");
|
||||
ARGCHECK(mode <= 1, 10, "mode must be <= 1!");
|
||||
if (percent_allowed_difference > 1.f - FEPS) {
|
||||
LOG(WARNING) << "percent_allowed_difference > " << (1.f - FEPS)
|
||||
<< "! Clamping to " << (1.f - FEPS) << ".";
|
||||
percent_allowed_difference = 1.f - FEPS;
|
||||
}
|
||||
LOG_IF(INFO, PULSAR_LOG_RENDER) << "Rendering forward pass...";
|
||||
// Update camera and transform into a new virtual camera system with
|
||||
// centered principal point and subsection rendering.
|
||||
self->cam.eye = cam.eye;
|
||||
self->cam.pixel_0_0_center = cam.pixel_0_0_center - cam.eye;
|
||||
self->cam.pixel_dir_x = cam.pixel_dir_x;
|
||||
self->cam.pixel_dir_y = cam.pixel_dir_y;
|
||||
self->cam.sensor_dir_z = cam.sensor_dir_z;
|
||||
self->cam.half_pixel_size = cam.half_pixel_size;
|
||||
self->cam.focal_length = cam.focal_length;
|
||||
self->cam.aperture_width = cam.aperture_width;
|
||||
self->cam.aperture_height = cam.aperture_height;
|
||||
self->cam.min_dist = cam.min_dist;
|
||||
self->cam.max_dist = cam.max_dist;
|
||||
self->cam.norm_fac = cam.norm_fac;
|
||||
self->cam.principal_point_offset_x = cam.principal_point_offset_x;
|
||||
self->cam.principal_point_offset_y = cam.principal_point_offset_y;
|
||||
self->cam.film_border_left = cam.film_border_left;
|
||||
self->cam.film_border_top = cam.film_border_top;
|
||||
#ifdef PULSAR_TIMINGS_ENABLED
|
||||
START_TIME(calc_signature);
|
||||
#endif
|
||||
LAUNCH_MAX_PARALLEL_1D(
|
||||
calc_signature<DEV>,
|
||||
num_balls,
|
||||
stream,
|
||||
*self,
|
||||
reinterpret_cast<const float3*>(vert_pos),
|
||||
vert_col,
|
||||
vert_rad,
|
||||
num_balls);
|
||||
CHECKLAUNCH();
|
||||
#ifdef PULSAR_TIMINGS_ENABLED
|
||||
STOP_TIME(calc_signature);
|
||||
START_TIME(sort);
|
||||
#endif
|
||||
SORT_ASCENDING_WS(
|
||||
self->min_depth_d,
|
||||
self->min_depth_sorted_d,
|
||||
self->ids_d,
|
||||
self->ids_sorted_d,
|
||||
num_balls,
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
SORT_ASCENDING_WS(
|
||||
self->min_depth_d,
|
||||
self->min_depth_sorted_d,
|
||||
self->ii_d,
|
||||
self->ii_sorted_d,
|
||||
num_balls,
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
SORT_ASCENDING_WS(
|
||||
self->min_depth_d,
|
||||
self->min_depth_sorted_d,
|
||||
self->di_d,
|
||||
self->di_sorted_d,
|
||||
num_balls,
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
#ifdef PULSAR_TIMINGS_ENABLED
|
||||
STOP_TIME(sort);
|
||||
START_TIME(minmax);
|
||||
#endif
|
||||
IntersectInfo pixel_minmax;
|
||||
pixel_minmax.min.x = MAX_USHORT;
|
||||
pixel_minmax.min.y = MAX_USHORT;
|
||||
pixel_minmax.max.x = 0;
|
||||
pixel_minmax.max.y = 0;
|
||||
REDUCE_WS(
|
||||
self->ii_sorted_d,
|
||||
self->min_max_pixels_d,
|
||||
num_balls,
|
||||
IntersectInfoMinMax(),
|
||||
pixel_minmax,
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
COPY_DEV_HOST(&pixel_minmax, self->min_max_pixels_d, IntersectInfo, 1);
|
||||
LOG_IF(INFO, PULSAR_LOG_RENDER)
|
||||
<< "Region with pixels to render: " << pixel_minmax.min.x << ":"
|
||||
<< pixel_minmax.max.x << " (x), " << pixel_minmax.min.y << ":"
|
||||
<< pixel_minmax.max.y << " (y).";
|
||||
#ifdef PULSAR_TIMINGS_ENABLED
|
||||
STOP_TIME(minmax);
|
||||
START_TIME(render);
|
||||
#endif
|
||||
MEMSET(
|
||||
self->result_d,
|
||||
0,
|
||||
float,
|
||||
self->cam.film_width * self->cam.film_height * self->cam.n_channels,
|
||||
stream);
|
||||
MEMSET(
|
||||
self->forw_info_d,
|
||||
0,
|
||||
float,
|
||||
self->cam.film_width * self->cam.film_height * (3 + 2 * self->n_track),
|
||||
stream);
|
||||
if (pixel_minmax.max.y > pixel_minmax.min.y &&
|
||||
pixel_minmax.max.x > pixel_minmax.min.x) {
|
||||
PASSERT(
|
||||
pixel_minmax.min.x >= static_cast<ushort>(self->cam.film_border_left) &&
|
||||
pixel_minmax.min.x <
|
||||
static_cast<ushort>(
|
||||
self->cam.film_border_left + self->cam.film_width) &&
|
||||
pixel_minmax.max.x <=
|
||||
static_cast<ushort>(
|
||||
self->cam.film_border_left + self->cam.film_width) &&
|
||||
pixel_minmax.min.y >= static_cast<ushort>(self->cam.film_border_top) &&
|
||||
pixel_minmax.min.y <
|
||||
static_cast<ushort>(
|
||||
self->cam.film_border_top + self->cam.film_height) &&
|
||||
pixel_minmax.max.y <=
|
||||
static_cast<ushort>(
|
||||
self->cam.film_border_top + self->cam.film_height));
|
||||
// Cut the image in 3x3 regions.
|
||||
int y_step = RENDER_BLOCK_SIZE *
|
||||
iDivCeil(pixel_minmax.max.y - pixel_minmax.min.y,
|
||||
3u * RENDER_BLOCK_SIZE);
|
||||
int x_step = RENDER_BLOCK_SIZE *
|
||||
iDivCeil(pixel_minmax.max.x - pixel_minmax.min.x,
|
||||
3u * RENDER_BLOCK_SIZE);
|
||||
LOG_IF(INFO, PULSAR_LOG_RENDER) << "Using image slices of size " << x_step
|
||||
<< ", " << y_step << " (W, H).";
|
||||
for (int y_min = pixel_minmax.min.y; y_min < pixel_minmax.max.y;
|
||||
y_min += y_step) {
|
||||
for (int x_min = pixel_minmax.min.x; x_min < pixel_minmax.max.x;
|
||||
x_min += x_step) {
|
||||
// Create region selection.
|
||||
LAUNCH_MAX_PARALLEL_1D(
|
||||
create_selector<DEV>,
|
||||
num_balls,
|
||||
stream,
|
||||
self->ii_sorted_d,
|
||||
num_balls,
|
||||
x_min,
|
||||
x_min + x_step,
|
||||
y_min,
|
||||
y_min + y_step,
|
||||
self->region_flags_d);
|
||||
CHECKLAUNCH();
|
||||
SELECT_FLAGS_WS(
|
||||
self->region_flags_d,
|
||||
self->ii_sorted_d,
|
||||
self->ii_d,
|
||||
self->num_selected_d,
|
||||
num_balls,
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
SELECT_FLAGS_WS(
|
||||
self->region_flags_d,
|
||||
self->di_sorted_d,
|
||||
self->di_d,
|
||||
self->num_selected_d,
|
||||
num_balls,
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
SELECT_FLAGS_WS(
|
||||
self->region_flags_d,
|
||||
self->ids_sorted_d,
|
||||
self->ids_d,
|
||||
self->num_selected_d,
|
||||
num_balls,
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
LAUNCH_PARALLEL_2D(
|
||||
render<DEV>,
|
||||
x_step,
|
||||
y_step,
|
||||
RENDER_BLOCK_SIZE,
|
||||
RENDER_BLOCK_SIZE,
|
||||
stream,
|
||||
self->num_selected_d,
|
||||
self->ii_d,
|
||||
self->di_d,
|
||||
self->min_depth_d,
|
||||
self->ids_d,
|
||||
opacity_d,
|
||||
self->cam,
|
||||
gamma,
|
||||
percent_allowed_difference,
|
||||
max_n_hits,
|
||||
bg_col_d,
|
||||
mode,
|
||||
x_min,
|
||||
y_min,
|
||||
x_step,
|
||||
y_step,
|
||||
self->result_d,
|
||||
self->forw_info_d,
|
||||
self->n_track);
|
||||
CHECKLAUNCH();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (mode == 0) {
|
||||
LAUNCH_MAX_PARALLEL_2D(
|
||||
fill_bg<DEV>,
|
||||
static_cast<int64_t>(self->cam.film_width),
|
||||
static_cast<int64_t>(self->cam.film_height),
|
||||
stream,
|
||||
*self,
|
||||
self->cam,
|
||||
bg_col_d,
|
||||
gamma,
|
||||
mode);
|
||||
CHECKLAUNCH();
|
||||
}
|
||||
#ifdef PULSAR_TIMINGS_ENABLED
|
||||
STOP_TIME(render);
|
||||
float time_ms;
|
||||
// This blocks the result and prevents batch-processing from parallelizing.
|
||||
GET_TIME(calc_signature, &time_ms);
|
||||
std::cout << "Time for signature calculation: " << time_ms << " ms"
|
||||
<< std::endl;
|
||||
GET_TIME(sort, &time_ms);
|
||||
std::cout << "Time for sorting: " << time_ms << " ms" << std::endl;
|
||||
GET_TIME(minmax, &time_ms);
|
||||
std::cout << "Time for minmax pixel calculation: " << time_ms << " ms"
|
||||
<< std::endl;
|
||||
GET_TIME(render, &time_ms);
|
||||
std::cout << "Time for rendering: " << time_ms << " ms" << std::endl;
|
||||
#endif
|
||||
LOG_IF(INFO, PULSAR_LOG_RENDER) << "Forward pass complete.";
|
||||
}
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
23
pytorch3d/csrc/pulsar/include/renderer.forward.instantiate.h
Normal file
23
pytorch3d/csrc/pulsar/include/renderer.forward.instantiate.h
Normal file
@ -0,0 +1,23 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./renderer.forward.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template void forward<ISONDEVICE>(
|
||||
Renderer* self,
|
||||
const float* vert_pos,
|
||||
const float* vert_col,
|
||||
const float* vert_rad,
|
||||
const CamInfo& cam,
|
||||
const float& gamma,
|
||||
float percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const float* bg_col_d,
|
||||
const float* opacity_d,
|
||||
const size_t& num_balls,
|
||||
const uint& mode,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
137
pytorch3d/csrc/pulsar/include/renderer.get_screen_area.device.h
Normal file
137
pytorch3d/csrc/pulsar/include/renderer.get_screen_area.device.h
Normal file
@ -0,0 +1,137 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_GET_SCREEN_AREA_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_GET_SCREEN_AREA_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.device.h"
|
||||
#include "./commands.h"
|
||||
#include "./math.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
/**
|
||||
* Find the closest enclosing screen area rectangle in pixels that encloses a
|
||||
* ball.
|
||||
*
|
||||
* The method returns the two x and the two y values of the boundaries. They
|
||||
* are not ordered yet and you need to find min and max for the left/right and
|
||||
* lower/upper boundary.
|
||||
*
|
||||
* The return values are floats and need to be rounded appropriately.
|
||||
*/
|
||||
INLINE DEVICE bool get_screen_area(
|
||||
const float3& ball_center_cam,
|
||||
const float3& ray_center_norm,
|
||||
const float& vert_rad,
|
||||
const CamInfo& cam,
|
||||
const uint& idx,
|
||||
/* Out variables. */
|
||||
float* x_1,
|
||||
float* x_2,
|
||||
float* y_1,
|
||||
float* y_2) {
|
||||
float cos_alpha = dot(cam.sensor_dir_z, ray_center_norm);
|
||||
float2 o__c_, alpha, theta;
|
||||
if (cos_alpha < EPS) {
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|ball not visible. cos_alpha: %.9f.\n",
|
||||
idx,
|
||||
cos_alpha);
|
||||
// No intersection, ball won't be visible.
|
||||
return false;
|
||||
}
|
||||
// Multiply the direction vector with the camera rotation matrix
|
||||
// to have the optical axis being the canonical z vector (0, 0, 1).
|
||||
// TODO: optimize.
|
||||
const float3 ball_center_cam_rot = rotate(
|
||||
ball_center_cam,
|
||||
cam.pixel_dir_x / length(cam.pixel_dir_x),
|
||||
cam.pixel_dir_y / length(cam.pixel_dir_y),
|
||||
cam.sensor_dir_z);
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|ball_center_cam_rot: %f, %f, %f.\n",
|
||||
idx,
|
||||
ball_center_cam.x,
|
||||
ball_center_cam.y,
|
||||
ball_center_cam.z);
|
||||
const float pixel_size_norm_fac = FRCP(2.f * cam.half_pixel_size);
|
||||
const float optical_offset_x =
|
||||
(static_cast<float>(cam.aperture_width) - 1.f) * .5f;
|
||||
const float optical_offset_y =
|
||||
(static_cast<float>(cam.aperture_height) - 1.f) * .5f;
|
||||
if (cam.orthogonal_projection) {
|
||||
*x_1 =
|
||||
FMA(ball_center_cam_rot.x - vert_rad,
|
||||
pixel_size_norm_fac,
|
||||
optical_offset_x);
|
||||
*x_2 =
|
||||
FMA(ball_center_cam_rot.x + vert_rad,
|
||||
pixel_size_norm_fac,
|
||||
optical_offset_x);
|
||||
*y_1 =
|
||||
FMA(ball_center_cam_rot.y - vert_rad,
|
||||
pixel_size_norm_fac,
|
||||
optical_offset_y);
|
||||
*y_2 =
|
||||
FMA(ball_center_cam_rot.y + vert_rad,
|
||||
pixel_size_norm_fac,
|
||||
optical_offset_y);
|
||||
return true;
|
||||
} else {
|
||||
o__c_.x = FMAX(
|
||||
FSQRT(
|
||||
ball_center_cam_rot.x * ball_center_cam_rot.x +
|
||||
ball_center_cam_rot.z * ball_center_cam_rot.z),
|
||||
FEPS);
|
||||
o__c_.y = FMAX(
|
||||
FSQRT(
|
||||
ball_center_cam_rot.y * ball_center_cam_rot.y +
|
||||
ball_center_cam_rot.z * ball_center_cam_rot.z),
|
||||
FEPS);
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|o__c_: %f, %f.\n",
|
||||
idx,
|
||||
o__c_.x,
|
||||
o__c_.y);
|
||||
alpha.x = sign_dir(ball_center_cam_rot.x) *
|
||||
acos(FMIN(FMAX(ball_center_cam_rot.z / o__c_.x, -1.f), 1.f));
|
||||
alpha.y = -sign_dir(ball_center_cam_rot.y) *
|
||||
acos(FMIN(FMAX(ball_center_cam_rot.z / o__c_.y, -1.f), 1.f));
|
||||
theta.x = asin(FMIN(FMAX(vert_rad / o__c_.x, -1.f), 1.f));
|
||||
theta.y = asin(FMIN(FMAX(vert_rad / o__c_.y, -1.f), 1.f));
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|alpha.x: %f, alpha.y: %f, theta.x: %f, theta.y: %f.\n",
|
||||
idx,
|
||||
alpha.x,
|
||||
alpha.y,
|
||||
theta.x,
|
||||
theta.y);
|
||||
*x_1 = tan(alpha.x - theta.x) * cam.focal_length;
|
||||
*x_2 = tan(alpha.x + theta.x) * cam.focal_length;
|
||||
*y_1 = tan(alpha.y - theta.y) * cam.focal_length;
|
||||
*y_2 = tan(alpha.y + theta.y) * cam.focal_length;
|
||||
PULSAR_LOG_DEV(
|
||||
PULSAR_LOG_CALC_SIGNATURE,
|
||||
"signature %d|in sensor plane: x_1: %f, x_2: %f, y_1: %f, y_2: %f.\n",
|
||||
idx,
|
||||
*x_1,
|
||||
*x_2,
|
||||
*y_1,
|
||||
*y_2);
|
||||
*x_1 = FMA(*x_1, pixel_size_norm_fac, optical_offset_x);
|
||||
*x_2 = FMA(*x_2, pixel_size_norm_fac, optical_offset_x);
|
||||
*y_1 = FMA(*y_1, -pixel_size_norm_fac, optical_offset_y);
|
||||
*y_2 = FMA(*y_2, -pixel_size_norm_fac, optical_offset_y);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
461
pytorch3d/csrc/pulsar/include/renderer.h
Normal file
461
pytorch3d/csrc/pulsar/include/renderer.h
Normal file
@ -0,0 +1,461 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_H_
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
//! Remember to order struct members from larger size to smaller size
|
||||
//! to avoid padding (for more info, see for example here:
|
||||
//! http://www.catb.org/esr/structure-packing/).
|
||||
|
||||
/**
|
||||
* This is the information that's needed to do a fast screen point
|
||||
* intersection with one of the balls.
|
||||
*
|
||||
* Aim to keep this below 8 bytes (256 bytes per cache-line / 32 threads in a
|
||||
* warp = 8 bytes per thread).
|
||||
*/
|
||||
struct IntersectInfo {
|
||||
ushort2 min; /** minimum x, y in pixel coordinates. */
|
||||
ushort2 max; /** maximum x, y in pixel coordinates. */
|
||||
};
|
||||
static_assert(
|
||||
sizeof(IntersectInfo) == 8,
|
||||
"The compiled size of `IntersectInfo` is wrong.");
|
||||
|
||||
/**
|
||||
* Reduction operation to find the limits of multiple IntersectInfo objects.
|
||||
*/
|
||||
struct IntersectInfoMinMax {
|
||||
IHD IntersectInfo
|
||||
operator()(const IntersectInfo& a, const IntersectInfo& b) const {
|
||||
// Treat the special case of an invalid intersect info object or one for
|
||||
// a ball out of bounds.
|
||||
if (b.max.x == MAX_USHORT && b.min.x == MAX_USHORT &&
|
||||
b.max.y == MAX_USHORT && b.min.y == MAX_USHORT) {
|
||||
return a;
|
||||
}
|
||||
if (a.max.x == MAX_USHORT && a.min.x == MAX_USHORT &&
|
||||
a.max.y == MAX_USHORT && a.min.y == MAX_USHORT) {
|
||||
return b;
|
||||
}
|
||||
IntersectInfo result;
|
||||
result.min.x = std::min<ushort>(a.min.x, b.min.x);
|
||||
result.min.y = std::min<ushort>(a.min.y, b.min.y);
|
||||
result.max.x = std::max<ushort>(a.max.x, b.max.x);
|
||||
result.max.y = std::max<ushort>(a.max.y, b.max.y);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* All information that's needed to draw a ball.
|
||||
*
|
||||
* It's necessary to keep this information in float (not half) format,
|
||||
* because the loss in accuracy would be too high and lead to artifacts.
|
||||
*/
|
||||
struct DrawInfo {
|
||||
float3 ray_center_norm; /** Ray to the ball center, normalized. */
|
||||
/** Ball color.
|
||||
*
|
||||
* This might be the full color in the case of n_channels <= 3. Otherwise,
|
||||
* a pointer to the original 'color' data is stored in the following union.
|
||||
*/
|
||||
float first_color;
|
||||
union {
|
||||
float color[2];
|
||||
float* ptr;
|
||||
} color_union;
|
||||
float t_center; /** Distance from the camera to the ball center. */
|
||||
float radius; /** Ball radius. */
|
||||
};
|
||||
static_assert(
|
||||
sizeof(DrawInfo) == 8 * 4,
|
||||
"The compiled size of `DrawInfo` is wrong.");
|
||||
|
||||
/**
|
||||
* An object to collect all associated data with the renderer.
|
||||
*
|
||||
* The `_d` suffixed pointers point to memory 'on-device', potentially on the
|
||||
* GPU. All other variables are expected to point to CPU memory.
|
||||
*/
|
||||
struct Renderer {
|
||||
/** Dummy initializer to make sure all pointers are set to NULL to
|
||||
* be safe for the device-specific 'construct' and 'destruct' methods.
|
||||
*/
|
||||
inline Renderer() {
|
||||
max_num_balls = 0;
|
||||
result_d = NULL;
|
||||
min_depth_d = NULL;
|
||||
min_depth_sorted_d = NULL;
|
||||
ii_d = NULL;
|
||||
ii_sorted_d = NULL;
|
||||
ids_d = NULL;
|
||||
ids_sorted_d = NULL;
|
||||
workspace_d = NULL;
|
||||
di_d = NULL;
|
||||
di_sorted_d = NULL;
|
||||
region_flags_d = NULL;
|
||||
num_selected_d = NULL;
|
||||
forw_info_d = NULL;
|
||||
grad_pos_d = NULL;
|
||||
grad_col_d = NULL;
|
||||
grad_rad_d = NULL;
|
||||
grad_cam_d = NULL;
|
||||
grad_opy_d = NULL;
|
||||
grad_cam_buf_d = NULL;
|
||||
n_grad_contributions_d = NULL;
|
||||
};
|
||||
/** The camera for this renderer. In world-coordinates. */
|
||||
CamInfo cam;
|
||||
/**
|
||||
* The maximum amount of balls the renderer can handle. Resources are
|
||||
* pre-allocated to account for this size. Less than this amount of balls
|
||||
* can be rendered, but not more.
|
||||
*/
|
||||
int max_num_balls;
|
||||
/** The result buffer. */
|
||||
float* result_d;
|
||||
/** Closest possible intersection depth per sphere w.r.t. the camera. */
|
||||
float* min_depth_d;
|
||||
/** Closest possible intersection depth per sphere, ordered ascending. */
|
||||
float* min_depth_sorted_d;
|
||||
/** The intersect infos per sphere. */
|
||||
IntersectInfo* ii_d;
|
||||
/** The intersect infos per sphere, ordered by their closest possible
|
||||
* intersection depth (asc.). */
|
||||
IntersectInfo* ii_sorted_d;
|
||||
/** Original sphere IDs. */
|
||||
int* ids_d;
|
||||
/** Original sphere IDs, ordered by their closest possible intersection depth
|
||||
* (asc.). */
|
||||
int* ids_sorted_d;
|
||||
/** Workspace for CUB routines. */
|
||||
char* workspace_d;
|
||||
/** Workspace size for CUB routines. */
|
||||
size_t workspace_size;
|
||||
/** The draw information structures for each sphere. */
|
||||
DrawInfo* di_d;
|
||||
/** The draw information structures sorted by closest possible intersection
|
||||
* depth (asc.). */
|
||||
DrawInfo* di_sorted_d;
|
||||
/** Region association buffer. */
|
||||
char* region_flags_d;
|
||||
/** Num spheres in the current region. */
|
||||
size_t* num_selected_d;
|
||||
/** Pointer to information from the forward pass. */
|
||||
float* forw_info_d;
|
||||
/** Struct containing information about the min max pixels that contain
|
||||
* rendered information in the image. */
|
||||
IntersectInfo* min_max_pixels_d;
|
||||
/** Gradients w.r.t. position. */
|
||||
float3* grad_pos_d;
|
||||
/** Gradients w.r.t. color. */
|
||||
float* grad_col_d;
|
||||
/** Gradients w.r.t. radius. */
|
||||
float* grad_rad_d;
|
||||
/** Gradients w.r.t. camera parameters. */
|
||||
float* grad_cam_d;
|
||||
/** Gradients w.r.t. opacity. */
|
||||
float* grad_opy_d;
|
||||
/** Camera gradient information by sphere.
|
||||
*
|
||||
* Here, every sphere's contribution to the camera gradients is stored. It is
|
||||
* aggregated and written to grad_cam_d in a separate step. This avoids write
|
||||
* conflicts when processing the spheres.
|
||||
*/
|
||||
CamGradInfo* grad_cam_buf_d;
|
||||
/** Total of all gradient contributions for this image. */
|
||||
int* n_grad_contributions_d;
|
||||
/** The number of spheres to track for backpropagation. */
|
||||
int n_track;
|
||||
};
|
||||
|
||||
inline bool operator==(const Renderer& a, const Renderer& b) {
|
||||
return a.cam == b.cam && a.max_num_balls == b.max_num_balls;
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct a renderer.
|
||||
*/
|
||||
template <bool DEV>
|
||||
void construct(
|
||||
Renderer* self,
|
||||
const size_t& max_num_balls,
|
||||
const int& width,
|
||||
const int& height,
|
||||
const bool& orthogonal_projection,
|
||||
const bool& right_handed_system,
|
||||
const float& background_normalization_depth,
|
||||
const uint& n_channels,
|
||||
const uint& n_track);
|
||||
|
||||
/**
|
||||
* Destruct the renderer and free the associated memory.
|
||||
*/
|
||||
template <bool DEV>
|
||||
void destruct(Renderer* self);
|
||||
|
||||
/**
|
||||
* Create a selection of points inside a rectangle.
|
||||
*
|
||||
* This write boolen values into `region_flags_d', which can
|
||||
* for example be used by a CUB function to extract the selection.
|
||||
*/
|
||||
template <bool DEV>
|
||||
GLOBAL void create_selector(
|
||||
IntersectInfo const* const RESTRICT ii_sorted_d,
|
||||
const uint num_balls,
|
||||
const int min_x,
|
||||
const int max_x,
|
||||
const int min_y,
|
||||
const int max_y,
|
||||
/* Out variables. */
|
||||
char* RESTRICT region_flags_d);
|
||||
|
||||
/**
|
||||
* Calculate a signature for a ball.
|
||||
*
|
||||
* Populate the `ids_d`, `ii_d`, `di_d` and `min_depth_d` fields of the
|
||||
* renderer. For spheres not visible in the image, sets the id field to -1,
|
||||
* min_depth_d to MAX_FLOAT and the ii_d.min.x fields to MAX_USHORT.
|
||||
*/
|
||||
template <bool DEV>
|
||||
GLOBAL void calc_signature(
|
||||
Renderer renderer,
|
||||
float3 const* const RESTRICT vert_poss,
|
||||
float const* const RESTRICT vert_cols,
|
||||
float const* const RESTRICT vert_rads,
|
||||
const uint num_balls);
|
||||
|
||||
/**
|
||||
* The block size for rendering.
|
||||
*
|
||||
* This should be as large as possible, but is limited due to the amount
|
||||
* of variables we use and the memory required per thread.
|
||||
*/
|
||||
#define RENDER_BLOCK_SIZE 16
|
||||
/**
|
||||
* The buffer size of spheres to be loaded and analyzed for relevance.
|
||||
*
|
||||
* This must be at least RENDER_BLOCK_SIZE * RENDER_BLOCK_SIZE so that
|
||||
* for every iteration through the loading loop every thread could add a
|
||||
* 'hit' to the buffer.
|
||||
*/
|
||||
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE* RENDER_BLOCK_SIZE * 2
|
||||
/**
|
||||
* The threshold after which the spheres that are in the render buffer
|
||||
* are rendered and the buffer is flushed.
|
||||
*
|
||||
* Must be less than RENDER_BUFFER_SIZE.
|
||||
*/
|
||||
#define RENDER_BUFFER_LOAD_THRESH 16 * 4
|
||||
|
||||
/**
|
||||
* The render function.
|
||||
*
|
||||
* Assumptions:
|
||||
* * the focal length is appropriately chosen,
|
||||
* * ray_dir_norm.z is > EPS.
|
||||
* * to be completed...
|
||||
*/
|
||||
template <bool DEV>
|
||||
GLOBAL void render(
|
||||
size_t const* const RESTRICT
|
||||
num_balls, /** Number of balls relevant for this pass. */
|
||||
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
|
||||
DrawInfo const* const RESTRICT di_d, /** Draw information. */
|
||||
float const* const RESTRICT min_depth_d, /** Minimum depth per sphere. */
|
||||
int const* const RESTRICT id_d, /** IDs. */
|
||||
float const* const RESTRICT op_d, /** Opacity. */
|
||||
const CamInfo cam_norm, /** Camera normalized with all vectors to be in the
|
||||
* camera coordinate system.
|
||||
*/
|
||||
const float gamma, /** Transparency parameter. **/
|
||||
const float percent_allowed_difference, /** Maximum allowed
|
||||
error in color. */
|
||||
const uint max_n_hits,
|
||||
const float* bg_col_d,
|
||||
const uint mode,
|
||||
const int x_min,
|
||||
const int y_min,
|
||||
const int x_step,
|
||||
const int y_step,
|
||||
// Out variables.
|
||||
float* const RESTRICT result_d, /** The result image. */
|
||||
float* const RESTRICT forw_info_d, /** Additional information needed for the
|
||||
grad computation. */
|
||||
// Infrastructure.
|
||||
const int n_track /** The number of spheres to track. */
|
||||
);
|
||||
|
||||
/**
|
||||
* Makes sure to paint background information.
|
||||
*
|
||||
* This is required as a separate post-processing step because certain
|
||||
* pixels may not be processed during the forward pass if there is no
|
||||
* possibility for a sphere to be present at their location.
|
||||
*/
|
||||
template <bool DEV>
|
||||
GLOBAL void fill_bg(
|
||||
Renderer renderer,
|
||||
const CamInfo norm,
|
||||
float const* const bg_col_d,
|
||||
const float gamma,
|
||||
const uint mode);
|
||||
|
||||
/**
|
||||
* Rendering forward pass.
|
||||
*
|
||||
* Takes a renderer and sphere data as inputs and creates a rendering.
|
||||
*/
|
||||
template <bool DEV>
|
||||
void forward(
|
||||
Renderer* self,
|
||||
const float* vert_pos,
|
||||
const float* vert_col,
|
||||
const float* vert_rad,
|
||||
const CamInfo& cam,
|
||||
const float& gamma,
|
||||
float percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const float* bg_col_d,
|
||||
const float* opacity_d,
|
||||
const size_t& num_balls,
|
||||
const uint& mode,
|
||||
cudaStream_t stream);
|
||||
|
||||
/**
|
||||
* Normalize the camera gradients by the number of spheres that contributed.
|
||||
*/
|
||||
template <bool DEV>
|
||||
GLOBAL void norm_cam_gradients(Renderer renderer);
|
||||
|
||||
/**
|
||||
* Normalize the sphere gradients.
|
||||
*
|
||||
* We're assuming that the samples originate from a Monte Carlo
|
||||
* sampling process and normalize by number and sphere area.
|
||||
*/
|
||||
template <bool DEV>
|
||||
GLOBAL void norm_sphere_gradients(Renderer renderer, const int num_balls);
|
||||
|
||||
#define GRAD_BLOCK_SIZE 16
|
||||
/** Calculate the gradients.
|
||||
*/
|
||||
template <bool DEV>
|
||||
GLOBAL void calc_gradients(
|
||||
const CamInfo cam, /** Camera in world coordinates. */
|
||||
float const* const RESTRICT grad_im, /** The gradient image. */
|
||||
const float
|
||||
gamma, /** The transparency parameter used in the forward pass. */
|
||||
float3 const* const RESTRICT vert_poss, /** Vertex position vector. */
|
||||
float const* const RESTRICT vert_cols, /** Vertex color vector. */
|
||||
float const* const RESTRICT vert_rads, /** Vertex radius vector. */
|
||||
float const* const RESTRICT opacity, /** Vertex opacity. */
|
||||
const uint num_balls, /** Number of balls. */
|
||||
float const* const RESTRICT result_d, /** Result image. */
|
||||
float const* const RESTRICT forw_info_d, /** Forward pass info. */
|
||||
DrawInfo const* const RESTRICT di_d, /** Draw information. */
|
||||
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
|
||||
// Mode switches.
|
||||
const bool calc_grad_pos,
|
||||
const bool calc_grad_col,
|
||||
const bool calc_grad_rad,
|
||||
const bool calc_grad_cam,
|
||||
const bool calc_grad_opy,
|
||||
// Out variables.
|
||||
float* const RESTRICT grad_rad_d, /** Radius gradients. */
|
||||
float* const RESTRICT grad_col_d, /** Color gradients. */
|
||||
float3* const RESTRICT grad_pos_d, /** Position gradients. */
|
||||
CamGradInfo* const RESTRICT grad_cam_buf_d, /** Camera gradient buffer. */
|
||||
float* const RESTRICT grad_opy_d, /** Opacity gradient buffer. */
|
||||
int* const RESTRICT
|
||||
grad_contributed_d, /** Gradient contribution counter. */
|
||||
// Infrastructure.
|
||||
const int n_track,
|
||||
const uint offs_x = 0,
|
||||
const uint offs_y = 0);
|
||||
|
||||
/**
|
||||
* A full backward pass.
|
||||
*
|
||||
* Creates the gradients for the given gradient_image and the spheres.
|
||||
*/
|
||||
template <bool DEV>
|
||||
void backward(
|
||||
Renderer* self,
|
||||
const float* grad_im,
|
||||
const float* image,
|
||||
const float* forw_info,
|
||||
const float* vert_pos,
|
||||
const float* vert_col,
|
||||
const float* vert_rad,
|
||||
const CamInfo& cam,
|
||||
const float& gamma,
|
||||
float percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const float* vert_opy,
|
||||
const size_t& num_balls,
|
||||
const uint& mode,
|
||||
const bool& dif_pos,
|
||||
const bool& dif_col,
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
cudaStream_t stream);
|
||||
|
||||
/**
|
||||
* A debug backward pass.
|
||||
*
|
||||
* This is a function to debug the gradient calculation. It calculates the
|
||||
* gradients for exactly one pixel (set with pos_x and pos_y) without averaging.
|
||||
*
|
||||
* *Uses only the first sphere for camera gradient calculation!*
|
||||
*/
|
||||
template <bool DEV>
|
||||
void backward_dbg(
|
||||
Renderer* self,
|
||||
const float* grad_im,
|
||||
const float* image,
|
||||
const float* forw_info,
|
||||
const float* vert_pos,
|
||||
const float* vert_col,
|
||||
const float* vert_rad,
|
||||
const CamInfo& cam,
|
||||
const float& gamma,
|
||||
float percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const float* vert_opy,
|
||||
const size_t& num_balls,
|
||||
const uint& mode,
|
||||
const bool& dif_pos,
|
||||
const bool& dif_col,
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const uint& pos_x,
|
||||
const uint& pos_y,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <bool DEV>
|
||||
void nn(
|
||||
const float* ref_ptr,
|
||||
const float* tar_ptr,
|
||||
const uint& k,
|
||||
const uint& d,
|
||||
const uint& n,
|
||||
float* dist_ptr,
|
||||
int32_t* inds_ptr,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,28 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_NORM_CAM_GRADIENTS_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_NORM_CAM_GRADIENTS_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.device.h"
|
||||
#include "./commands.h"
|
||||
#include "./math.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
/**
|
||||
* Normalize the camera gradients by the number of spheres that contributed.
|
||||
*/
|
||||
template <bool DEV>
|
||||
GLOBAL void norm_cam_gradients(Renderer renderer) {
|
||||
GET_PARALLEL_IDX_1D(idx, 1);
|
||||
CamGradInfo* cgi = reinterpret_cast<CamGradInfo*>(renderer.grad_cam_d);
|
||||
*cgi = *cgi * FRCP(static_cast<float>(*renderer.n_grad_contributions_d));
|
||||
END_PARALLEL_NORET();
|
||||
};
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,10 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./renderer.norm_cam_gradients.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template GLOBAL void norm_cam_gradients<ISONDEVICE>(Renderer renderer);
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
@ -0,0 +1,68 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_NORM_SPHERE_GRADIENTS_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_NORM_SPHERE_GRADIENTS_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./commands.h"
|
||||
#include "./math.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
/**
|
||||
* Normalize the sphere gradients.
|
||||
*
|
||||
* We're assuming that the samples originate from a Monte Carlo
|
||||
* sampling process and normalize by number and sphere area.
|
||||
*/
|
||||
template <bool DEV>
|
||||
GLOBAL void norm_sphere_gradients(Renderer renderer, const int num_balls) {
|
||||
GET_PARALLEL_IDX_1D(idx, num_balls);
|
||||
float norm_fac = 0.f;
|
||||
IntersectInfo ii;
|
||||
if (renderer.ids_sorted_d[idx] > 0) {
|
||||
ii = renderer.ii_d[idx];
|
||||
// Normalize the sphere gradients as averages.
|
||||
// This avoids the case that there are small spheres in a scene with still
|
||||
// un-converged colors whereas the big spheres already converged, just
|
||||
// because their integrated learning rate is 'higher'.
|
||||
norm_fac = FRCP(static_cast<float>(renderer.ids_sorted_d[idx]));
|
||||
}
|
||||
PULSAR_LOG_DEV_NODE(
|
||||
PULSAR_LOG_NORMALIZE,
|
||||
"ids_sorted_d[idx]: %d, norm_fac: %.9f.\n",
|
||||
renderer.ids_sorted_d[idx],
|
||||
norm_fac);
|
||||
renderer.grad_rad_d[idx] *= norm_fac;
|
||||
for (uint c_idx = 0; c_idx < renderer.cam.n_channels; ++c_idx) {
|
||||
renderer.grad_col_d[idx * renderer.cam.n_channels + c_idx] *= norm_fac;
|
||||
}
|
||||
renderer.grad_pos_d[idx] *= norm_fac;
|
||||
renderer.grad_opy_d[idx] *= norm_fac;
|
||||
|
||||
if (renderer.ids_sorted_d[idx] > 0) {
|
||||
// For the camera, we need to be more correct and have the gradients
|
||||
// be proportional to the area they cover in the image.
|
||||
// This leads to a formulation very much like in monte carlo integration:
|
||||
norm_fac = FRCP(static_cast<float>(renderer.ids_sorted_d[idx])) *
|
||||
(static_cast<float>(ii.max.x) - static_cast<float>(ii.min.x)) *
|
||||
(static_cast<float>(ii.max.y) - static_cast<float>(ii.min.y)) *
|
||||
1e-3f; // for better numerics.
|
||||
}
|
||||
renderer.grad_cam_buf_d[idx].cam_pos *= norm_fac;
|
||||
renderer.grad_cam_buf_d[idx].pixel_0_0_center *= norm_fac;
|
||||
renderer.grad_cam_buf_d[idx].pixel_dir_x *= norm_fac;
|
||||
renderer.grad_cam_buf_d[idx].pixel_dir_y *= norm_fac;
|
||||
// The sphere only contributes to the camera gradients if it is
|
||||
// large enough in screen space.
|
||||
if (renderer.ids_sorted_d[idx] > 0 && ii.max.x >= ii.min.x + 3 &&
|
||||
ii.max.y >= ii.min.y + 3)
|
||||
renderer.ids_sorted_d[idx] = 1;
|
||||
END_PARALLEL_NORET();
|
||||
};
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
@ -0,0 +1,12 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./renderer.norm_sphere_gradients.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template GLOBAL void norm_sphere_gradients<ISONDEVICE>(
|
||||
Renderer renderer,
|
||||
const int num_balls);
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
409
pytorch3d/csrc/pulsar/include/renderer.render.device.h
Normal file
409
pytorch3d/csrc/pulsar/include/renderer.render.device.h
Normal file
@ -0,0 +1,409 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_DEVICE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_DEVICE_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "./camera.device.h"
|
||||
#include "./commands.h"
|
||||
#include "./math.h"
|
||||
#include "./renderer.h"
|
||||
|
||||
#include "./closest_sphere_tracker.device.h"
|
||||
#include "./renderer.draw.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
|
||||
template <bool DEV>
|
||||
GLOBAL void render(
|
||||
size_t const* const RESTRICT
|
||||
num_balls, /** Number of balls relevant for this pass. */
|
||||
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
|
||||
DrawInfo const* const RESTRICT di_d, /** Draw information. */
|
||||
float const* const RESTRICT min_depth_d, /** Minimum depth per sphere. */
|
||||
int const* const RESTRICT ids_d, /** IDs. */
|
||||
float const* const RESTRICT op_d, /** Opacity. */
|
||||
const CamInfo cam_norm, /** Camera normalized with all vectors to be in the
|
||||
* camera coordinate system.
|
||||
*/
|
||||
const float gamma, /** Transparency parameter. **/
|
||||
const float percent_allowed_difference, /** Maximum allowed
|
||||
error in color. */
|
||||
const uint max_n_hits,
|
||||
const float* bg_col,
|
||||
const uint mode,
|
||||
const int x_min,
|
||||
const int y_min,
|
||||
const int x_step,
|
||||
const int y_step,
|
||||
// Out variables.
|
||||
float* const RESTRICT result_d, /** The result image. */
|
||||
float* const RESTRICT forw_info_d, /** Additional information needed for the
|
||||
grad computation. */
|
||||
const int n_track /** The number of spheres to track for backprop. */
|
||||
) {
|
||||
// Do not early stop threads in this block here. They can all contribute to
|
||||
// the scanning process, we just have to prevent from writing their result.
|
||||
GET_PARALLEL_IDS_2D(offs_x, offs_y, x_step, y_step);
|
||||
// Variable declarations and const initializations.
|
||||
const float ln_pad_over_1minuspad =
|
||||
FLN(percent_allowed_difference / (1.f - percent_allowed_difference));
|
||||
/** A facility to track the closest spheres to the camera
|
||||
(in preparation for gradient calculation). */
|
||||
ClosestSphereTracker tracker(n_track);
|
||||
const uint coord_x = x_min + offs_x; /** Ray coordinate x. */
|
||||
const uint coord_y = y_min + offs_y; /** Ray coordinate y. */
|
||||
float3 ray_dir_norm; /** Ray cast through the pixel, normalized. */
|
||||
float2 projected_ray; /** Ray intersection with the sensor. */
|
||||
if (cam_norm.orthogonal_projection) {
|
||||
ray_dir_norm = cam_norm.sensor_dir_z;
|
||||
projected_ray.x = static_cast<float>(coord_x);
|
||||
projected_ray.y = static_cast<float>(coord_y);
|
||||
} else {
|
||||
ray_dir_norm = normalize(
|
||||
cam_norm.pixel_0_0_center + coord_x * cam_norm.pixel_dir_x +
|
||||
coord_y * cam_norm.pixel_dir_y);
|
||||
// This is a reasonable assumption for normal focal lengths and image sizes.
|
||||
PASSERT(FABS(ray_dir_norm.z) > FEPS);
|
||||
projected_ray.x = ray_dir_norm.x / ray_dir_norm.z * cam_norm.focal_length;
|
||||
projected_ray.y = ray_dir_norm.y / ray_dir_norm.z * cam_norm.focal_length;
|
||||
}
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_RENDER_PIX,
|
||||
"render|ray_dir_norm: %.9f, %.9f, %.9f. projected_ray: %.9f, %.9f.\n",
|
||||
ray_dir_norm.x,
|
||||
ray_dir_norm.y,
|
||||
ray_dir_norm.z,
|
||||
projected_ray.x,
|
||||
projected_ray.y);
|
||||
// Set up shared infrastructure.
|
||||
/** This entire thread block. */
|
||||
cg::thread_block thread_block = cg::this_thread_block();
|
||||
/** The collaborators within a warp. */
|
||||
cg::coalesced_group thread_warp = cg::coalesced_threads();
|
||||
/** The number of loaded balls in the load buffer di_l. */
|
||||
SHARED uint n_loaded;
|
||||
/** Draw information buffer. */
|
||||
SHARED DrawInfo di_l[RENDER_BUFFER_SIZE];
|
||||
/** The original sphere id of each loaded sphere. */
|
||||
SHARED uint sphere_id_l[RENDER_BUFFER_SIZE];
|
||||
/** The number of pixels in this block that are done. */
|
||||
SHARED int n_pixels_done;
|
||||
/** Whether loading of balls is completed. */
|
||||
SHARED bool loading_done;
|
||||
/** The number of balls loaded overall (just for statistics). */
|
||||
SHARED int n_balls_loaded;
|
||||
/** The area this thread block covers. */
|
||||
SHARED IntersectInfo block_area;
|
||||
if (thread_block.thread_rank() == 0) {
|
||||
// Initialize the shared variables.
|
||||
n_loaded = 0;
|
||||
block_area.min.x = static_cast<ushort>(coord_x);
|
||||
block_area.max.x = static_cast<ushort>(IMIN(
|
||||
coord_x + blockDim.x, cam_norm.film_border_left + cam_norm.film_width));
|
||||
block_area.min.y = static_cast<ushort>(coord_y);
|
||||
block_area.max.y = static_cast<ushort>(IMIN(
|
||||
coord_y + blockDim.y, cam_norm.film_border_top + cam_norm.film_height));
|
||||
n_pixels_done = 0;
|
||||
loading_done = false;
|
||||
n_balls_loaded = 0;
|
||||
}
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_RENDER_PIX,
|
||||
"render|block_area.min: %d, %d. block_area.max: %d, %d.\n",
|
||||
block_area.min.x,
|
||||
block_area.min.y,
|
||||
block_area.max.x,
|
||||
block_area.max.y);
|
||||
// Initialization of the pixel with the background color.
|
||||
/**
|
||||
* The result of this very pixel.
|
||||
* the offset calculation might overflow if this thread is out of
|
||||
* bounds of the film. However, in this case result is not
|
||||
* accessed, so this is fine.
|
||||
*/
|
||||
float* result = result_d +
|
||||
(coord_y - cam_norm.film_border_top) * cam_norm.film_width *
|
||||
cam_norm.n_channels +
|
||||
(coord_x - cam_norm.film_border_left) * cam_norm.n_channels;
|
||||
if (coord_x >= cam_norm.film_border_left &&
|
||||
coord_x < cam_norm.film_border_left + cam_norm.film_width &&
|
||||
coord_y >= cam_norm.film_border_top &&
|
||||
coord_y < cam_norm.film_border_top + cam_norm.film_height) {
|
||||
// Initialize the result.
|
||||
if (mode == 0u) {
|
||||
for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id)
|
||||
result[c_id] = bg_col[c_id];
|
||||
} else {
|
||||
result[0] = 0.f;
|
||||
}
|
||||
}
|
||||
/** Normalization denominator. */
|
||||
float sm_d = 1.f;
|
||||
/** Normalization tracker for stable softmax. The maximum observed value. */
|
||||
float sm_m = cam_norm.background_normalization_depth / gamma;
|
||||
/** Whether this pixel has had all information needed for drawing. */
|
||||
bool done =
|
||||
(coord_x < cam_norm.film_border_left ||
|
||||
coord_x >= cam_norm.film_border_left + cam_norm.film_width ||
|
||||
coord_y < cam_norm.film_border_top ||
|
||||
coord_y >= cam_norm.film_border_top + cam_norm.film_height);
|
||||
/** The depth threshold for a new point to have at least
|
||||
* `percent_allowed_difference` influence on the result color. All points that
|
||||
* are further away than this are ignored.
|
||||
*/
|
||||
float depth_threshold = done ? -1.f : MAX_FLOAT;
|
||||
/** The closest intersection possible of a ball that was hit by this pixel
|
||||
* ray. */
|
||||
float max_closest_possible_intersection_hit = -1.f;
|
||||
bool hit; /** Whether a sphere was hit. */
|
||||
float intersection_depth; /** The intersection_depth for a sphere at this
|
||||
pixel. */
|
||||
float closest_possible_intersection; /** The closest possible intersection
|
||||
for this sphere. */
|
||||
float max_closest_possible_intersection;
|
||||
// Sync up threads so that everyone is similarly initialized.
|
||||
thread_block.sync();
|
||||
//! Coalesced loading and intersection analysis of balls.
|
||||
for (uint ball_idx = thread_block.thread_rank();
|
||||
ball_idx < iDivCeil(static_cast<uint>(*num_balls), thread_block.size()) *
|
||||
thread_block.size() &&
|
||||
!loading_done && n_pixels_done < thread_block.size();
|
||||
ball_idx += thread_block.size()) {
|
||||
if (ball_idx < static_cast<uint>(*num_balls)) { // Account for overflow.
|
||||
const IntersectInfo& ii = ii_d[ball_idx];
|
||||
hit = (ii.min.x <= block_area.max.x) && (ii.max.x > block_area.min.x) &&
|
||||
(ii.min.y <= block_area.max.y) && (ii.max.y > block_area.min.y);
|
||||
if (hit) {
|
||||
uint write_idx = ATOMICADD_B(&n_loaded, 1u);
|
||||
di_l[write_idx] = di_d[ball_idx];
|
||||
sphere_id_l[write_idx] = static_cast<uint>(ids_d[ball_idx]);
|
||||
PULSAR_LOG_DEV_PIXB(
|
||||
PULSAR_LOG_RENDER_PIX,
|
||||
"render|found intersection with sphere %u.\n",
|
||||
sphere_id_l[write_idx]);
|
||||
}
|
||||
if (ii.min.x == MAX_USHORT)
|
||||
// This is an invalid sphere (out of image). These spheres have
|
||||
// maximum depth. Since we ordered the spheres by earliest possible
|
||||
// intersection depth we re certain that there will no other sphere
|
||||
// that is relevant after this one.
|
||||
loading_done = true;
|
||||
}
|
||||
// Reset n_pixels_done.
|
||||
n_pixels_done = 0;
|
||||
thread_block.sync(); // Make sure n_loaded is updated.
|
||||
if (n_loaded > RENDER_BUFFER_LOAD_THRESH) {
|
||||
// The load buffer is full enough. Draw.
|
||||
if (thread_block.thread_rank() == 0)
|
||||
n_balls_loaded += n_loaded;
|
||||
max_closest_possible_intersection = 0.f;
|
||||
// This excludes threads outside of the image boundary. Also, it reduces
|
||||
// block artifacts.
|
||||
if (!done) {
|
||||
for (uint draw_idx = 0; draw_idx < n_loaded; ++draw_idx) {
|
||||
intersection_depth = 0.f;
|
||||
if (cam_norm.orthogonal_projection) {
|
||||
// The closest possible intersection is the distance to the camera
|
||||
// plane.
|
||||
closest_possible_intersection = min_depth_d[sphere_id_l[draw_idx]];
|
||||
} else {
|
||||
closest_possible_intersection =
|
||||
di_l[draw_idx].t_center - di_l[draw_idx].radius;
|
||||
}
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_RENDER_PIX,
|
||||
"render|drawing sphere %u (depth: %f, "
|
||||
"closest possible intersection: %f).\n",
|
||||
sphere_id_l[draw_idx],
|
||||
di_l[draw_idx].t_center,
|
||||
closest_possible_intersection);
|
||||
hit = draw(
|
||||
di_l[draw_idx], // Sphere to draw.
|
||||
op_d == NULL ? 1.f : op_d[sphere_id_l[draw_idx]], // Opacity.
|
||||
cam_norm, // Cam.
|
||||
gamma, // Gamma.
|
||||
ray_dir_norm, // Ray direction.
|
||||
projected_ray, // Ray intersection with the image.
|
||||
// Mode switches.
|
||||
true, // Draw.
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false, // No gradients.
|
||||
// Position info.
|
||||
coord_x,
|
||||
coord_y,
|
||||
sphere_id_l[draw_idx],
|
||||
// Optional in variables.
|
||||
NULL, // intersect information.
|
||||
NULL, // ray_dir.
|
||||
NULL, // norm_ray_dir.
|
||||
NULL, // grad_pix.
|
||||
&ln_pad_over_1minuspad,
|
||||
// in/out variables
|
||||
&sm_d,
|
||||
&sm_m,
|
||||
result,
|
||||
// Optional out.
|
||||
&depth_threshold,
|
||||
&intersection_depth,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL // gradients.
|
||||
);
|
||||
if (hit) {
|
||||
max_closest_possible_intersection_hit = FMAX(
|
||||
max_closest_possible_intersection_hit,
|
||||
closest_possible_intersection);
|
||||
tracker.track(
|
||||
sphere_id_l[draw_idx], intersection_depth, coord_x, coord_y);
|
||||
}
|
||||
max_closest_possible_intersection = FMAX(
|
||||
max_closest_possible_intersection, closest_possible_intersection);
|
||||
}
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_RENDER_PIX,
|
||||
"render|max_closest_possible_intersection: %f, "
|
||||
"depth_threshold: %f.\n",
|
||||
max_closest_possible_intersection,
|
||||
depth_threshold);
|
||||
}
|
||||
done = done ||
|
||||
(percent_allowed_difference > 0.f &&
|
||||
max_closest_possible_intersection > depth_threshold) ||
|
||||
tracker.get_n_hits() >= max_n_hits;
|
||||
uint warp_done = thread_warp.ballot(done);
|
||||
if (thread_warp.thread_rank() == 0)
|
||||
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
|
||||
// This sync is necessary to keep n_loaded until all threads are done with
|
||||
// painting.
|
||||
thread_block.sync();
|
||||
n_loaded = 0;
|
||||
}
|
||||
thread_block.sync();
|
||||
}
|
||||
if (thread_block.thread_rank() == 0)
|
||||
n_balls_loaded += n_loaded;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_RENDER_PIX,
|
||||
"render|loaded %d balls in total.\n",
|
||||
n_balls_loaded);
|
||||
if (!done) {
|
||||
for (uint draw_idx = 0; draw_idx < n_loaded; ++draw_idx) {
|
||||
intersection_depth = 0.f;
|
||||
if (cam_norm.orthogonal_projection) {
|
||||
// The closest possible intersection is the distance to the camera
|
||||
// plane.
|
||||
closest_possible_intersection = min_depth_d[sphere_id_l[draw_idx]];
|
||||
} else {
|
||||
closest_possible_intersection =
|
||||
di_l[draw_idx].t_center - di_l[draw_idx].radius;
|
||||
}
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_RENDER_PIX,
|
||||
"render|drawing sphere %u (depth: %f, "
|
||||
"closest possible intersection: %f).\n",
|
||||
sphere_id_l[draw_idx],
|
||||
di_l[draw_idx].t_center,
|
||||
closest_possible_intersection);
|
||||
hit = draw(
|
||||
di_l[draw_idx], // Sphere to draw.
|
||||
op_d == NULL ? 1.f : op_d[sphere_id_l[draw_idx]], // Opacity.
|
||||
cam_norm, // Cam.
|
||||
gamma, // Gamma.
|
||||
ray_dir_norm, // Ray direction.
|
||||
projected_ray, // Ray intersection with the image.
|
||||
// Mode switches.
|
||||
true, // Draw.
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false, // No gradients.
|
||||
// Logging info.
|
||||
coord_x,
|
||||
coord_y,
|
||||
sphere_id_l[draw_idx],
|
||||
// Optional in variables.
|
||||
NULL, // intersect information.
|
||||
NULL, // ray_dir.
|
||||
NULL, // norm_ray_dir.
|
||||
NULL, // grad_pix.
|
||||
&ln_pad_over_1minuspad,
|
||||
// in/out variables
|
||||
&sm_d,
|
||||
&sm_m,
|
||||
result,
|
||||
// Optional out.
|
||||
&depth_threshold,
|
||||
&intersection_depth,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL,
|
||||
NULL // gradients.
|
||||
);
|
||||
if (hit) {
|
||||
max_closest_possible_intersection_hit = FMAX(
|
||||
max_closest_possible_intersection_hit,
|
||||
closest_possible_intersection);
|
||||
tracker.track(
|
||||
sphere_id_l[draw_idx], intersection_depth, coord_x, coord_y);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (coord_x < cam_norm.film_border_left ||
|
||||
coord_y < cam_norm.film_border_top ||
|
||||
coord_x >= cam_norm.film_border_left + cam_norm.film_width ||
|
||||
coord_y >= cam_norm.film_border_top + cam_norm.film_height) {
|
||||
RETURN_PARALLEL();
|
||||
}
|
||||
if (mode == 1u) {
|
||||
// The subtractions, for example coord_y - cam_norm.film_border_left, are
|
||||
// safe even though both components are uints. We checked their relation
|
||||
// just above.
|
||||
result_d
|
||||
[(coord_y - cam_norm.film_border_top) * cam_norm.film_width *
|
||||
cam_norm.n_channels +
|
||||
(coord_x - cam_norm.film_border_left) * cam_norm.n_channels] =
|
||||
static_cast<float>(tracker.get_n_hits());
|
||||
} else {
|
||||
float sm_d_normfac = FRCP(FMAX(sm_d, FEPS));
|
||||
for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id)
|
||||
result[c_id] *= sm_d_normfac;
|
||||
int write_loc = (coord_y - cam_norm.film_border_top) * cam_norm.film_width *
|
||||
(3 + 2 * n_track) +
|
||||
(coord_x - cam_norm.film_border_left) * (3 + 2 * n_track);
|
||||
forw_info_d[write_loc] = sm_m;
|
||||
forw_info_d[write_loc + 1] = sm_d;
|
||||
forw_info_d[write_loc + 2] = max_closest_possible_intersection_hit;
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_RENDER_PIX,
|
||||
"render|writing the %d most important ball infos.\n",
|
||||
IMIN(n_track, tracker.get_n_hits()));
|
||||
for (int i = 0; i < n_track; ++i) {
|
||||
int sphere_id = tracker.get_closest_sphere_id(i);
|
||||
IASF(sphere_id, forw_info_d[write_loc + 3 + i * 2]);
|
||||
forw_info_d[write_loc + 3 + i * 2 + 1] =
|
||||
tracker.get_closest_sphere_depth(i) == MAX_FLOAT
|
||||
? -1.f
|
||||
: tracker.get_closest_sphere_depth(i);
|
||||
PULSAR_LOG_DEV_PIX(
|
||||
PULSAR_LOG_RENDER_PIX,
|
||||
"render|writing %d most important: id: %d, normalized depth: %f.\n",
|
||||
i,
|
||||
tracker.get_closest_sphere_id(i),
|
||||
tracker.get_closest_sphere_depth(i));
|
||||
}
|
||||
}
|
||||
END_PARALLEL_2D();
|
||||
}
|
||||
|
||||
} // namespace Renderer
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
39
pytorch3d/csrc/pulsar/include/renderer.render.instantiate.h
Normal file
39
pytorch3d/csrc/pulsar/include/renderer.render.instantiate.h
Normal file
@ -0,0 +1,39 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_INSTANTIATE_H_
|
||||
#define PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_INSTANTIATE_H_
|
||||
|
||||
#include "./renderer.render.device.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace Renderer {
|
||||
template GLOBAL void render<ISONDEVICE>(
|
||||
size_t const* const RESTRICT
|
||||
num_balls, /** Number of balls relevant for this pass. */
|
||||
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
|
||||
DrawInfo const* const RESTRICT di_d, /** Draw information. */
|
||||
float const* const RESTRICT min_depth_d, /** Minimum depth per sphere. */
|
||||
int const* const RESTRICT id_d, /** IDs. */
|
||||
float const* const RESTRICT op_d, /** Opacity. */
|
||||
const CamInfo cam_norm, /** Camera normalized with all vectors to be in the
|
||||
* camera coordinate system.
|
||||
*/
|
||||
const float gamma, /** Transparency parameter. **/
|
||||
const float percent_allowed_difference, /** Maximum allowed
|
||||
error in color. */
|
||||
const uint max_n_hits,
|
||||
const float* bg_col_d,
|
||||
const uint mode,
|
||||
const int x_min,
|
||||
const int y_min,
|
||||
const int x_step,
|
||||
const int y_step,
|
||||
// Out variables.
|
||||
float* const RESTRICT result_d, /** The result image. */
|
||||
float* const RESTRICT forw_info_d, /** Additional information needed for the
|
||||
grad computation. */
|
||||
const int n_track /** The number of spheres to track for backprop. */
|
||||
);
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
108
pytorch3d/csrc/pulsar/logging.h
Normal file
108
pytorch3d/csrc/pulsar/logging.h
Normal file
@ -0,0 +1,108 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_LOGGING_H_
|
||||
#define PULSAR_LOGGING_H_
|
||||
|
||||
// #define PULSAR_LOGGING_ENABLED
|
||||
/**
|
||||
* Enable detailed per-operation timings.
|
||||
*
|
||||
* This timing scheme is not appropriate to measure batched calculations.
|
||||
* Use `PULSAR_TIMINGS_BATCHED_ENABLED` for that.
|
||||
*/
|
||||
// #define PULSAR_TIMINGS_ENABLED
|
||||
/**
|
||||
* Time batched operations.
|
||||
*/
|
||||
// #define PULSAR_TIMINGS_BATCHED_ENABLED
|
||||
#if defined(PULSAR_TIMINGS_BATCHED_ENABLED) && defined(PULSAR_TIMINGS_ENABLED)
|
||||
#pragma message("Pulsar|batched and unbatched timings enabled. This will not")
|
||||
#pragma message("Pulsar|create meaningful results.")
|
||||
#endif
|
||||
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
|
||||
// Control logging.
|
||||
// 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL (Abort after logging).
|
||||
#define CAFFE2_LOG_THRESHOLD 0
|
||||
#define PULSAR_LOG_INIT false
|
||||
#define PULSAR_LOG_FORWARD false
|
||||
#define PULSAR_LOG_CALC_SIGNATURE false
|
||||
#define PULSAR_LOG_RENDER false
|
||||
#define PULSAR_LOG_RENDER_PIX false
|
||||
#define PULSAR_LOG_RENDER_PIX_X 428
|
||||
#define PULSAR_LOG_RENDER_PIX_Y 669
|
||||
#define PULSAR_LOG_RENDER_PIX_ALL false
|
||||
#define PULSAR_LOG_TRACKER_PIX false
|
||||
#define PULSAR_LOG_TRACKER_PIX_X 428
|
||||
#define PULSAR_LOG_TRACKER_PIX_Y 669
|
||||
#define PULSAR_LOG_TRACKER_PIX_ALL false
|
||||
#define PULSAR_LOG_DRAW_PIX false
|
||||
#define PULSAR_LOG_DRAW_PIX_X 428
|
||||
#define PULSAR_LOG_DRAW_PIX_Y 669
|
||||
#define PULSAR_LOG_DRAW_PIX_ALL false
|
||||
#define PULSAR_LOG_BACKWARD false
|
||||
#define PULSAR_LOG_GRAD false
|
||||
#define PULSAR_LOG_GRAD_X 509
|
||||
#define PULSAR_LOG_GRAD_Y 489
|
||||
#define PULSAR_LOG_GRAD_ALL false
|
||||
#define PULSAR_LOG_NORMALIZE false
|
||||
#define PULSAR_LOG_NORMALIZE_X 0
|
||||
#define PULSAR_LOG_NORMALIZE_ALL false
|
||||
|
||||
#define PULSAR_LOG_DEV(ID, ...) \
|
||||
if ((ID)) { \
|
||||
printf(__VA_ARGS__); \
|
||||
}
|
||||
#define PULSAR_LOG_DEV_APIX(ID, MSG, ...) \
|
||||
if ((ID) && (film_coord_x == (ID##_X) && film_coord_y == (ID##_Y)) || \
|
||||
ID##_ALL) { \
|
||||
printf( \
|
||||
"%u %u (ap %u %u)|" MSG, \
|
||||
film_coord_x, \
|
||||
film_coord_y, \
|
||||
ap_coord_x, \
|
||||
ap_coord_y, \
|
||||
__VA_ARGS__); \
|
||||
}
|
||||
#define PULSAR_LOG_DEV_PIX(ID, MSG, ...) \
|
||||
if ((ID) && (coord_x == (ID##_X) && coord_y == (ID##_Y)) || ID##_ALL) { \
|
||||
printf("%u %u|" MSG, coord_x, coord_y, __VA_ARGS__); \
|
||||
}
|
||||
#ifdef __CUDACC__
|
||||
#define PULSAR_LOG_DEV_PIXB(ID, MSG, ...) \
|
||||
if ((ID) && static_cast<int>(block_area.min.x) <= (ID##_X) && \
|
||||
static_cast<int>(block_area.max.x) > (ID##_X) && \
|
||||
static_cast<int>(block_area.min.y) <= (ID##_Y) && \
|
||||
static_cast<int>(block_area.max.y) > (ID##_Y)) { \
|
||||
printf("%u %u|" MSG, coord_x, coord_y, __VA_ARGS__); \
|
||||
}
|
||||
#else
|
||||
#define PULSAR_LOG_DEV_PIXB(ID, MSG, ...) \
|
||||
if ((ID) && coord_x == (ID##_X) && coord_y == (ID##_Y)) { \
|
||||
printf("%u %u|" MSG, coord_x, coord_y, __VA_ARGS__); \
|
||||
}
|
||||
#endif
|
||||
#define PULSAR_LOG_DEV_NODE(ID, MSG, ...) \
|
||||
if ((ID) && idx == (ID##_X) || (ID##_ALL)) { \
|
||||
printf("%u|" MSG, idx, __VA_ARGS__); \
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#define CAFFE2_LOG_THRESHOLD 2
|
||||
|
||||
#define PULSAR_LOG_RENDER false
|
||||
#define PULSAR_LOG_INIT false
|
||||
#define PULSAR_LOG_FORWARD false
|
||||
#define PULSAR_LOG_BACKWARD false
|
||||
#define PULSAR_LOG_TRACKER_PIX false
|
||||
|
||||
#define PULSAR_LOG_DEV(...)
|
||||
#define PULSAR_LOG_DEV_APIX(...)
|
||||
#define PULSAR_LOG_DEV_PIX(...)
|
||||
#define PULSAR_LOG_DEV_PIXB(...)
|
||||
#define PULSAR_LOG_DEV_NODE(...)
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
63
pytorch3d/csrc/pulsar/pytorch/camera.cpp
Normal file
63
pytorch3d/csrc/pulsar/pytorch/camera.cpp
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./camera.h"
|
||||
#include "../include/math.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
CamInfo cam_info_from_params(
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& principal_point_offset,
|
||||
const float& focal_length,
|
||||
const uint& width,
|
||||
const uint& height,
|
||||
const float& min_dist,
|
||||
const float& max_dist,
|
||||
const bool& right_handed) {
|
||||
CamInfo res;
|
||||
fill_cam_vecs(
|
||||
cam_pos.detach().cpu(),
|
||||
pixel_0_0_center.detach().cpu(),
|
||||
pixel_vec_x.detach().cpu(),
|
||||
pixel_vec_y.detach().cpu(),
|
||||
principal_point_offset.detach().cpu(),
|
||||
right_handed,
|
||||
&res);
|
||||
res.half_pixel_size = 0.5f * length(res.pixel_dir_x);
|
||||
if (length(res.pixel_dir_y) * 0.5f - res.half_pixel_size > EPS) {
|
||||
throw std::runtime_error("Pixel sizes must agree in x and y direction!");
|
||||
}
|
||||
res.focal_length = focal_length;
|
||||
res.aperture_width =
|
||||
width + 2u * static_cast<uint>(abs(res.principal_point_offset_x));
|
||||
res.aperture_height =
|
||||
height + 2u * static_cast<uint>(abs(res.principal_point_offset_y));
|
||||
res.pixel_0_0_center -=
|
||||
res.pixel_dir_x * static_cast<float>(abs(res.principal_point_offset_x));
|
||||
res.pixel_0_0_center -=
|
||||
res.pixel_dir_y * static_cast<float>(abs(res.principal_point_offset_y));
|
||||
res.film_width = width;
|
||||
res.film_height = height;
|
||||
res.film_border_left =
|
||||
static_cast<uint>(std::max(0, 2 * res.principal_point_offset_x));
|
||||
res.film_border_top =
|
||||
static_cast<uint>(std::max(0, 2 * res.principal_point_offset_y));
|
||||
LOG_IF(INFO, PULSAR_LOG_INIT)
|
||||
<< "Aperture width, height: " << res.aperture_width << ", "
|
||||
<< res.aperture_height;
|
||||
LOG_IF(INFO, PULSAR_LOG_INIT)
|
||||
<< "Film width, height: " << res.film_width << ", " << res.film_height;
|
||||
LOG_IF(INFO, PULSAR_LOG_INIT)
|
||||
<< "Film border left, top: " << res.film_border_left << ", "
|
||||
<< res.film_border_top;
|
||||
res.min_dist = min_dist;
|
||||
res.max_dist = max_dist;
|
||||
res.norm_fac = 1.f / (max_dist - min_dist);
|
||||
return res;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
61
pytorch3d/csrc/pulsar/pytorch/camera.h
Normal file
61
pytorch3d/csrc/pulsar/pytorch/camera.h
Normal file
@ -0,0 +1,61 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_CAMERA_H_
|
||||
#define PULSAR_NATIVE_CAMERA_H_
|
||||
|
||||
#include <tuple>
|
||||
#include "../global.h"
|
||||
|
||||
#include "../include/camera.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
inline void fill_cam_vecs(
|
||||
const torch::Tensor& pos_vec,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_dir_x,
|
||||
const torch::Tensor& pixel_dir_y,
|
||||
const torch::Tensor& principal_point_offset,
|
||||
const bool& right_handed,
|
||||
CamInfo* res) {
|
||||
res->eye.x = pos_vec.data_ptr<float>()[0];
|
||||
res->eye.y = pos_vec.data_ptr<float>()[1];
|
||||
res->eye.z = pos_vec.data_ptr<float>()[2];
|
||||
res->pixel_0_0_center.x = pixel_0_0_center.data_ptr<float>()[0];
|
||||
res->pixel_0_0_center.y = pixel_0_0_center.data_ptr<float>()[1];
|
||||
res->pixel_0_0_center.z = pixel_0_0_center.data_ptr<float>()[2];
|
||||
res->pixel_dir_x.x = pixel_dir_x.data_ptr<float>()[0];
|
||||
res->pixel_dir_x.y = pixel_dir_x.data_ptr<float>()[1];
|
||||
res->pixel_dir_x.z = pixel_dir_x.data_ptr<float>()[2];
|
||||
res->pixel_dir_y.x = pixel_dir_y.data_ptr<float>()[0];
|
||||
res->pixel_dir_y.y = pixel_dir_y.data_ptr<float>()[1];
|
||||
res->pixel_dir_y.z = pixel_dir_y.data_ptr<float>()[2];
|
||||
auto sensor_dir_z = pixel_dir_y.cross(pixel_dir_x);
|
||||
sensor_dir_z /= sensor_dir_z.norm();
|
||||
if (right_handed) {
|
||||
sensor_dir_z *= -1.f;
|
||||
}
|
||||
res->sensor_dir_z.x = sensor_dir_z.data_ptr<float>()[0];
|
||||
res->sensor_dir_z.y = sensor_dir_z.data_ptr<float>()[1];
|
||||
res->sensor_dir_z.z = sensor_dir_z.data_ptr<float>()[2];
|
||||
res->principal_point_offset_x = principal_point_offset.data_ptr<int32_t>()[0];
|
||||
res->principal_point_offset_y = principal_point_offset.data_ptr<int32_t>()[1];
|
||||
}
|
||||
|
||||
CamInfo cam_info_from_params(
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& principal_point_offset,
|
||||
const float& focal_length,
|
||||
const uint& width,
|
||||
const uint& height,
|
||||
const float& min_dist,
|
||||
const float& max_dist,
|
||||
const bool& right_handed);
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
1481
pytorch3d/csrc/pulsar/pytorch/renderer.cpp
Normal file
1481
pytorch3d/csrc/pulsar/pytorch/renderer.cpp
Normal file
File diff suppressed because it is too large
Load Diff
167
pytorch3d/csrc/pulsar/pytorch/renderer.h
Normal file
167
pytorch3d/csrc/pulsar/pytorch/renderer.h
Normal file
@ -0,0 +1,167 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_PYTORCH_RENDERER_H_
|
||||
#define PULSAR_NATIVE_PYTORCH_RENDERER_H_
|
||||
|
||||
#include "../global.h"
|
||||
#include "../include/renderer.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
struct Renderer {
|
||||
public:
|
||||
/**
|
||||
* Pytorch Pulsar differentiable rendering module.
|
||||
*/
|
||||
explicit Renderer(
|
||||
const unsigned int& width,
|
||||
const unsigned int& height,
|
||||
const uint& max_n_balls,
|
||||
const bool& orthogonal_projection,
|
||||
const bool& right_handed_system,
|
||||
const float& background_normalization_depth,
|
||||
const uint& n_channels,
|
||||
const uint& n_track);
|
||||
~Renderer();
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> forward(
|
||||
const torch::Tensor& vert_pos,
|
||||
const torch::Tensor& vert_col,
|
||||
const torch::Tensor& vert_radii,
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& focal_length,
|
||||
const torch::Tensor& principal_point_offsets,
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
std::tuple<
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>,
|
||||
at::optional<torch::Tensor>>
|
||||
backward(
|
||||
const torch::Tensor& grad_im,
|
||||
const torch::Tensor& image,
|
||||
const torch::Tensor& forw_info,
|
||||
const torch::Tensor& vert_pos,
|
||||
const torch::Tensor& vert_col,
|
||||
const torch::Tensor& vert_radii,
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& focal_length,
|
||||
const torch::Tensor& principal_point_offsets,
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode,
|
||||
const bool& dif_pos,
|
||||
const bool& dif_col,
|
||||
const bool& dif_rad,
|
||||
const bool& dif_cam,
|
||||
const bool& dif_opy,
|
||||
const at::optional<std::pair<uint, uint>>& dbg_pos);
|
||||
|
||||
// Infrastructure.
|
||||
/**
|
||||
* Ensure that the renderer is placed on this device.
|
||||
* Is nearly a no-op if the device is correct.
|
||||
*/
|
||||
void ensure_on_device(torch::Device device, bool non_blocking = false);
|
||||
|
||||
/**
|
||||
* Ensure that at least n renderers are available.
|
||||
*/
|
||||
void ensure_n_renderers_gte(const size_t& batch_size);
|
||||
|
||||
/**
|
||||
* Check the parameters.
|
||||
*/
|
||||
std::tuple<size_t, size_t, bool, torch::Tensor> arg_check(
|
||||
const torch::Tensor& vert_pos,
|
||||
const torch::Tensor& vert_col,
|
||||
const torch::Tensor& vert_radii,
|
||||
const torch::Tensor& cam_pos,
|
||||
const torch::Tensor& pixel_0_0_center,
|
||||
const torch::Tensor& pixel_vec_x,
|
||||
const torch::Tensor& pixel_vec_y,
|
||||
const torch::Tensor& focal_length,
|
||||
const torch::Tensor& principal_point_offsets,
|
||||
const float& gamma,
|
||||
const float& max_depth,
|
||||
float& min_depth,
|
||||
const c10::optional<torch::Tensor>& bg_col,
|
||||
const c10::optional<torch::Tensor>& opacity,
|
||||
const float& percent_allowed_difference,
|
||||
const uint& max_n_hits,
|
||||
const uint& mode);
|
||||
|
||||
bool operator==(const Renderer& rhs) const;
|
||||
inline friend std::ostream& operator<<(
|
||||
std::ostream& stream,
|
||||
const Renderer& self) {
|
||||
stream << "pulsar::Renderer[";
|
||||
// Device info.
|
||||
stream << self.device_type;
|
||||
if (self.device_index != -1)
|
||||
stream << ", ID " << self.device_index;
|
||||
stream << "]";
|
||||
return stream;
|
||||
}
|
||||
|
||||
inline uint width() const {
|
||||
return this->renderer_vec[0].cam.film_width;
|
||||
}
|
||||
inline uint height() const {
|
||||
return this->renderer_vec[0].cam.film_height;
|
||||
}
|
||||
inline int max_num_balls() const {
|
||||
return this->renderer_vec[0].max_num_balls;
|
||||
}
|
||||
inline bool orthogonal() const {
|
||||
return this->renderer_vec[0].cam.orthogonal_projection;
|
||||
}
|
||||
inline bool right_handed() const {
|
||||
return this->renderer_vec[0].cam.right_handed;
|
||||
}
|
||||
inline uint n_track() const {
|
||||
return static_cast<uint>(this->renderer_vec[0].n_track);
|
||||
}
|
||||
|
||||
/** A tensor that is registered as a buffer with this Module to track its
|
||||
* device placement. Unfortunately, pytorch doesn't offer tracking Module
|
||||
* device placement in a better way as of now.
|
||||
*/
|
||||
torch::Tensor device_tracker;
|
||||
|
||||
protected:
|
||||
/** The device type for this renderer. */
|
||||
c10::DeviceType device_type;
|
||||
/** The device index for this renderer. */
|
||||
c10::DeviceIndex device_index;
|
||||
/** Pointer to the underlying pulsar renderers. */
|
||||
std::vector<pulsar::Renderer::Renderer> renderer_vec;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
48
pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp
Normal file
48
pytorch3d/csrc/pulsar/pytorch/tensor_util.cpp
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "./tensor_util.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
torch::Tensor sphere_ids_from_result_info_nograd(
|
||||
const torch::Tensor& forw_info) {
|
||||
torch::Tensor result = torch::zeros(
|
||||
{forw_info.size(0),
|
||||
forw_info.size(1),
|
||||
forw_info.size(2),
|
||||
(forw_info.size(3) - 3) / 2},
|
||||
torch::TensorOptions().device(forw_info.device()).dtype(torch::kInt32));
|
||||
// Get the relevant slice, contiguous.
|
||||
torch::Tensor tmp =
|
||||
forw_info
|
||||
.slice(
|
||||
/*dim=*/3, /*start=*/3, /*end=*/forw_info.size(3), /*step=*/2)
|
||||
.contiguous();
|
||||
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
||||
cudaMemcpyAsync(
|
||||
result.data_ptr(),
|
||||
tmp.data_ptr(),
|
||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||
tmp.size(3),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
} else {
|
||||
memcpy(
|
||||
result.data_ptr(),
|
||||
tmp.data_ptr(),
|
||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||
tmp.size(3));
|
||||
}
|
||||
// `tmp` is freed after this, the memory might get reallocated. However,
|
||||
// only kernels in the same stream should ever be able to write to this
|
||||
// memory, which are executed only after the memcpy is complete. That's
|
||||
// why we can just continue.
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
16
pytorch3d/csrc/pulsar/pytorch/tensor_util.h
Normal file
16
pytorch3d/csrc/pulsar/pytorch/tensor_util.h
Normal file
@ -0,0 +1,16 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_PYTORCH_TENSOR_UTIL_H_
|
||||
#define PULSAR_NATIVE_PYTORCH_TENSOR_UTIL_H_
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
torch::Tensor sphere_ids_from_result_info_nograd(
|
||||
const torch::Tensor& forw_info);
|
||||
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
24
pytorch3d/csrc/pulsar/pytorch/util.cpp
Normal file
24
pytorch3d/csrc/pulsar/pytorch/util.cpp
Normal file
@ -0,0 +1,24 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
void cudaDevToDev(
|
||||
void* trg,
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
|
||||
void cudaDevToHost(
|
||||
void* trg,
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream) {
|
||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
|
||||
}
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
59
pytorch3d/csrc/pulsar/pytorch/util.h
Normal file
59
pytorch3d/csrc/pulsar/pytorch/util.h
Normal file
@ -0,0 +1,59 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#ifndef PULSAR_NATIVE_PYTORCH_UTIL_H_
|
||||
#define PULSAR_NATIVE_PYTORCH_UTIL_H_
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include "../global.h"
|
||||
|
||||
namespace pulsar {
|
||||
namespace pytorch {
|
||||
|
||||
void cudaDevToDev(
|
||||
void* trg,
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream);
|
||||
void cudaDevToHost(
|
||||
void* trg,
|
||||
const void* src,
|
||||
const int& size,
|
||||
const cudaStream_t& stream);
|
||||
|
||||
/**
|
||||
* This method takes a memory pointer and wraps it into a pytorch tensor.
|
||||
*
|
||||
* This is preferred over `torch::from_blob`, since that requires a CUDA
|
||||
* managed pointer. However, working with these for high performance
|
||||
* operations is slower. Most of the rendering operations should stay
|
||||
* local to the respective GPU anyways, so unmanaged pointers are
|
||||
* preferred.
|
||||
*/
|
||||
template <typename T>
|
||||
torch::Tensor from_blob(
|
||||
const T* ptr,
|
||||
const torch::IntArrayRef& shape,
|
||||
const c10::DeviceType& device_type,
|
||||
const c10::DeviceIndex& device_index,
|
||||
const torch::Dtype& dtype,
|
||||
const cudaStream_t& stream) {
|
||||
torch::Tensor ret = torch::zeros(
|
||||
shape, torch::device({device_type, device_index}).dtype(dtype));
|
||||
const int num_elements =
|
||||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
|
||||
if (device_type == c10::DeviceType::CUDA) {
|
||||
cudaDevToDev(
|
||||
ret.data_ptr(),
|
||||
static_cast<const void*>(ptr),
|
||||
sizeof(T) * num_elements,
|
||||
stream);
|
||||
// TODO: check for synchronization.
|
||||
} else {
|
||||
memcpy(ret.data_ptr(), ptr, sizeof(T) * num_elements);
|
||||
}
|
||||
return ret;
|
||||
};
|
||||
|
||||
} // namespace pytorch
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
14
pytorch3d/csrc/pulsar/warnings.cpp
Normal file
14
pytorch3d/csrc/pulsar/warnings.cpp
Normal file
@ -0,0 +1,14 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#include "./global.h"
|
||||
#include "./logging.h"
|
||||
|
||||
/**
|
||||
* A compilation unit to provide warnings about the code and avoid
|
||||
* repeated messages.
|
||||
*/
|
||||
#ifdef PULSAR_ASSERTIONS
|
||||
#pragma message("WARNING: assertions are enabled in Pulsar.")
|
||||
#endif
|
||||
#ifdef PULSAR_LOGGING_ENABLED
|
||||
#pragma message("WARNING: logging is enabled in Pulsar.")
|
||||
#endif
|
2
pytorch3d/renderer/points/pulsar/__init__.py
Normal file
2
pytorch3d/renderer/points/pulsar/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from .renderer import Renderer # noqa: F401
|
692
pytorch3d/renderer/points/pulsar/renderer.py
Normal file
692
pytorch3d/renderer/points/pulsar/renderer.py
Normal file
@ -0,0 +1,692 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""pulsar renderer PyTorch integration.
|
||||
|
||||
Proper Python support for pytorch requires creating a torch.autograd.function
|
||||
(independent of whether this is being done within the C++ module). This is done
|
||||
here and a torch.nn.Module is exposed for the use in more complex models.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
# pyre-fixme[21]: Could not find a name `_C` defined in module `pytorch3d`.
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.transforms import axis_angle_to_matrix, rotation_6d_to_matrix
|
||||
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
GAMMA_WARNING_EMITTED = False
|
||||
AXANGLE_WARNING_EMITTED = False
|
||||
|
||||
|
||||
class _Render(torch.autograd.Function):
|
||||
"""
|
||||
Differentiable rendering function for the Pulsar renderer.
|
||||
|
||||
Usually this will be used through the `Renderer` module, which takes care of
|
||||
setting up the buffers and putting them on the correct device. If you use
|
||||
the function directly, you will have to do this manually.
|
||||
|
||||
The steps for this are two-fold: first, you need to create a native Renderer
|
||||
object to provide the required buffers. This is the `native_renderer` parameter
|
||||
for this function. You can create it by creating a `pytorch3d._C.PulsarRenderer`
|
||||
object (with parameters for width, height and maximum number of balls it should
|
||||
be able to render). This object by default resides on the CPU. If you want to
|
||||
shift the buffers to a different device, just assign an empty tensor on the target
|
||||
device to its property `device_tracker`.
|
||||
|
||||
To convert camera parameters from a more convenient representation to the
|
||||
required vectors as in this function, you can use the static
|
||||
function `pytorch3d.renderer.points.pulsar.Renderer._transform_cam_params`.
|
||||
|
||||
Args:
|
||||
* ctx: Pytorch context.
|
||||
* vert_pos: vertex positions. [Bx]Nx3 tensor of positions in 3D space.
|
||||
* vert_col: vertex colors. [Bx]NxK tensor of channels.
|
||||
* vert_rad: vertex radii. [Bx]N tensor of radiuses, >0.
|
||||
* cam_pos: camera position(s). [Bx]3 tensor in 3D coordinates.
|
||||
* pixel_0_0_center: [Bx]3 tensor center(s) of the upper left pixel(s) in
|
||||
world coordinates.
|
||||
* pixel_vec_x: [Bx]3 tensor from one pixel center to the next in image x
|
||||
direction in world coordinates.
|
||||
* pixel_vec_y: [Bx]3 tensor from one pixel center to the next in image y
|
||||
direction in world coordinates.
|
||||
* focal_length: [Bx]1 tensor of focal lengths in world coordinates.
|
||||
* principal_point_offsets: [Bx]2 tensor of principal point offsets in pixels.
|
||||
* gamma: sphere transparency in [1.,1E-5], with 1 being mostly transparent.
|
||||
[Bx]1.
|
||||
* max_depth: maximum depth for spheres to render. Set this as tighly
|
||||
as possible to have good numerical accuracy for gradients.
|
||||
* native_renderer: a `pytorch3d._C.PulsarRenderer` object.
|
||||
* min_depth: a float with the minimum depth a sphere must have to be renderer.
|
||||
Must be 0. or > max(focal_length).
|
||||
* bg_col: K tensor with a background color to use or None (uses all ones).
|
||||
* opacity: [Bx]N tensor of opacity values in [0., 1.] or None (uses all ones).
|
||||
* percent_allowed_difference: a float in [0., 1.[ with the maximum allowed
|
||||
difference in color space. This is used to speed up the
|
||||
computation. Default: 0.01.
|
||||
* max_n_hits: a hard limit on the number of hits per ray. Default: max int.
|
||||
* mode: render mode in {0, 1}. 0: render an image; 1: render the hit map.
|
||||
* return_forward_info: whether to return a second map. This second map contains
|
||||
13 channels: first channel contains sm_m (the maximum exponent factor
|
||||
observed), the second sm_d (the normalization denominator, the sum of all
|
||||
coefficients), the third the maximum closest possible intersection for a
|
||||
hit. The following channels alternate with the float encoded integer index
|
||||
of a sphere and its weight. They are the five spheres with the highest
|
||||
color contribution to this pixel color, ordered descending.
|
||||
|
||||
Returns:
|
||||
* image: [Bx]HxWxK float tensor with the resulting image.
|
||||
* forw_info: [Bx]HxWx13 float forward information as described above,
|
||||
if enabled.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
vert_pos,
|
||||
vert_col,
|
||||
vert_rad,
|
||||
cam_pos,
|
||||
pixel_0_0_center,
|
||||
pixel_vec_x,
|
||||
pixel_vec_y,
|
||||
focal_length,
|
||||
principal_point_offsets,
|
||||
gamma,
|
||||
max_depth,
|
||||
native_renderer,
|
||||
min_depth=0.0,
|
||||
bg_col=None,
|
||||
opacity=None,
|
||||
percent_allowed_difference=0.01,
|
||||
max_n_hits=_C.MAX_UINT,
|
||||
mode=0,
|
||||
return_forward_info=False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if mode != 0:
|
||||
assert not return_forward_info, (
|
||||
"You are using a non-standard rendering mode. This does "
|
||||
"not provide gradients, and also no `forward_info`. Please "
|
||||
"set `return_forward_info` to `False`."
|
||||
)
|
||||
ctx.gamma = gamma
|
||||
ctx.max_depth = max_depth
|
||||
ctx.min_depth = min_depth
|
||||
ctx.percent_allowed_difference = percent_allowed_difference
|
||||
ctx.max_n_hits = max_n_hits
|
||||
ctx.mode = mode
|
||||
ctx.native_renderer = native_renderer
|
||||
image, info = ctx.native_renderer.forward(
|
||||
vert_pos,
|
||||
vert_col,
|
||||
vert_rad,
|
||||
cam_pos,
|
||||
pixel_0_0_center,
|
||||
pixel_vec_x,
|
||||
pixel_vec_y,
|
||||
focal_length,
|
||||
principal_point_offsets,
|
||||
gamma,
|
||||
max_depth,
|
||||
min_depth,
|
||||
bg_col,
|
||||
opacity,
|
||||
percent_allowed_difference,
|
||||
max_n_hits,
|
||||
mode,
|
||||
)
|
||||
if mode != 0:
|
||||
# Backprop not possible!
|
||||
info = None
|
||||
# Prepare for backprop.
|
||||
ctx.save_for_backward(
|
||||
vert_pos,
|
||||
vert_col,
|
||||
vert_rad,
|
||||
cam_pos,
|
||||
pixel_0_0_center,
|
||||
pixel_vec_x,
|
||||
pixel_vec_y,
|
||||
focal_length,
|
||||
principal_point_offsets,
|
||||
bg_col,
|
||||
opacity,
|
||||
image,
|
||||
info,
|
||||
)
|
||||
if return_forward_info:
|
||||
return image, info
|
||||
else:
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_im, *args):
|
||||
global GAMMA_WARNING_EMITTED
|
||||
(
|
||||
vert_pos,
|
||||
vert_col,
|
||||
vert_rad,
|
||||
cam_pos,
|
||||
pixel_0_0_center,
|
||||
pixel_vec_x,
|
||||
pixel_vec_y,
|
||||
focal_length,
|
||||
principal_point_offsets,
|
||||
bg_col,
|
||||
opacity,
|
||||
image,
|
||||
info,
|
||||
) = ctx.saved_tensors
|
||||
if (
|
||||
(
|
||||
ctx.needs_input_grad[0]
|
||||
or ctx.needs_input_grad[2]
|
||||
or ctx.needs_input_grad[3]
|
||||
or ctx.needs_input_grad[4]
|
||||
or ctx.needs_input_grad[5]
|
||||
or ctx.needs_input_grad[6]
|
||||
or ctx.needs_input_grad[7]
|
||||
)
|
||||
and ctx.gamma < 1e-3
|
||||
and not GAMMA_WARNING_EMITTED
|
||||
):
|
||||
warnings.warn(
|
||||
"Optimizing for non-color parameters and having a gamma value < 1E-3! "
|
||||
"This is probably not going to produce usable gradients."
|
||||
)
|
||||
GAMMA_WARNING_EMITTED = True
|
||||
if ctx.mode == 0:
|
||||
(
|
||||
grad_pos,
|
||||
grad_col,
|
||||
grad_rad,
|
||||
grad_cam_pos,
|
||||
grad_pixel_0_0_center,
|
||||
grad_pixel_vec_x,
|
||||
grad_pixel_vec_y,
|
||||
grad_opacity,
|
||||
) = ctx.native_renderer.backward(
|
||||
grad_im,
|
||||
image,
|
||||
info,
|
||||
vert_pos,
|
||||
vert_col,
|
||||
vert_rad,
|
||||
cam_pos,
|
||||
pixel_0_0_center,
|
||||
pixel_vec_x,
|
||||
pixel_vec_y,
|
||||
focal_length,
|
||||
principal_point_offsets,
|
||||
ctx.gamma,
|
||||
ctx.max_depth,
|
||||
ctx.min_depth,
|
||||
bg_col,
|
||||
opacity,
|
||||
ctx.percent_allowed_difference,
|
||||
ctx.max_n_hits,
|
||||
ctx.mode,
|
||||
ctx.needs_input_grad[0],
|
||||
ctx.needs_input_grad[1],
|
||||
ctx.needs_input_grad[2],
|
||||
ctx.needs_input_grad[3]
|
||||
or ctx.needs_input_grad[4]
|
||||
or ctx.needs_input_grad[5]
|
||||
or ctx.needs_input_grad[6],
|
||||
ctx.needs_input_grad[13],
|
||||
None, # No debug information provided.
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Performing a backward pass for a "
|
||||
"rendering with `mode != 0`! This is not possible."
|
||||
)
|
||||
return (
|
||||
grad_pos,
|
||||
grad_col,
|
||||
grad_rad,
|
||||
grad_cam_pos,
|
||||
grad_pixel_0_0_center,
|
||||
grad_pixel_vec_x,
|
||||
grad_pixel_vec_y,
|
||||
None, # focal_length
|
||||
None, # principal_point_offsets
|
||||
None, # gamma
|
||||
None, # max_depth
|
||||
None, # native_renderer
|
||||
None, # min_depth
|
||||
None, # bg_col
|
||||
grad_opacity,
|
||||
None, # percent_allowed_difference
|
||||
None, # max_n_hits
|
||||
None, # mode
|
||||
None, # return_forward_info
|
||||
)
|
||||
|
||||
|
||||
class Renderer(torch.nn.Module):
|
||||
"""
|
||||
Differentiable rendering module for the Pulsar renderer.
|
||||
|
||||
Set the maximum number of balls to a reasonable value. It is used to determine
|
||||
several buffer sizes. It is no problem to render less balls than this number,
|
||||
but never more.
|
||||
|
||||
When optimizing for sphere positions, sphere radiuses or camera parameters you
|
||||
have to use higher gamma values (closer to one) and larger sphere sizes: spheres
|
||||
can only 'move' to areas that they cover, and only with higher gamma values exists
|
||||
a gradient w.r.t. their color depending on their position.
|
||||
|
||||
Args:
|
||||
* width: result image width in pixels.
|
||||
* height: result image height in pixels.
|
||||
* max_num_balls: the maximum number of balls this renderer will handle.
|
||||
* orthogonal_projection: use an orthogonal instead of perspective projection.
|
||||
Default: False.
|
||||
* right_handed_system: use a right-handed instead of a left-handed coordinate
|
||||
system. This is relevant for compatibility with other drawing or scanning
|
||||
systems. Pulsar by default assumes a left-handed world and camera coordinate
|
||||
system as known from mathematics with x-axis to the right, y axis up and z
|
||||
axis for increasing depth along the optical axis. In the image coordinate
|
||||
system, only the y axis is pointing down, leading still to a left-handed
|
||||
system. If you set this to True, it is assuming a right-handed world and
|
||||
camera coordinate system with x axis to the right, y axis to the top and
|
||||
z axis decreasing along the optical axis. Again, the image coordinate
|
||||
system has a flipped y axis, remaining a right-handed system.
|
||||
Default: False.
|
||||
* background_normalized_depth: the normalized depth the background is placed
|
||||
at.
|
||||
This is on a scale from 0. to 1. between the specified min and max depth
|
||||
(see the forward function). The value 0. is the most furthest depth whereas
|
||||
1. is the closest. Be careful when setting the background too far front - it
|
||||
may hide elements in your scene. Default: EPS.
|
||||
* n_channels: the number of image content channels to use. This is usually three
|
||||
for regular color representations, but can be a higher or lower number.
|
||||
Default: 3.
|
||||
* n_track: the number of spheres to track for gradient calculation per pixel.
|
||||
Only the closest n_track spheres will receive gradients. Default: 5.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
height: int,
|
||||
max_num_balls: int,
|
||||
orthogonal_projection: bool = False,
|
||||
right_handed_system: bool = False,
|
||||
background_normalized_depth: float = _C.EPS,
|
||||
n_channels: int = 3,
|
||||
n_track: int = 5,
|
||||
):
|
||||
super(Renderer, self).__init__()
|
||||
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
|
||||
self._renderer = _C.PulsarRenderer(
|
||||
width,
|
||||
height,
|
||||
max_num_balls,
|
||||
orthogonal_projection,
|
||||
right_handed_system,
|
||||
background_normalized_depth,
|
||||
n_channels,
|
||||
n_track,
|
||||
)
|
||||
self.register_buffer("device_tracker", torch.zeros(1))
|
||||
|
||||
@staticmethod
|
||||
def sphere_ids_from_result_info_nograd(result_info: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get the sphere IDs from a result info tensor.
|
||||
"""
|
||||
if result_info.ndim == 3:
|
||||
return Renderer.sphere_ids_from_result_info_nograd(result_info[None, ...])
|
||||
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
|
||||
return _C.pulsar_sphere_ids_from_result_info_nograd(result_info)
|
||||
|
||||
@staticmethod
|
||||
def depth_map_from_result_info_nograd(result_info: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Get the depth map from a result info tensor.
|
||||
|
||||
This returns a map of the same size as the image with just one channel
|
||||
containing the closest intersection value at that position. Gradients
|
||||
are not available for this tensor, but do note that you can use
|
||||
`sphere_ids_from_result_info_nograd` to get the IDs of the spheres at
|
||||
each position and directly create a loss on their depth if required.
|
||||
|
||||
The depth map contains -1. at positions where no intersection has
|
||||
been detected.
|
||||
"""
|
||||
return result_info[..., 4]
|
||||
|
||||
@staticmethod
|
||||
def _transform_cam_params(
|
||||
cam_params: torch.Tensor,
|
||||
width: int,
|
||||
height: int,
|
||||
orthogonal: bool,
|
||||
right_handed: bool,
|
||||
) -> Tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
"""
|
||||
Transform 8 component camera parameter vector(s) to the internal camera
|
||||
representation.
|
||||
|
||||
The input vectors consists of:
|
||||
* 3 components for camera position,
|
||||
* 3 components for camera rotation (three rotation angles) or
|
||||
6 components as described in "On the Continuity of Rotation
|
||||
Representations in Neural Networks" (Zhou et al.),
|
||||
* focal length,
|
||||
* the sensor width in world coordinates,
|
||||
* [optional] the principal point offset in x and y.
|
||||
|
||||
The sensor height is inferred by pixel size and sensor width to obtain
|
||||
quadratic pixels.
|
||||
|
||||
Args:
|
||||
* cam_params: [Bx]{8, 10, 11, 13}, input tensors as described above.
|
||||
* width: number of pixels in x direction.
|
||||
* height: number of pixels in y direction.
|
||||
* orthogonal: bool, whether an orthogonal projection is used
|
||||
(does not use focal length).
|
||||
* right_handed: bool, whether to use a right handed system
|
||||
(negative z in camera direction).
|
||||
|
||||
Returns:
|
||||
* pos_vec: the position vector in 3D,
|
||||
* pixel_0_0_center: the center of the upper left pixel in world coordinates,
|
||||
* pixel_vec_x: the step to move one pixel on the image x axis
|
||||
in world coordinates,
|
||||
* pixel_vec_y: the step to move one pixel on the image y axis
|
||||
in world coordinates,
|
||||
* focal_length: the focal lengths,
|
||||
* principal_point_offsets: the principal point offsets in x, y.
|
||||
"""
|
||||
global AXANGLE_WARNING_EMITTED
|
||||
# Set up all direction vectors, i.e., the sensor direction of all axes.
|
||||
assert width > 0
|
||||
assert height > 0
|
||||
batch_processing = True
|
||||
if cam_params.ndimension() == 1:
|
||||
batch_processing = False
|
||||
cam_params = cam_params[None, :]
|
||||
batch_size = cam_params.size(0)
|
||||
continuous_rep = True
|
||||
if cam_params.shape[1] in [8, 10]:
|
||||
if cam_params.requires_grad and not AXANGLE_WARNING_EMITTED:
|
||||
warnings.warn(
|
||||
"Using an axis angle representation for camera rotations. "
|
||||
"This has discontinuities and should not be used for optimization. "
|
||||
"Alternatively, use a six-component representation as described in "
|
||||
"'On the Continuity of Rotation Representations in Neural Networks'"
|
||||
" (Zhou et al.). "
|
||||
"The `pytorch3d.transforms` module provides "
|
||||
"facilities for using this representation."
|
||||
)
|
||||
AXANGLE_WARNING_EMITTED = True
|
||||
continuous_rep = False
|
||||
else:
|
||||
assert cam_params.shape[1] in [11, 13]
|
||||
pos_vec: torch.Tensor = cam_params[:, :3]
|
||||
principal_point_offsets: torch.Tensor = torch.zeros(
|
||||
(cam_params.shape[0], 2), dtype=torch.int32, device=cam_params.device
|
||||
)
|
||||
if continuous_rep:
|
||||
rot_vec = cam_params[:, 3:9]
|
||||
focal_length: torch.Tensor = cam_params[:, 9:10]
|
||||
sensor_size_x = cam_params[:, 10:11]
|
||||
if cam_params.shape[1] == 13:
|
||||
principal_point_offsets: torch.Tensor = cam_params[:, 11:13].to(
|
||||
torch.int32
|
||||
)
|
||||
else:
|
||||
rot_vec = cam_params[:, 3:6]
|
||||
focal_length: torch.Tensor = cam_params[:, 6:7]
|
||||
sensor_size_x = cam_params[:, 7:8]
|
||||
if cam_params.shape[1] == 10:
|
||||
principal_point_offsets: torch.Tensor = cam_params[:, 8:10].to(
|
||||
torch.int32
|
||||
)
|
||||
# Always get quadratic pixels.
|
||||
pixel_size_x = sensor_size_x / float(width)
|
||||
sensor_size_y = height * pixel_size_x
|
||||
LOGGER.debug(
|
||||
"Camera position: %s, rotation: %s. Focal length: %s.",
|
||||
str(pos_vec),
|
||||
str(rot_vec),
|
||||
str(focal_length),
|
||||
)
|
||||
if continuous_rep:
|
||||
rot_mat = rotation_6d_to_matrix(rot_vec)
|
||||
else:
|
||||
rot_mat = axis_angle_to_matrix(rot_vec)
|
||||
sensor_dir_x = torch.matmul(
|
||||
rot_mat,
|
||||
torch.tensor(
|
||||
[1.0, 0.0, 0.0], dtype=torch.float32, device=rot_mat.device
|
||||
).repeat(batch_size, 1)[:, :, None],
|
||||
)[:, :, 0]
|
||||
sensor_dir_y = torch.matmul(
|
||||
rot_mat,
|
||||
torch.tensor(
|
||||
[0.0, -1.0, 0.0], dtype=torch.float32, device=rot_mat.device
|
||||
).repeat(batch_size, 1)[:, :, None],
|
||||
)[:, :, 0]
|
||||
sensor_dir_z = torch.matmul(
|
||||
rot_mat,
|
||||
torch.tensor(
|
||||
[0.0, 0.0, 1.0], dtype=torch.float32, device=rot_mat.device
|
||||
).repeat(batch_size, 1)[:, :, None],
|
||||
)[:, :, 0]
|
||||
if right_handed:
|
||||
sensor_dir_z *= -1
|
||||
LOGGER.debug(
|
||||
"Sensor direction vectors: %s, %s, %s.",
|
||||
str(sensor_dir_x),
|
||||
str(sensor_dir_y),
|
||||
str(sensor_dir_z),
|
||||
)
|
||||
if orthogonal:
|
||||
sensor_center = pos_vec
|
||||
else:
|
||||
sensor_center = pos_vec + focal_length * sensor_dir_z
|
||||
LOGGER.debug("Sensor center: %s.", str(sensor_center))
|
||||
sensor_luc = ( # Sensor left upper corner.
|
||||
sensor_center
|
||||
- sensor_dir_x * (sensor_size_x / 2.0)
|
||||
- sensor_dir_y * (sensor_size_y / 2.0)
|
||||
)
|
||||
LOGGER.debug("Sensor luc: %s.", str(sensor_luc))
|
||||
pixel_size_x = sensor_size_x / float(width)
|
||||
pixel_size_y = sensor_size_y / float(height)
|
||||
LOGGER.debug(
|
||||
"Pixel sizes (x): %s, (y) %s.", str(pixel_size_x), str(pixel_size_y)
|
||||
)
|
||||
pixel_vec_x: torch.Tensor = sensor_dir_x * pixel_size_x
|
||||
pixel_vec_y: torch.Tensor = sensor_dir_y * pixel_size_y
|
||||
pixel_0_0_center = sensor_luc + 0.5 * pixel_vec_x + 0.5 * pixel_vec_y
|
||||
LOGGER.debug(
|
||||
"Pixel 0 centers: %s, vec x: %s, vec y: %s.",
|
||||
str(pixel_0_0_center),
|
||||
str(pixel_vec_x),
|
||||
str(pixel_vec_y),
|
||||
)
|
||||
if not orthogonal:
|
||||
LOGGER.debug(
|
||||
"Camera horizontal fovs: %s deg.",
|
||||
str(
|
||||
2.0
|
||||
* torch.atan(0.5 * sensor_size_x / focal_length)
|
||||
/ math.pi
|
||||
* 180.0
|
||||
),
|
||||
)
|
||||
LOGGER.debug(
|
||||
"Camera vertical fovs: %s deg.",
|
||||
str(
|
||||
2.0
|
||||
* torch.atan(0.5 * sensor_size_y / focal_length)
|
||||
/ math.pi
|
||||
* 180.0
|
||||
),
|
||||
)
|
||||
# Reduce dimension.
|
||||
focal_length: torch.Tensor = focal_length[:, 0]
|
||||
if batch_processing:
|
||||
return (
|
||||
pos_vec,
|
||||
pixel_0_0_center,
|
||||
pixel_vec_x,
|
||||
pixel_vec_y,
|
||||
focal_length,
|
||||
principal_point_offsets,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
pos_vec[0],
|
||||
pixel_0_0_center[0],
|
||||
pixel_vec_x[0],
|
||||
pixel_vec_y[0],
|
||||
focal_length[0],
|
||||
principal_point_offsets[0],
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
vert_pos: torch.Tensor,
|
||||
vert_col: torch.Tensor,
|
||||
vert_rad: torch.Tensor,
|
||||
cam_params: torch.Tensor,
|
||||
gamma: float,
|
||||
max_depth: float,
|
||||
min_depth: float = 0.0,
|
||||
bg_col: Optional[torch.Tensor] = None,
|
||||
opacity: Optional[torch.Tensor] = None,
|
||||
percent_allowed_difference: float = 0.01,
|
||||
max_n_hits: int = _C.MAX_UINT,
|
||||
mode: int = 0,
|
||||
return_forward_info: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||
"""
|
||||
Rendering pass to create an image from the provided spheres and camera
|
||||
parameters.
|
||||
|
||||
Args:
|
||||
* vert_pos: vertex positions. [Bx]Nx3 tensor of positions in 3D space.
|
||||
* vert_col: vertex colors. [Bx]NxK tensor of channels.
|
||||
* vert_rad: vertex radii. [Bx]N tensor of radiuses, >0.
|
||||
* cam_params: camera parameter(s). [Bx]8 tensor, consisting of:
|
||||
- 3 components for camera position,
|
||||
- 3 components for camera rotation (axis angle representation) or
|
||||
6 components as described in "On the Continuity of Rotation
|
||||
Representations in Neural Networks" (Zhou et al.),
|
||||
- focal length,
|
||||
- the sensor width in world coordinates,
|
||||
- [optional] an offset for the principal point in x, y (no gradients).
|
||||
* gamma: sphere transparency in [1.,1E-5], with 1 being mostly transparent.
|
||||
[Bx]1.
|
||||
* max_depth: maximum depth for spheres to render. Set this as tightly
|
||||
as possible to have good numerical accuracy for gradients.
|
||||
float > min_depth + eps.
|
||||
* min_depth: a float with the minimum depth a sphere must have to be
|
||||
rendered. Must be 0. or > max(focal_length) + eps.
|
||||
* bg_col: K tensor with a background color to use or None (uses all ones).
|
||||
* opacity: [Bx]N tensor of opacity values in [0., 1.] or None (uses all
|
||||
ones).
|
||||
* percent_allowed_difference: a float in [0., 1.[ with the maximum allowed
|
||||
difference in color space. This is used to speed up the
|
||||
computation. Default: 0.01.
|
||||
* max_n_hits: a hard limit on the number of hits per ray. Default: max int.
|
||||
* mode: render mode in {0, 1}. 0: render an image; 1: render the hit map.
|
||||
* return_forward_info: whether to return a second map. This second map
|
||||
contains 13 channels: first channel contains sm_m (the maximum
|
||||
exponent factor observed), the second sm_d (the normalization
|
||||
denominator, the sum of all coefficients), the third the maximum closest
|
||||
possible intersection for a hit. The following channels alternate with
|
||||
the float encoded integer index of a sphere and its weight. They are the
|
||||
five spheres with the highest color contribution to this pixel color,
|
||||
ordered descending. Default: False.
|
||||
|
||||
Returns:
|
||||
* image: [Bx]HxWx3 float tensor with the resulting image.
|
||||
* forw_info: [Bx]HxWx13 float forward information as described above, if
|
||||
enabled.
|
||||
"""
|
||||
# The device tracker is registered as buffer.
|
||||
# pyre-fixme[16]: `Renderer` has no attribute `device_tracker`.
|
||||
self._renderer.device_tracker = self.device_tracker
|
||||
(
|
||||
pos_vec,
|
||||
pixel_0_0_center,
|
||||
pixel_vec_x,
|
||||
pixel_vec_y,
|
||||
focal_lengths,
|
||||
principal_point_offsets,
|
||||
) = Renderer._transform_cam_params(
|
||||
cam_params,
|
||||
self._renderer.width,
|
||||
self._renderer.height,
|
||||
self._renderer.orthogonal,
|
||||
self._renderer.right_handed,
|
||||
)
|
||||
if (
|
||||
focal_lengths.min().item() > 0.0
|
||||
and max_depth > 10_000.0 * focal_lengths.min().item()
|
||||
):
|
||||
warnings.warn(
|
||||
(
|
||||
"Extreme ratio of `max_depth` vs. focal length detected "
|
||||
"(%f vs. %f, ratio: %f). This will likely lead to "
|
||||
"artifacts due to numerical instabilities."
|
||||
)
|
||||
% (
|
||||
max_depth,
|
||||
focal_lengths.min().item(),
|
||||
max_depth / focal_lengths.min().item(),
|
||||
)
|
||||
)
|
||||
# pyre-fixme[16]: `_Render` has no attribute `apply`.
|
||||
ret_res = _Render.apply(
|
||||
vert_pos,
|
||||
vert_col,
|
||||
vert_rad,
|
||||
pos_vec,
|
||||
pixel_0_0_center,
|
||||
pixel_vec_x,
|
||||
pixel_vec_y,
|
||||
# Focal length and sensor size don't need gradients other than through
|
||||
# `pixel_vec_x` and `pixel_vec_y`. The focal length is only used in the
|
||||
# renderer to determine the projection areas of the balls.
|
||||
focal_lengths,
|
||||
# principal_point_offsets does not receive gradients.
|
||||
principal_point_offsets,
|
||||
gamma,
|
||||
max_depth,
|
||||
self._renderer,
|
||||
min_depth,
|
||||
bg_col,
|
||||
opacity,
|
||||
percent_allowed_difference,
|
||||
max_n_hits,
|
||||
mode,
|
||||
(mode == 0) and return_forward_info,
|
||||
)
|
||||
if return_forward_info and mode != 0:
|
||||
return ret_res, None
|
||||
return ret_res
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Extra information to print in pytorch graphs."""
|
||||
return "width={}, height={}, max_num_balls={}".format(
|
||||
self._renderer.width, self._renderer.height, self._renderer.max_num_balls
|
||||
)
|
@ -1,9 +1,13 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .external.kornia_angle_axis_to_rotation_matrix import (
|
||||
angle_axis_to_rotation_matrix as axis_angle_to_matrix,
|
||||
)
|
||||
from .rotation_conversions import (
|
||||
euler_angles_to_matrix,
|
||||
matrix_to_euler_angles,
|
||||
matrix_to_quaternion,
|
||||
matrix_to_rotation_6d,
|
||||
quaternion_apply,
|
||||
quaternion_invert,
|
||||
quaternion_multiply,
|
||||
@ -12,6 +16,7 @@ from .rotation_conversions import (
|
||||
random_quaternions,
|
||||
random_rotation,
|
||||
random_rotations,
|
||||
rotation_6d_to_matrix,
|
||||
standardize_quaternion,
|
||||
)
|
||||
from .so3 import (
|
||||
|
1
pytorch3d/transforms/external/__init__.py
vendored
Normal file
1
pytorch3d/transforms/external/__init__.py
vendored
Normal file
@ -0,0 +1 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
94
pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py
vendored
Normal file
94
pytorch3d/transforms/external/kornia_angle_axis_to_rotation_matrix.py
vendored
Normal file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This file contains the great angle axis to rotation matrix conversion
|
||||
from kornia (https://github.com/arraiyopensource/kornia). The license
|
||||
can be found in kornia_license.txt.
|
||||
|
||||
The method is used unchanged; the documentation has been adjusted
|
||||
to match our doc format.
|
||||
"""
|
||||
import torch
|
||||
|
||||
|
||||
def angle_axis_to_rotation_matrix(angle_axis):
|
||||
"""Convert 3d vector of axis-angle rotation to 4x4 rotation matrix
|
||||
|
||||
Args:
|
||||
angle_axis (Tensor): tensor of 3d vector of axis-angle rotations.
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of 3x3 rotation matrix.
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, 3)`
|
||||
- Output: :math:`(N, 3, 3)`
|
||||
|
||||
Example:
|
||||
|
||||
..code-block::python
|
||||
|
||||
>>> input = torch.rand(1, 3) # Nx3
|
||||
>>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3
|
||||
>>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx3x3
|
||||
"""
|
||||
|
||||
def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6):
|
||||
# We want to be careful to only evaluate the square root if the
|
||||
# norm of the angle_axis vector is greater than zero. Otherwise
|
||||
# we get a division by zero.
|
||||
k_one = 1.0
|
||||
theta = torch.sqrt(theta2)
|
||||
wxyz = angle_axis / (theta + eps)
|
||||
wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
|
||||
cos_theta = torch.cos(theta)
|
||||
sin_theta = torch.sin(theta)
|
||||
|
||||
r00 = cos_theta + wx * wx * (k_one - cos_theta)
|
||||
r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
|
||||
r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
|
||||
r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
|
||||
r11 = cos_theta + wy * wy * (k_one - cos_theta)
|
||||
r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
|
||||
r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
|
||||
r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
|
||||
r22 = cos_theta + wz * wz * (k_one - cos_theta)
|
||||
rotation_matrix = torch.cat(
|
||||
[r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1
|
||||
)
|
||||
return rotation_matrix.view(-1, 3, 3)
|
||||
|
||||
def _compute_rotation_matrix_taylor(angle_axis):
|
||||
rx, ry, rz = torch.chunk(angle_axis, 3, dim=1)
|
||||
k_one = torch.ones_like(rx)
|
||||
rotation_matrix = torch.cat(
|
||||
[k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1
|
||||
)
|
||||
return rotation_matrix.view(-1, 3, 3)
|
||||
|
||||
# stolen from ceres/rotation.h
|
||||
|
||||
_angle_axis = torch.unsqueeze(angle_axis + 1e-6, dim=1)
|
||||
# _angle_axis.register_hook(lambda grad: pdb.set_trace())
|
||||
# _angle_axis = 1e-6
|
||||
theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2))
|
||||
theta2 = torch.squeeze(theta2, dim=1)
|
||||
|
||||
# compute rotation matrices
|
||||
rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2)
|
||||
rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
|
||||
|
||||
# create mask to handle both cases
|
||||
eps = 1e-6
|
||||
mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device)
|
||||
mask_pos = (mask).type_as(theta2)
|
||||
mask_neg = (mask == False).type_as(theta2) # noqa
|
||||
|
||||
# create output pose matrix
|
||||
batch_size = angle_axis.shape[0]
|
||||
rotation_matrix = torch.eye(3).to(angle_axis.device).type_as(angle_axis)
|
||||
rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1)
|
||||
# fill output matrix with masked values
|
||||
rotation_matrix[..., :3, :3] = (
|
||||
mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor
|
||||
)
|
||||
return rotation_matrix.to(angle_axis.device).type_as(angle_axis) # Nx4x4
|
201
pytorch3d/transforms/external/kornia_license.txt
vendored
Normal file
201
pytorch3d/transforms/external/kornia_license.txt
vendored
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -110,3 +110,7 @@ def bm_barycentric_clip() -> None:
|
||||
|
||||
benchmark(baryclip_cuda, "BARY_CLIP_CUDA", kwargs_list, warmup_iters=1)
|
||||
benchmark(baryclip_pytorch, "BARY_CLIP_PYTORCH", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_barycentric_clip()
|
||||
|
@ -42,3 +42,7 @@ def bm_blending() -> None:
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_blending()
|
||||
|
@ -22,3 +22,7 @@ def bm_cameras_alignment() -> None:
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_cameras_alignment()
|
||||
|
@ -8,6 +8,8 @@ from test_chamfer import TestChamfer
|
||||
|
||||
|
||||
def bm_chamfer() -> None:
|
||||
# Currently disabled.
|
||||
return
|
||||
devices = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices.append("cuda:0")
|
||||
@ -53,3 +55,7 @@ def bm_chamfer() -> None:
|
||||
}
|
||||
)
|
||||
benchmark(TestChamfer.chamfer_with_init, "CHAMFER", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_chamfer()
|
||||
|
@ -11,3 +11,7 @@ def bm_cubify() -> None:
|
||||
{"batch_size": 16, "V": 32},
|
||||
]
|
||||
benchmark(TestCubify.cubify_with_init, "CUBIFY", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_cubify()
|
||||
|
@ -37,3 +37,7 @@ def bm_face_areas_normals() -> None:
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_face_areas_normals()
|
||||
|
@ -40,3 +40,7 @@ def bm_graph_conv() -> None:
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_graph_conv()
|
||||
|
@ -74,3 +74,7 @@ def bm_interpolate_face_attribues() -> None:
|
||||
kwargs_list.append({"N": N, "S": S, "K": K, "F": F, "D": D, "impl": impl})
|
||||
benchmark(_bm_forward, "FORWARD", kwargs_list, warmup_iters=3)
|
||||
benchmark(_bm_forward_backward, "FORWARD+BACKWARD", kwargs_list, warmup_iters=3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_interpolate_face_attribues()
|
||||
|
@ -24,3 +24,7 @@ def bm_knn() -> None:
|
||||
benchmark(TestKNN.knn_square, "KNN_SQUARE", kwargs_list, warmup_iters=1)
|
||||
|
||||
benchmark(TestKNN.knn_ragged, "KNN_RAGGED", kwargs_list, warmup_iters=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_knn()
|
||||
|
@ -45,3 +45,7 @@ def bm_lighting() -> None:
|
||||
kwargs_list.append({"N": N, "S": S, "K": K})
|
||||
benchmark(_bm_diffuse_cuda_with_init, "DIFFUSE", kwargs_list, warmup_iters=3)
|
||||
benchmark(_bm_specular_cuda_with_init, "SPECULAR", kwargs_list, warmup_iters=3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_lighting()
|
||||
|
@ -2,8 +2,10 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
from os.path import basename, dirname, isfile, join, sys
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from os.path import dirname, isfile, join
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -11,20 +13,22 @@ if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
# Parse from flags.
|
||||
# pyre-ignore[16]
|
||||
module_names = [n for n in sys.argv if n.startswith("bm_")]
|
||||
file_names = [
|
||||
join(dirname(__file__), n) for n in sys.argv if n.startswith("bm_")
|
||||
]
|
||||
else:
|
||||
# Get all the benchmark files (starting with "bm_").
|
||||
bm_files = glob.glob(join(dirname(__file__), "bm_*.py"))
|
||||
module_names = [
|
||||
basename(f)[:-3]
|
||||
for f in bm_files
|
||||
if isfile(f) and not f.endswith("bm_main.py")
|
||||
]
|
||||
file_names = sorted(
|
||||
f for f in bm_files if isfile(f) and not f.endswith("bm_main.py")
|
||||
)
|
||||
|
||||
for module_name in module_names:
|
||||
module = importlib.import_module(module_name)
|
||||
for attr in dir(module):
|
||||
# Run all the functions with names "bm_*" in the module.
|
||||
if attr.startswith("bm_"):
|
||||
print("Running benchmarks for " + module_name + "/" + attr + "...")
|
||||
getattr(module, attr)()
|
||||
# Forward all important path information to the subprocesses through the
|
||||
# environment.
|
||||
os.environ["PATH"] = sys.path[0] + ":" + os.environ.get("PATH", "")
|
||||
os.environ["LD_LIBRARY_PATH"] = (
|
||||
sys.path[0] + ":" + os.environ.get("LD_LIBRARY_PATH", "")
|
||||
)
|
||||
os.environ["PYTHONPATH"] = ":".join(sys.path)
|
||||
for file_name in file_names:
|
||||
subprocess.check_call([sys.executable, file_name])
|
||||
|
@ -19,3 +19,7 @@ def bm_mesh_edge_loss() -> None:
|
||||
benchmark(
|
||||
TestMeshEdgeLoss.mesh_edge_loss, "MESH_EDGE_LOSS", kwargs_list, warmup_iters=1
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_mesh_edge_loss()
|
||||
|
@ -95,3 +95,7 @@ def bm_save_load() -> None:
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_save_load()
|
||||
|
@ -30,3 +30,7 @@ def bm_mesh_laplacian_smoothing() -> None:
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_mesh_laplacian_smoothing()
|
||||
|
@ -27,3 +27,7 @@ def bm_mesh_normal_consistency() -> None:
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_mesh_normal_consistency()
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user