mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 19:36:23 +08:00
pulsar integration.
Summary: This diff integrates the pulsar renderer source code into PyTorch3D as an alternative backend for the PyTorch3D point renderer. This diff is the first of a series of three diffs to complete that migration and focuses on the packaging and integration of the source code. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder `docs/examples`. Tasks addressed in the following diffs: * Add the PyTorch3D interface, * Add notebook examples and documentation (or adapt the existing ones to feature both interfaces). Reviewed By: nikhilaravi Differential Revision: D23947736 fbshipit-source-id: a5e77b53e6750334db22aefa89b4c079cda1b443
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d565032399
commit
b19fe1de2f
50
docs/examples/pulsar_basic.py
Executable file
50
docs/examples/pulsar_basic.py
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
This example demonstrates the most trivial, direct interface of the pulsar
|
||||
sphere renderer. It renders and saves an image with 10 random spheres.
|
||||
Output: basic.png.
|
||||
"""
|
||||
from os import path
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
from pytorch3d.renderer.points.pulsar import Renderer
|
||||
|
||||
|
||||
n_points = 10
|
||||
width = 1_000
|
||||
height = 1_000
|
||||
device = torch.device("cuda")
|
||||
renderer = Renderer(width, height, n_points).to(device)
|
||||
# Generate sample data.
|
||||
vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0
|
||||
vert_pos[:, 2] += 25.0
|
||||
vert_pos[:, :2] -= 5.0
|
||||
vert_col = torch.rand(n_points, 3, dtype=torch.float32, device=device)
|
||||
vert_rad = torch.rand(n_points, dtype=torch.float32, device=device)
|
||||
cam_params = torch.tensor(
|
||||
[
|
||||
0.0,
|
||||
0.0,
|
||||
0.0, # Position 0, 0, 0 (x, y, z).
|
||||
0.0,
|
||||
0.0,
|
||||
0.0, # Rotation 0, 0, 0 (in axis-angle format).
|
||||
5.0, # Focal length in world size.
|
||||
2.0, # Sensor size in world size.
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
# Render.
|
||||
image = renderer(
|
||||
vert_pos,
|
||||
vert_col,
|
||||
vert_rad,
|
||||
cam_params,
|
||||
1.0e-1, # Renderer blending parameter gamma, in [1., 1e-5].
|
||||
45.0, # Maximum depth.
|
||||
)
|
||||
print("Writing image to `%s`." % (path.abspath("basic.png")))
|
||||
imageio.imsave("basic.png", (image.cpu().detach() * 255.0).to(torch.uint8).numpy())
|
||||
158
docs/examples/pulsar_cam.py
Executable file
158
docs/examples/pulsar_cam.py
Executable file
@@ -0,0 +1,158 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
This example demonstrates camera parameter optimization with the plain
|
||||
pulsar interface. For this, a reference image has been pre-generated
|
||||
(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_cam.png`).
|
||||
The same scene parameterization is loaded and the camera parameters
|
||||
distorted. Gradient-based optimization is used to converge towards the
|
||||
original camera parameters.
|
||||
"""
|
||||
from os import path
|
||||
|
||||
import cv2
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer.points.pulsar import Renderer
|
||||
from torch import nn, optim
|
||||
|
||||
|
||||
n_points = 20
|
||||
width = 1_000
|
||||
height = 1_000
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
class SceneModel(nn.Module):
|
||||
"""
|
||||
A simple scene model to demonstrate use of pulsar in PyTorch modules.
|
||||
|
||||
The scene model is parameterized with sphere locations (vert_pos),
|
||||
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
|
||||
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
|
||||
|
||||
The forward method of the model renders this scene description. Any
|
||||
of these parameters could instead be passed as inputs to the forward
|
||||
method and come from a different model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SceneModel, self).__init__()
|
||||
self.gamma = 0.1
|
||||
# Points.
|
||||
torch.manual_seed(1)
|
||||
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
|
||||
vert_pos[:, 2] += 25.0
|
||||
vert_pos[:, :2] -= 5.0
|
||||
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False))
|
||||
self.register_parameter(
|
||||
"vert_col",
|
||||
nn.Parameter(
|
||||
torch.rand(n_points, 3, dtype=torch.float32), requires_grad=False
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"vert_rad",
|
||||
nn.Parameter(
|
||||
torch.rand(n_points, dtype=torch.float32), requires_grad=False
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"cam_pos",
|
||||
nn.Parameter(
|
||||
torch.tensor([0.1, 0.1, 0.0], dtype=torch.float32), requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"cam_rot",
|
||||
nn.Parameter(
|
||||
torch.tensor(
|
||||
[
|
||||
# We're using the 6D rot. representation for better gradients.
|
||||
0.9995,
|
||||
0.0300445,
|
||||
-0.0098482,
|
||||
-0.0299445,
|
||||
0.9995,
|
||||
0.0101482,
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
requires_grad=True,
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"cam_sensor",
|
||||
nn.Parameter(
|
||||
torch.tensor([4.8, 1.8], dtype=torch.float32), requires_grad=True
|
||||
),
|
||||
)
|
||||
self.renderer = Renderer(width, height, n_points)
|
||||
|
||||
def forward(self):
|
||||
return self.renderer.forward(
|
||||
self.vert_pos,
|
||||
self.vert_col,
|
||||
self.vert_rad,
|
||||
torch.cat([self.cam_pos, self.cam_rot, self.cam_sensor]),
|
||||
self.gamma,
|
||||
45.0,
|
||||
)
|
||||
|
||||
|
||||
# Load reference.
|
||||
ref = (
|
||||
torch.from_numpy(
|
||||
imageio.imread(
|
||||
"../../tests/pulsar/reference/examples_TestRenderer_test_cam.png"
|
||||
)
|
||||
).to(torch.float32)
|
||||
/ 255.0
|
||||
).to(device)
|
||||
# Set up model.
|
||||
model = SceneModel().to(device)
|
||||
# Optimizer.
|
||||
optimizer = optim.SGD(
|
||||
[
|
||||
{"params": [model.cam_pos], "lr": 1e-4}, # 1e-3
|
||||
{"params": [model.cam_rot], "lr": 5e-6},
|
||||
{"params": [model.cam_sensor], "lr": 1e-4},
|
||||
]
|
||||
)
|
||||
|
||||
print("Writing video to `%s`." % (path.abspath("cam.gif")))
|
||||
writer = imageio.get_writer("cam.gif", format="gif", fps=25)
|
||||
|
||||
# Optimize.
|
||||
for i in range(300):
|
||||
optimizer.zero_grad()
|
||||
result = model()
|
||||
# Visualize.
|
||||
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
||||
cv2.imshow("opt", result_im[:, :, ::-1])
|
||||
writer.append_data(result_im)
|
||||
overlay_img = np.ascontiguousarray(
|
||||
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
|
||||
:, :, ::-1
|
||||
]
|
||||
)
|
||||
overlay_img = cv2.putText(
|
||||
overlay_img,
|
||||
"Step %d" % (i),
|
||||
(10, 40),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(0, 0, 0),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
False,
|
||||
)
|
||||
cv2.imshow("overlay", overlay_img)
|
||||
cv2.waitKey(1)
|
||||
# Update.
|
||||
loss = ((result - ref) ** 2).sum()
|
||||
print("loss {}: {}".format(i, loss.item()))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
writer.close()
|
||||
201
docs/examples/pulsar_multiview.py
Executable file
201
docs/examples/pulsar_multiview.py
Executable file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
This example demonstrates multiview 3D reconstruction using the plain
|
||||
pulsar interface. For this, reference images have been pre-generated
|
||||
(you can find them at `../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png`).
|
||||
The camera parameters are assumed given. The scene is initialized with
|
||||
random spheres. Gradient-based optimization is used to optimize sphere
|
||||
parameters and prune spheres to converge to a 3D representation.
|
||||
"""
|
||||
from os import path
|
||||
|
||||
import cv2
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer.points.pulsar import Renderer
|
||||
from torch import nn, optim
|
||||
|
||||
|
||||
n_points = 400_000
|
||||
width = 1_000
|
||||
height = 1_000
|
||||
visualize_ids = [0, 1]
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
class SceneModel(nn.Module):
|
||||
"""
|
||||
A simple scene model to demonstrate use of pulsar in PyTorch modules.
|
||||
|
||||
The scene model is parameterized with sphere locations (vert_pos),
|
||||
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
|
||||
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
|
||||
|
||||
The forward method of the model renders this scene description. Any
|
||||
of these parameters could instead be passed as inputs to the forward
|
||||
method and come from a different model. Optionally, camera parameters can
|
||||
be provided to the forward method in which case the scene is rendered
|
||||
using those parameters.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SceneModel, self).__init__()
|
||||
self.gamma = 1.0
|
||||
# Points.
|
||||
torch.manual_seed(1)
|
||||
vert_pos = torch.rand((1, n_points, 3), dtype=torch.float32) * 10.0
|
||||
vert_pos[:, :, 2] += 25.0
|
||||
vert_pos[:, :, :2] -= 5.0
|
||||
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
|
||||
self.register_parameter(
|
||||
"vert_col",
|
||||
nn.Parameter(
|
||||
torch.ones(1, n_points, 3, dtype=torch.float32) * 0.5,
|
||||
requires_grad=True,
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"vert_rad",
|
||||
nn.Parameter(
|
||||
torch.ones(1, n_points, dtype=torch.float32) * 0.05, requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"vert_opy",
|
||||
nn.Parameter(
|
||||
torch.ones(1, n_points, dtype=torch.float32), requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"cam_params",
|
||||
torch.tensor(
|
||||
[
|
||||
[
|
||||
np.sin(angle) * 35.0,
|
||||
0.0,
|
||||
30.0 - np.cos(angle) * 35.0,
|
||||
0.0,
|
||||
-angle,
|
||||
0.0,
|
||||
5.0,
|
||||
2.0,
|
||||
]
|
||||
for angle in [-1.5, -0.8, -0.4, -0.1, 0.1, 0.4, 0.8, 1.5]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
)
|
||||
self.renderer = Renderer(width, height, n_points)
|
||||
|
||||
def forward(self, cam=None):
|
||||
if cam is None:
|
||||
cam = self.cam_params
|
||||
n_views = 8
|
||||
else:
|
||||
n_views = 1
|
||||
return self.renderer.forward(
|
||||
self.vert_pos.expand(n_views, -1, -1),
|
||||
self.vert_col.expand(n_views, -1, -1),
|
||||
self.vert_rad.expand(n_views, -1),
|
||||
cam,
|
||||
self.gamma,
|
||||
45.0,
|
||||
)
|
||||
|
||||
|
||||
# Load reference.
|
||||
ref = torch.stack(
|
||||
[
|
||||
torch.from_numpy(
|
||||
imageio.imread(
|
||||
"../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png"
|
||||
% idx
|
||||
)
|
||||
).to(torch.float32)
|
||||
/ 255.0
|
||||
for idx in range(8)
|
||||
]
|
||||
).to(device)
|
||||
# Set up model.
|
||||
model = SceneModel().to(device)
|
||||
# Optimizer.
|
||||
optimizer = optim.SGD(
|
||||
[
|
||||
{"params": [model.vert_col], "lr": 1e-1},
|
||||
{"params": [model.vert_rad], "lr": 1e-3},
|
||||
{"params": [model.vert_pos], "lr": 1e-3},
|
||||
]
|
||||
)
|
||||
|
||||
# For visualization.
|
||||
angle = 0.0
|
||||
print("Writing video to `%s`." % (path.abspath("multiview.avi")))
|
||||
writer = imageio.get_writer("multiview.gif", format="gif", fps=25)
|
||||
|
||||
# Optimize.
|
||||
for i in range(300):
|
||||
optimizer.zero_grad()
|
||||
result = model()
|
||||
# Visualize.
|
||||
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
||||
cv2.imshow("opt", result_im[0, :, :, ::-1])
|
||||
overlay_img = np.ascontiguousarray(
|
||||
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
|
||||
0, :, :, ::-1
|
||||
]
|
||||
)
|
||||
overlay_img = cv2.putText(
|
||||
overlay_img,
|
||||
"Step %d" % (i),
|
||||
(10, 40),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(0, 0, 0),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
False,
|
||||
)
|
||||
cv2.imshow("overlay", overlay_img)
|
||||
cv2.waitKey(1)
|
||||
# Update.
|
||||
loss = ((result - ref) ** 2).sum()
|
||||
print("loss {}: {}".format(i, loss.item()))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# Cleanup.
|
||||
with torch.no_grad():
|
||||
model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
|
||||
# Remove points.
|
||||
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
|
||||
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
|
||||
vd = (
|
||||
(model.vert_col - torch.ones(1, 1, 3, dtype=torch.float32).to(device))
|
||||
.abs()
|
||||
.sum(dim=2)
|
||||
)
|
||||
model.vert_pos.data[vd <= 0.2] = -1000.0
|
||||
# Rotating visualization.
|
||||
cam_control = torch.tensor(
|
||||
[
|
||||
[
|
||||
np.sin(angle) * 35.0,
|
||||
0.0,
|
||||
30.0 - np.cos(angle) * 35.0,
|
||||
0.0,
|
||||
-angle,
|
||||
0.0,
|
||||
5.0,
|
||||
2.0,
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
).to(device)
|
||||
with torch.no_grad():
|
||||
result = model.forward(cam=cam_control)[0]
|
||||
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
||||
cv2.imshow("vis", result_im[:, :, ::-1])
|
||||
writer.append_data(result_im)
|
||||
angle += 0.05
|
||||
writer.close()
|
||||
140
docs/examples/pulsar_optimization.py
Executable file
140
docs/examples/pulsar_optimization.py
Executable file
@@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
This example demonstrates scene optimization with the plain
|
||||
pulsar interface. For this, a reference image has been pre-generated
|
||||
(you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png`).
|
||||
The scene is initialized with random spheres. Gradient-based
|
||||
optimization is used to converge towards a faithful
|
||||
scene representation.
|
||||
"""
|
||||
import cv2
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer.points.pulsar import Renderer
|
||||
from torch import nn, optim
|
||||
|
||||
|
||||
n_points = 10_000
|
||||
width = 1_000
|
||||
height = 1_000
|
||||
device = torch.device("cuda")
|
||||
|
||||
|
||||
class SceneModel(nn.Module):
|
||||
"""
|
||||
A simple scene model to demonstrate use of pulsar in PyTorch modules.
|
||||
|
||||
The scene model is parameterized with sphere locations (vert_pos),
|
||||
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
|
||||
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
|
||||
|
||||
The forward method of the model renders this scene description. Any
|
||||
of these parameters could instead be passed as inputs to the forward
|
||||
method and come from a different model.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(SceneModel, self).__init__()
|
||||
self.gamma = 1.0
|
||||
# Points.
|
||||
torch.manual_seed(1)
|
||||
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0
|
||||
vert_pos[:, 2] += 25.0
|
||||
vert_pos[:, :2] -= 5.0
|
||||
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
|
||||
self.register_parameter(
|
||||
"vert_col",
|
||||
nn.Parameter(
|
||||
torch.ones(n_points, 3, dtype=torch.float32) * 0.5, requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_parameter(
|
||||
"vert_rad",
|
||||
nn.Parameter(
|
||||
torch.ones(n_points, dtype=torch.float32) * 0.3, requires_grad=True
|
||||
),
|
||||
)
|
||||
self.register_buffer(
|
||||
"cam_params",
|
||||
torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 2.0], dtype=torch.float32),
|
||||
)
|
||||
# The volumetric optimization works better with a higher number of tracked
|
||||
# intersections per ray.
|
||||
self.renderer = Renderer(width, height, n_points, n_track=32)
|
||||
|
||||
def forward(self):
|
||||
return self.renderer.forward(
|
||||
self.vert_pos,
|
||||
self.vert_col,
|
||||
self.vert_rad,
|
||||
self.cam_params,
|
||||
self.gamma,
|
||||
45.0,
|
||||
return_forward_info=True,
|
||||
)
|
||||
|
||||
|
||||
# Load reference.
|
||||
ref = (
|
||||
torch.from_numpy(
|
||||
imageio.imread(
|
||||
"../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png"
|
||||
)
|
||||
).to(torch.float32)
|
||||
/ 255.0
|
||||
).to(device)
|
||||
# Set up model.
|
||||
model = SceneModel().to(device)
|
||||
# Optimizer.
|
||||
optimizer = optim.SGD(
|
||||
[
|
||||
{"params": [model.vert_col], "lr": 1e0},
|
||||
{"params": [model.vert_rad], "lr": 5e-3},
|
||||
{"params": [model.vert_pos], "lr": 1e-2},
|
||||
]
|
||||
)
|
||||
|
||||
# Optimize.
|
||||
for i in range(500):
|
||||
optimizer.zero_grad()
|
||||
result, result_info = model()
|
||||
# Visualize.
|
||||
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
||||
cv2.imshow("opt", result_im[:, :, ::-1])
|
||||
overlay_img = np.ascontiguousarray(
|
||||
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
|
||||
:, :, ::-1
|
||||
]
|
||||
)
|
||||
overlay_img = cv2.putText(
|
||||
overlay_img,
|
||||
"Step %d" % (i),
|
||||
(10, 40),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
1,
|
||||
(0, 0, 0),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
False,
|
||||
)
|
||||
cv2.imshow("overlay", overlay_img)
|
||||
cv2.waitKey(1)
|
||||
# Update.
|
||||
loss = ((result - ref) ** 2).sum()
|
||||
print("loss {}: {}".format(i, loss.item()))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# Cleanup.
|
||||
with torch.no_grad():
|
||||
model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
|
||||
# Remove points.
|
||||
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
|
||||
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
|
||||
vd = (
|
||||
(model.vert_col - torch.ones(3, dtype=torch.float32).to(device))
|
||||
.abs()
|
||||
.sum(dim=1)
|
||||
)
|
||||
model.vert_pos.data[vd <= 0.2] = -1000.0
|
||||
Reference in New Issue
Block a user