mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Changes to CI and some minor fixes now that pulsar is part of pytorch3d. Most significantly, add CUB to CI builds. Make CUB_HOME override the CUB already in cudatoolkit (important for cuda11.0 which uses cub 1.9.9 which pulsar doesn't work well with. Make imageio available for testing. Lint fixes. Fix some test verbosity. Avoid use of atomicAdd_block on older GPUs. Reviewed By: nikhilaravi, classner Differential Revision: D24773716 fbshipit-source-id: 2428356bb2e62735f2bc0c15cbe4cff35b1b24b8
226 lines
7.0 KiB
Python
Executable File
226 lines
7.0 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
"""
|
|
This example demonstrates multiview 3D reconstruction using the plain
|
|
pulsar interface. For this, reference images have been pre-generated
|
|
(you can find them at
|
|
`../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png`).
|
|
The camera parameters are assumed given. The scene is initialized with
|
|
random spheres. Gradient-based optimization is used to optimize sphere
|
|
parameters and prune spheres to converge to a 3D representation.
|
|
|
|
This example is not available yet through the 'unified' interface,
|
|
because opacity support has not landed in PyTorch3D for general data
|
|
structures yet.
|
|
"""
|
|
import logging
|
|
import math
|
|
from os import path
|
|
|
|
import cv2
|
|
import imageio
|
|
import numpy as np
|
|
import torch
|
|
from pytorch3d.renderer.points.pulsar import Renderer
|
|
from torch import nn, optim
|
|
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
N_POINTS = 400_000
|
|
WIDTH = 1_000
|
|
HEIGHT = 1_000
|
|
VISUALIZE_IDS = [0, 1]
|
|
DEVICE = torch.device("cuda")
|
|
|
|
|
|
class SceneModel(nn.Module):
|
|
"""
|
|
A simple scene model to demonstrate use of pulsar in PyTorch modules.
|
|
|
|
The scene model is parameterized with sphere locations (vert_pos),
|
|
channel content (vert_col), radiuses (vert_rad), camera position (cam_pos),
|
|
camera rotation (cam_rot) and sensor focal length and width (cam_sensor).
|
|
|
|
The forward method of the model renders this scene description. Any
|
|
of these parameters could instead be passed as inputs to the forward
|
|
method and come from a different model. Optionally, camera parameters can
|
|
be provided to the forward method in which case the scene is rendered
|
|
using those parameters.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(SceneModel, self).__init__()
|
|
self.gamma = 1.0
|
|
# Points.
|
|
torch.manual_seed(1)
|
|
vert_pos = torch.rand((1, N_POINTS, 3), dtype=torch.float32) * 10.0
|
|
vert_pos[:, :, 2] += 25.0
|
|
vert_pos[:, :, :2] -= 5.0
|
|
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
|
|
self.register_parameter(
|
|
"vert_col",
|
|
nn.Parameter(
|
|
torch.ones(1, N_POINTS, 3, dtype=torch.float32) * 0.5,
|
|
requires_grad=True,
|
|
),
|
|
)
|
|
self.register_parameter(
|
|
"vert_rad",
|
|
nn.Parameter(
|
|
torch.ones(1, N_POINTS, dtype=torch.float32) * 0.05, requires_grad=True
|
|
),
|
|
)
|
|
self.register_parameter(
|
|
"vert_opy",
|
|
nn.Parameter(
|
|
torch.ones(1, N_POINTS, dtype=torch.float32), requires_grad=True
|
|
),
|
|
)
|
|
self.register_buffer(
|
|
"cam_params",
|
|
torch.tensor(
|
|
[
|
|
[
|
|
np.sin(angle) * 35.0,
|
|
0.0,
|
|
30.0 - np.cos(angle) * 35.0,
|
|
0.0,
|
|
-angle + math.pi,
|
|
0.0,
|
|
5.0,
|
|
2.0,
|
|
]
|
|
for angle in [-1.5, -0.8, -0.4, -0.1, 0.1, 0.4, 0.8, 1.5]
|
|
],
|
|
dtype=torch.float32,
|
|
),
|
|
)
|
|
self.renderer = Renderer(WIDTH, HEIGHT, N_POINTS, right_handed_system=True)
|
|
|
|
def forward(self, cam=None):
|
|
if cam is None:
|
|
cam = self.cam_params
|
|
n_views = 8
|
|
else:
|
|
n_views = 1
|
|
return self.renderer.forward(
|
|
self.vert_pos.expand(n_views, -1, -1),
|
|
self.vert_col.expand(n_views, -1, -1),
|
|
self.vert_rad.expand(n_views, -1),
|
|
cam,
|
|
self.gamma,
|
|
45.0,
|
|
)
|
|
|
|
|
|
def cli():
|
|
"""
|
|
Simple demonstration for a multi-view 3D reconstruction using pulsar.
|
|
|
|
This example makes use of opacity, which is not yet supported through
|
|
the unified PyTorch3D interface.
|
|
|
|
Writes to `multiview.gif`.
|
|
"""
|
|
LOGGER.info("Loading reference...")
|
|
# Load reference.
|
|
ref = torch.stack(
|
|
[
|
|
torch.from_numpy(
|
|
imageio.imread(
|
|
"../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png"
|
|
% idx
|
|
)
|
|
).to(torch.float32)
|
|
/ 255.0
|
|
for idx in range(8)
|
|
]
|
|
).to(DEVICE)
|
|
# Set up model.
|
|
model = SceneModel().to(DEVICE)
|
|
# Optimizer.
|
|
optimizer = optim.SGD(
|
|
[
|
|
{"params": [model.vert_col], "lr": 1e-1},
|
|
{"params": [model.vert_rad], "lr": 1e-3},
|
|
{"params": [model.vert_pos], "lr": 1e-3},
|
|
]
|
|
)
|
|
|
|
# For visualization.
|
|
angle = 0.0
|
|
LOGGER.info("Writing video to `%s`.", path.abspath("multiview.avi"))
|
|
writer = imageio.get_writer("multiview.gif", format="gif", fps=25)
|
|
|
|
# Optimize.
|
|
for i in range(300):
|
|
optimizer.zero_grad()
|
|
result = model()
|
|
# Visualize.
|
|
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
|
cv2.imshow("opt", result_im[0, :, :, ::-1])
|
|
overlay_img = np.ascontiguousarray(
|
|
((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[
|
|
0, :, :, ::-1
|
|
]
|
|
)
|
|
overlay_img = cv2.putText(
|
|
overlay_img,
|
|
"Step %d" % (i),
|
|
(10, 40),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
1,
|
|
(0, 0, 0),
|
|
2,
|
|
cv2.LINE_AA,
|
|
False,
|
|
)
|
|
cv2.imshow("overlay", overlay_img)
|
|
cv2.waitKey(1)
|
|
# Update.
|
|
loss = ((result - ref) ** 2).sum()
|
|
LOGGER.info("loss %d: %f", i, loss.item())
|
|
loss.backward()
|
|
optimizer.step()
|
|
# Cleanup.
|
|
with torch.no_grad():
|
|
model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0)
|
|
# Remove points.
|
|
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
|
|
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
|
|
vd = (
|
|
(model.vert_col - torch.ones(1, 1, 3, dtype=torch.float32).to(DEVICE))
|
|
.abs()
|
|
.sum(dim=2)
|
|
)
|
|
model.vert_pos.data[vd <= 0.2] = -1000.0
|
|
# Rotating visualization.
|
|
cam_control = torch.tensor(
|
|
[
|
|
[
|
|
np.sin(angle) * 35.0,
|
|
0.0,
|
|
30.0 - np.cos(angle) * 35.0,
|
|
0.0,
|
|
-angle + math.pi,
|
|
0.0,
|
|
5.0,
|
|
2.0,
|
|
]
|
|
],
|
|
dtype=torch.float32,
|
|
).to(DEVICE)
|
|
with torch.no_grad():
|
|
result = model.forward(cam=cam_control)[0]
|
|
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
|
|
cv2.imshow("vis", result_im[:, :, ::-1])
|
|
writer.append_data(result_im)
|
|
angle += 0.05
|
|
writer.close()
|
|
LOGGER.info("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
cli()
|