pulsar integration.

Summary:
This diff integrates the pulsar renderer source code into PyTorch3D as an alternative backend for the PyTorch3D point renderer. This diff is the first of a series of three diffs to complete that migration and focuses on the packaging and integration of the source code.

For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder `docs/examples`.

Tasks addressed in the following diffs:
* Add the PyTorch3D interface,
* Add notebook examples and documentation (or adapt the existing ones to feature both interfaces).

Reviewed By: nikhilaravi

Differential Revision: D23947736

fbshipit-source-id: a5e77b53e6750334db22aefa89b4c079cda1b443
This commit is contained in:
Christoph Lassner
2020-11-03 13:05:02 -08:00
committed by Facebook GitHub Bot
parent d565032399
commit b19fe1de2f
137 changed files with 10055 additions and 37 deletions

50
docs/examples/pulsar_basic.py Executable file
View File

@@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This example demonstrates the most trivial, direct interface of the pulsar
sphere renderer. It renders and saves an image with 10 random spheres.
Output: basic.png.
"""
from os import path
import imageio
import torch
from pytorch3d.renderer.points.pulsar import Renderer
n_points = 10
width = 1_000
height = 1_000
device = torch.device("cuda")
renderer = Renderer(width, height, n_points).to(device)
# Generate sample data.
vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0
vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0
vert_col = torch.rand(n_points, 3, dtype=torch.float32, device=device)
vert_rad = torch.rand(n_points, dtype=torch.float32, device=device)
cam_params = torch.tensor(
[
0.0,
0.0,
0.0, # Position 0, 0, 0 (x, y, z).
0.0,
0.0,
0.0, # Rotation 0, 0, 0 (in axis-angle format).
5.0, # Focal length in world size.
2.0, # Sensor size in world size.
],
dtype=torch.float32,
device=device,
)
# Render.
image = renderer(
vert_pos,
vert_col,
vert_rad,
cam_params,
1.0e-1, # Renderer blending parameter gamma, in [1., 1e-5].
45.0, # Maximum depth.
)
print("Writing image to `%s`." % (path.abspath("basic.png")))
imageio.imsave("basic.png", (image.cpu().detach() * 255.0).to(torch.uint8).numpy())

158
docs/examples/pulsar_cam.py Executable file
View File

@@ -0,0 +1,158 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This example demonstrates camera parameter optimization with the plain
pulsar interface. For this, a reference image has been pre-generated
(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_cam.png`).
The same scene parameterization is loaded and the camera parameters
distorted. Gradient-based optimization is used to converge towards the
original camera parameters.
"""
from os import path
import cv2
import imageio
import numpy as np
import torch
from pytorch3d.renderer.points.pulsar import Renderer
from torch import nn, optim
n_points = 20
width = 1_000
height = 1_000
device = torch.device("cuda")
class SceneModel(nn.Module):
"""
A simple scene model to demonstrate use of pulsar in PyTorch modules.
The scene model is parameterized with sphere locations (vert_pos),
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
The forward method of the model renders this scene description. Any
of these parameters could instead be passed as inputs to the forward
method and come from a different model.
"""
def __init__(self):
super(SceneModel, self).__init__()
self.gamma = 0.1
# Points.
torch.manual_seed(1)
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False))
self.register_parameter(
"vert_col",
nn.Parameter(
torch.rand(n_points, 3, dtype=torch.float32), requires_grad=False
),
)
self.register_parameter(
"vert_rad",
nn.Parameter(
torch.rand(n_points, dtype=torch.float32), requires_grad=False
),
)
self.register_parameter(
"cam_pos",
nn.Parameter(
torch.tensor([0.1, 0.1, 0.0], dtype=torch.float32), requires_grad=True
),
)
self.register_parameter(
"cam_rot",
nn.Parameter(
torch.tensor(
[
# We're using the 6D rot. representation for better gradients.
0.9995,
0.0300445,
-0.0098482,
-0.0299445,
0.9995,
0.0101482,
],
dtype=torch.float32,
),
requires_grad=True,
),
)
self.register_parameter(
"cam_sensor",
nn.Parameter(
torch.tensor([4.8, 1.8], dtype=torch.float32), requires_grad=True
),
)
self.renderer = Renderer(width, height, n_points)
def forward(self):
return self.renderer.forward(
self.vert_pos,
self.vert_col,
self.vert_rad,
torch.cat([self.cam_pos, self.cam_rot, self.cam_sensor]),
self.gamma,
45.0,
)
# Load reference.
ref = (
torch.from_numpy(
imageio.imread(
"../../tests/pulsar/reference/examples_TestRenderer_test_cam.png"
)
).to(torch.float32)
/ 255.0
).to(device)
# Set up model.
model = SceneModel().to(device)
# Optimizer.
optimizer = optim.SGD(
[
{"params": [model.cam_pos], "lr": 1e-4}, # 1e-3
{"params": [model.cam_rot], "lr": 5e-6},
{"params": [model.cam_sensor], "lr": 1e-4},
]
)
print("Writing video to `%s`." % (path.abspath("cam.gif")))
writer = imageio.get_writer("cam.gif", format="gif", fps=25)
# Optimize.
for i in range(300):
optimizer.zero_grad()
result = model()
# Visualize.
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
cv2.imshow("opt", result_im[:, :, ::-1])
writer.append_data(result_im)
overlay_img = np.ascontiguousarray(
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
:, :, ::-1
]
)
overlay_img = cv2.putText(
overlay_img,
"Step %d" % (i),
(10, 40),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 0),
2,
cv2.LINE_AA,
False,
)
cv2.imshow("overlay", overlay_img)
cv2.waitKey(1)
# Update.
loss = ((result - ref) ** 2).sum()
print("loss {}: {}".format(i, loss.item()))
loss.backward()
optimizer.step()
writer.close()

