pytorch3d/docs/examples/pulsar_multiview.py
Christoph Lassner 039e02601d examples and docs.
Summary: This diff updates the documentation and tutorials with information about the new pulsar backend. For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder docs/examples.

Reviewed By: nikhilaravi

Differential Revision: D24498129

fbshipit-source-id: e312b0169a72b13590df6e4db36bfe6190d742f9
2020-11-03 13:06:35 -08:00

207 lines
6.2 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This example demonstrates multiview 3D reconstruction using the plain
pulsar interface. For this, reference images have been pre-generated
(you can find them at `../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png`).
The camera parameters are assumed given. The scene is initialized with
random spheres. Gradient-based optimization is used to optimize sphere
parameters and prune spheres to converge to a 3D representation.
This example is not available yet through the 'unified' interface,
because opacity support has not landed in PyTorch3D for general data
structures yet.
"""
import math
from os import path
import cv2
import imageio
import numpy as np
import torch
from pytorch3d.renderer.points.pulsar import Renderer
from torch import nn, optim
n_points = 400_000
width = 1_000
height = 1_000
visualize_ids = [0, 1]
device = torch.device("cuda")
class SceneModel(nn.Module):
"""
A simple scene model to demonstrate use of pulsar in PyTorch modules.
The scene model is parameterized with sphere locations (vert_pos),
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
The forward method of the model renders this scene description. Any
of these parameters could instead be passed as inputs to the forward
method and come from a different model. Optionally, camera parameters can
be provided to the forward method in which case the scene is rendered
using those parameters.
"""
def __init__(self):
super(SceneModel, self).__init__()
self.gamma = 1.0
# Points.
torch.manual_seed(1)
vert_pos = torch.rand((1, n_points, 3), dtype=torch.float32) * 10.0
vert_pos[:, :, 2] += 25.0
vert_pos[:, :, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
self.register_parameter(
"vert_col",
nn.Parameter(
torch.ones(1, n_points, 3, dtype=torch.float32) * 0.5,
requires_grad=True,
),
)
self.register_parameter(
"vert_rad",
nn.Parameter(
torch.ones(1, n_points, dtype=torch.float32) * 0.05, requires_grad=True
),
)
self.register_parameter(
"vert_opy",
nn.Parameter(
torch.ones(1, n_points, dtype=torch.float32), requires_grad=True
),
)
self.register_buffer(
"cam_params",
torch.tensor(
[
[
np.sin(angle) * 35.0,
0.0,
30.0 - np.cos(angle) * 35.0,
0.0,
-angle + math.pi,
0.0,
5.0,
2.0,
]
for angle in [-1.5, -0.8, -0.4, -0.1, 0.1, 0.4, 0.8, 1.5]
],
dtype=torch.float32,
),
)
self.renderer = Renderer(width, height, n_points, right_handed_system=True)
def forward(self, cam=None):
if cam is None:
cam = self.cam_params
n_views = 8
else:
n_views = 1
return self.renderer.forward(
self.vert_pos.expand(n_views, -1, -1),
self.vert_col.expand(n_views, -1, -1),
self.vert_rad.expand(n_views, -1),
cam,
self.gamma,
45.0,
)
# Load reference.
ref = torch.stack(
[
torch.from_numpy(
imageio.imread(
"../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png"
% idx
)
).to(torch.float32)
/ 255.0
for idx in range(8)
]
).to(device)
# Set up model.
model = SceneModel().to(device)
# Optimizer.
optimizer = optim.SGD(
[
{"params": [model.vert_col], "lr": 1e-1},
{"params": [model.vert_rad], "lr": 1e-3},
{"params": [model.vert_pos], "lr": 1e-3},
]
)
# For visualization.
angle = 0.0
print("Writing video to `%s`." % (path.abspath("multiview.avi")))
writer = imageio.get_writer("multiview.gif", format="gif", fps=25)
# Optimize.
for i in range(300):
optimizer.zero_grad()
result = model()
# Visualize.
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
cv2.imshow("opt", result_im[0, :, :, ::-1])
overlay_img = np.ascontiguousarray(
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
0, :, :, ::-1
]
)
overlay_img = cv2.putText(
overlay_img,
"Step %d" % (i),
(10, 40),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 0),
2,
cv2.LINE_AA,
False,
)
cv2.imshow("overlay", overlay_img)
cv2.waitKey(1)
# Update.
loss = ((result - ref) ** 2).sum()
print("loss {}: {}".format(i, loss.item()))
loss.backward()
optimizer.step()
# Cleanup.
with torch.no_grad():
model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
# Remove points.
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
vd = (
(model.vert_col - torch.ones(1, 1, 3, dtype=torch.float32).to(device))
.abs()
.sum(dim=2)
)
model.vert_pos.data[vd <= 0.2] = -1000.0
# Rotating visualization.
cam_control = torch.tensor(
[
[
np.sin(angle) * 35.0,
0.0,
30.0 - np.cos(angle) * 35.0,
0.0,
-angle + math.pi,
0.0,
5.0,
2.0,
]
],
dtype=torch.float32,
).to(device)
with torch.no_grad():
result = model.forward(cam=cam_control)[0]
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
cv2.imshow("vis", result_im[:, :, ::-1])
writer.append_data(result_im)
angle += 0.05
writer.close()