201
docs/examples/pulsar_multiview.py Executable file
View File

@@ -0,0 +1,201 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This example demonstrates multiview 3D reconstruction using the plain
pulsar interface. For this, reference images have been pre-generated
(you can find them at `../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png`).
The camera parameters are assumed given. The scene is initialized with
random spheres. Gradient-based optimization is used to optimize sphere
parameters and prune spheres to converge to a 3D representation.
"""
from os import path
import cv2
import imageio
import numpy as np
import torch
from pytorch3d.renderer.points.pulsar import Renderer
from torch import nn, optim
n_points = 400_000
width = 1_000
height = 1_000
visualize_ids = [0, 1]
device = torch.device("cuda")
class SceneModel(nn.Module):
"""
A simple scene model to demonstrate use of pulsar in PyTorch modules.
The scene model is parameterized with sphere locations (vert_pos),
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
The forward method of the model renders this scene description. Any
of these parameters could instead be passed as inputs to the forward
method and come from a different model. Optionally, camera parameters can
be provided to the forward method in which case the scene is rendered
using those parameters.
"""
def __init__(self):
super(SceneModel, self).__init__()
self.gamma = 1.0
# Points.
torch.manual_seed(1)
vert_pos = torch.rand((1, n_points, 3), dtype=torch.float32) * 10.0
vert_pos[:, :, 2] += 25.0
vert_pos[:, :, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
self.register_parameter(
"vert_col",
nn.Parameter(
torch.ones(1, n_points, 3, dtype=torch.float32) * 0.5,
requires_grad=True,
),
)
self.register_parameter(
"vert_rad",
nn.Parameter(
torch.ones(1, n_points, dtype=torch.float32) * 0.05, requires_grad=True
),
)
self.register_parameter(
"vert_opy",
nn.Parameter(
torch.ones(1, n_points, dtype=torch.float32), requires_grad=True
),
)
self.register_buffer(
"cam_params",
torch.tensor(
[
[
np.sin(angle) * 35.0,
0.0,
30.0 - np.cos(angle) * 35.0,
0.0,
-angle,
0.0,
5.0,
2.0,
]
for angle in [-1.5, -0.8, -0.4, -0.1, 0.1, 0.4, 0.8, 1.5]
],
dtype=torch.float32,
),
)
self.renderer = Renderer(width, height, n_points)
def forward(self, cam=None):
if cam is None:
cam = self.cam_params
n_views = 8
else:
n_views = 1
return self.renderer.forward(
self.vert_pos.expand(n_views, -1, -1),
self.vert_col.expand(n_views, -1, -1),
self.vert_rad.expand(n_views, -1),
cam,
self.gamma,
45.0,
)
# Load reference.
ref = torch.stack(
[
torch.from_numpy(
imageio.imread(
"../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png"
% idx
)
).to(torch.float32)
/ 255.0
for idx in range(8)
]
).to(device)
# Set up model.
model = SceneModel().to(device)
# Optimizer.
optimizer = optim.SGD(
[
{"params": [model.vert_col], "lr": 1e-1},
{"params": [model.vert_rad], "lr": 1e-3},
{"params": [model.vert_pos], "lr": 1e-3},
]
)
# For visualization.
angle = 0.0
print("Writing video to `%s`." % (path.abspath("multiview.avi")))
writer = imageio.get_writer("multiview.gif", format="gif", fps=25)
# Optimize.
for i in range(300):
optimizer.zero_grad()
result = model()
# Visualize.
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
cv2.imshow("opt", result_im[0, :, :, ::-1])
overlay_img = np.ascontiguousarray(
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
0, :, :, ::-1
]
)
overlay_img = cv2.putText(
overlay_img,
"Step %d" % (i),
(10, 40),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 0),
2,
cv2.LINE_AA,
False,
)
cv2.imshow("overlay", overlay_img)
cv2.waitKey(1)
# Update.
loss = ((result - ref) ** 2).sum()
print("loss {}: {}".format(i, loss.item()))
loss.backward()
optimizer.step()
# Cleanup.
with torch.no_grad():
model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
# Remove points.
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
vd = (
(model.vert_col - torch.ones(1, 1, 3, dtype=torch.float32).to(device))
.abs()
.sum(dim=2)
)
model.vert_pos.data[vd <= 0.2] = -1000.0
# Rotating visualization.
cam_control = torch.tensor(
[
[
np.sin(angle) * 35.0,
0.0,
30.0 - np.cos(angle) * 35.0,
0.0,
-angle,
0.0,
5.0,
2.0,
]
],
dtype=torch.float32,
).to(device)
with torch.no_grad():
result = model.forward(cam=cam_control)[0]
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
cv2.imshow("vis", result_im[:, :, ::-1])
writer.append_data(result_im)
angle += 0.05
writer.close()

View File

@@ -0,0 +1,140 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This example demonstrates scene optimization with the plain
pulsar interface. For this, a reference image has been pre-generated
(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png`).
The scene is initialized with random spheres. Gradient-based
optimization is used to converge towards a faithful
scene representation.
"""
import cv2
import imageio
import numpy as np
import torch
from pytorch3d.renderer.points.pulsar import Renderer
from torch import nn, optim
n_points = 10_000
width = 1_000
height = 1_000
device = torch.device("cuda")
class SceneModel(nn.Module):
"""
A simple scene model to demonstrate use of pulsar in PyTorch modules.
The scene model is parameterized with sphere locations (vert_pos),
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
The forward method of the model renders this scene description. Any
of these parameters could instead be passed as inputs to the forward
method and come from a different model.
"""
def __init__(self):
super(SceneModel, self).__init__()
self.gamma = 1.0
# Points.
torch.manual_seed(1)
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
self.register_parameter(
"vert_col",
nn.Parameter(
torch.ones(n_points, 3, dtype=torch.float32) * 0.5, requires_grad=True
),
)
self.register_parameter(
"vert_rad",
nn.Parameter(
torch.ones(n_points, dtype=torch.float32) * 0.3, requires_grad=True
),
)
self.register_buffer(
"cam_params",
torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32),
)
# The volumetric optimization works better with a higher number of tracked
# intersections per ray.
self.renderer = Renderer(width, height, n_points, n_track=32)
def forward(self):
return self.renderer.forward(
self.vert_pos,
self.vert_col,
self.vert_rad,
self.cam_params,
self.gamma,
45.0,
return_forward_info=True,
)
# Load reference.
ref = (
torch.from_numpy(
imageio.imread(
"../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png"
)
).to(torch.float32)
/ 255.0
).to(device)
# Set up model.
model = SceneModel().to(device)
# Optimizer.
optimizer = optim.SGD(
[
{"params": [model.vert_col], "lr": 1e0},
{"params": [model.vert_rad], "lr": 5e-3},
{"params": [model.vert_pos], "lr": 1e-2},
]
)
# Optimize.
for i in range(500):
optimizer.zero_grad()
result, result_info = model()
# Visualize.
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
cv2.imshow("opt", result_im[:, :, ::-1])
overlay_img = np.ascontiguousarray(
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
:, :, ::-1
]
)
overlay_img = cv2.putText(
overlay_img,
"Step %d" % (i),
(10, 40),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 0),
2,
cv2.LINE_AA,
False,
)
cv2.imshow("overlay", overlay_img)
cv2.waitKey(1)
# Update.
loss = ((result - ref) ** 2).sum()
print("loss {}: {}".format(i, loss.item()))
loss.backward()
optimizer.step()
# Cleanup.
with torch.no_grad():
model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
# Remove points.
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
vd = (
(model.vert_col - torch.ones(3, dtype=torch.float32).to(device))
.abs()
.sum(dim=1)
)
model.vert_pos.data[vd <= 0.2] = -1000.0