Example and test updates.

Summary: This commit performs pulsar example and test refinements. The examples are fully adjusted to adhere to PEP style guide and additional comments are added.

Reviewed By: nikhilaravi

Differential Revision: D24723391

fbshipit-source-id: 6d289006f080140159731e7f3a8c98b582164f1a
This commit is contained in:
Christoph Lassner 2020-11-04 09:53:19 -08:00 committed by Facebook GitHub Bot
parent e9a26f263a
commit b6be3b95fb
9 changed files with 569 additions and 448 deletions

View File

@ -7,14 +7,24 @@ Output: basic.png.
""" """
import math import math
from os import path from os import path
import logging
import imageio import imageio
import torch import torch
from pytorch3d.renderer.points.pulsar import Renderer from pytorch3d.renderer.points.pulsar import Renderer
torch.manual_seed(1) LOGGER = logging.getLogger(__name__)
def cli():
"""
Basic example for the pulsar sphere renderer.
Writes to `basic.png`.
"""
LOGGER.info("Rendering on GPU...")
torch.manual_seed(1)
n_points = 10 n_points = 10
width = 1_000 width = 1_000
height = 1_000 height = 1_000
@ -51,5 +61,11 @@ image = renderer(
1.0e-1, # Renderer blending parameter gamma, in [1., 1e-5]. 1.0e-1, # Renderer blending parameter gamma, in [1., 1e-5].
45.0, # Maximum depth. 45.0, # Maximum depth.
) )
print("Writing image to `%s`." % (path.abspath("basic.png"))) LOGGER.info("Writing image to `%s`.", path.abspath("basic.png"))
imageio.imsave("basic.png", (image.cpu().detach() * 255.0).to(torch.uint8).numpy()) imageio.imsave("basic.png", (image.cpu().detach() * 255.0).to(torch.uint8).numpy())
LOGGER.info("Done.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cli()

View File

@ -6,10 +6,14 @@ interface for sphere renderering. It renders and saves an image with
10 random spheres. 10 random spheres.
Output: basic-pt3d.png. Output: basic-pt3d.png.
""" """
import logging
from os import path from os import path
import imageio import imageio
import torch import torch
# Import `look_at_view_transform` as needed in the suggestion later in the
# example.
from pytorch3d.renderer import PerspectiveCameras # , look_at_view_transform from pytorch3d.renderer import PerspectiveCameras # , look_at_view_transform
from pytorch3d.renderer import ( from pytorch3d.renderer import (
PointsRasterizationSettings, PointsRasterizationSettings,
@ -19,13 +23,21 @@ from pytorch3d.renderer import (
from pytorch3d.structures import Pointclouds from pytorch3d.structures import Pointclouds
torch.manual_seed(1) LOGGER = logging.getLogger(__name__)
def cli():
"""
Basic example for the pulsar sphere renderer using the PyTorch3D interface.
Writes to `basic-pt3d.png`.
"""
LOGGER.info("Rendering on GPU...")
torch.manual_seed(1)
n_points = 10 n_points = 10
width = 1_000 width = 1_000
height = 1_000 height = 1_000
device = torch.device("cuda") device = torch.device("cuda")
# Generate sample data. # Generate sample data.
vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0 vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0
vert_pos[:, 2] += 25.0 vert_pos[:, 2] += 25.0
@ -63,5 +75,13 @@ image = renderer(
radius_world=True, radius_world=True,
bg_col=torch.ones((3,), dtype=torch.float32, device=device), bg_col=torch.ones((3,), dtype=torch.float32, device=device),
)[0] )[0]
print("Writing image to `%s`." % (path.abspath("basic-pt3d.png"))) LOGGER.info("Writing image to `%s`.", path.abspath("basic-pt3d.png"))
imageio.imsave("basic-pt3d.png", (image.cpu().detach() * 255.0).to(torch.uint8).numpy()) imageio.imsave(
"basic-pt3d.png", (image.cpu().detach() * 255.0).to(torch.uint8).numpy()
)
LOGGER.info("Done.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cli()

View File

@ -9,6 +9,7 @@ distorted. Gradient-based optimization is used to converge towards the
original camera parameters. original camera parameters.
Output: cam.gif. Output: cam.gif.
""" """
import logging
import math import math
from os import path from os import path
@ -21,10 +22,11 @@ from pytorch3d.transforms import axis_angle_to_matrix, matrix_to_rotation_6d
from torch import nn, optim from torch import nn, optim
n_points = 20 LOGGER = logging.getLogger(__name__)
width = 1_000 N_POINTS = 20
height = 1_000 WIDTH = 1_000
device = torch.device("cuda") HEIGHT = 1_000
DEVICE = torch.device("cuda")
class SceneModel(nn.Module): class SceneModel(nn.Module):
@ -45,20 +47,20 @@ class SceneModel(nn.Module):
self.gamma = 0.1 self.gamma = 0.1
# Points. # Points.
torch.manual_seed(1) torch.manual_seed(1)
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 vert_pos = torch.rand(N_POINTS, 3, dtype=torch.float32) * 10.0
vert_pos[:, 2] += 25.0 vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0 vert_pos[:, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False)) self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False))
self.register_parameter( self.register_parameter(
"vert_col", "vert_col",
nn.Parameter( nn.Parameter(
torch.rand(n_points, 3, dtype=torch.float32), requires_grad=False torch.rand(N_POINTS, 3, dtype=torch.float32), requires_grad=False
), ),
) )
self.register_parameter( self.register_parameter(
"vert_rad", "vert_rad",
nn.Parameter( nn.Parameter(
torch.rand(n_points, dtype=torch.float32), requires_grad=False torch.rand(N_POINTS, dtype=torch.float32), requires_grad=False
), ),
) )
self.register_parameter( self.register_parameter(
@ -90,7 +92,7 @@ class SceneModel(nn.Module):
torch.tensor([4.8, 1.8], dtype=torch.float32), requires_grad=True torch.tensor([4.8, 1.8], dtype=torch.float32), requires_grad=True
), ),
) )
self.renderer = Renderer(width, height, n_points, right_handed_system=True) self.renderer = Renderer(WIDTH, HEIGHT, N_POINTS, right_handed_system=True)
def forward(self): def forward(self):
return self.renderer.forward( return self.renderer.forward(
@ -103,6 +105,13 @@ class SceneModel(nn.Module):
) )
def cli():
"""
Camera optimization example using pulsar.
Writes to `cam.gif`.
"""
LOGGER.info("Loading reference...")
# Load reference. # Load reference.
ref = ( ref = (
torch.from_numpy( torch.from_numpy(
@ -111,9 +120,9 @@ ref = (
)[:, ::-1, :].copy() )[:, ::-1, :].copy()
).to(torch.float32) ).to(torch.float32)
/ 255.0 / 255.0
).to(device) ).to(DEVICE)
# Set up model. # Set up model.
model = SceneModel().to(device) model = SceneModel().to(DEVICE)
# Optimizer. # Optimizer.
optimizer = optim.SGD( optimizer = optim.SGD(
[ [
@ -123,7 +132,7 @@ optimizer = optim.SGD(
] ]
) )
print("Writing video to `%s`." % (path.abspath("cam.gif"))) LOGGER.info("Writing video to `%s`.", path.abspath("cam.gif"))
writer = imageio.get_writer("cam.gif", format="gif", fps=25) writer = imageio.get_writer("cam.gif", format="gif", fps=25)
# Optimize. # Optimize.
@ -154,7 +163,13 @@ for i in range(300):
cv2.waitKey(1) cv2.waitKey(1)
# Update. # Update.
loss = ((result - ref) ** 2).sum() loss = ((result - ref) ** 2).sum()
print("loss {}: {}".format(i, loss.item())) LOGGER.info("loss %d: %f", i, loss.item())
loss.backward() loss.backward()
optimizer.step() optimizer.step()
writer.close() writer.close()
LOGGER.info("Done.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cli()

View File

@ -10,11 +10,15 @@ original camera parameters.
Output: cam-pt3d.gif Output: cam-pt3d.gif
""" """
from os import path from os import path
import logging
import cv2 import cv2
import imageio import imageio
import numpy as np import numpy as np
import torch import torch
# Import `look_at_view_transform` as needed in the suggestion later in the
# example.
from pytorch3d.renderer.cameras import PerspectiveCameras # , look_at_view_transform from pytorch3d.renderer.cameras import PerspectiveCameras # , look_at_view_transform
from pytorch3d.renderer.points import ( from pytorch3d.renderer.points import (
PointsRasterizationSettings, PointsRasterizationSettings,
@ -26,10 +30,11 @@ from pytorch3d.transforms import axis_angle_to_matrix
from torch import nn, optim from torch import nn, optim
n_points = 20 LOGGER = logging.getLogger(__name__)
width = 1_000 N_POINTS = 20
height = 1_000 WIDTH = 1_000
device = torch.device("cuda") HEIGHT = 1_000
DEVICE = torch.device("cuda")
class SceneModel(nn.Module): class SceneModel(nn.Module):
@ -50,21 +55,21 @@ class SceneModel(nn.Module):
self.gamma = 0.1 self.gamma = 0.1
# Points. # Points.
torch.manual_seed(1) torch.manual_seed(1)
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 vert_pos = torch.rand(N_POINTS, 3, dtype=torch.float32) * 10.0
vert_pos[:, 2] += 25.0 vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0 vert_pos[:, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False)) self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=False))
self.register_parameter( self.register_parameter(
"vert_col", "vert_col",
nn.Parameter( nn.Parameter(
torch.rand(n_points, 3, dtype=torch.float32), torch.rand(N_POINTS, 3, dtype=torch.float32),
requires_grad=False, requires_grad=False,
), ),
) )
self.register_parameter( self.register_parameter(
"vert_rad", "vert_rad",
nn.Parameter( nn.Parameter(
torch.rand(n_points, dtype=torch.float32), torch.rand(N_POINTS, dtype=torch.float32),
requires_grad=False, requires_grad=False,
), ),
) )
@ -118,11 +123,11 @@ class SceneModel(nn.Module):
focal_length=self.focal_length, focal_length=self.focal_length,
R=self.cam_rot[None, ...], R=self.cam_rot[None, ...],
T=self.cam_pos[None, ...], T=self.cam_pos[None, ...],
image_size=((width, height),), image_size=((WIDTH, HEIGHT),),
device=device, device=DEVICE,
) )
raster_settings = PointsRasterizationSettings( raster_settings = PointsRasterizationSettings(
image_size=(width, height), image_size=(WIDTH, HEIGHT),
radius=self.vert_rad, radius=self.vert_rad,
) )
rasterizer = PointsRasterizer( rasterizer = PointsRasterizer(
@ -142,7 +147,7 @@ class SceneModel(nn.Module):
zfar=(45.0,), zfar=(45.0,),
znear=(1.0,), znear=(1.0,),
radius_world=True, radius_world=True,
bg_col=torch.ones((3,), dtype=torch.float32, device=device), bg_col=torch.ones((3,), dtype=torch.float32, device=DEVICE),
# As mentioned above: workaround for device placement of gradients for # As mentioned above: workaround for device placement of gradients for
# camera parameters. # camera parameters.
focal_length=self.focal_length, focal_length=self.focal_length,
@ -151,6 +156,13 @@ class SceneModel(nn.Module):
)[0] )[0]
def cli():
"""
Camera optimization example using pulsar.
Writes to `cam.gif`.
"""
LOGGER.info("Loading reference...")
# Load reference. # Load reference.
ref = ( ref = (
torch.from_numpy( torch.from_numpy(
@ -159,9 +171,9 @@ ref = (
)[:, ::-1, :].copy() )[:, ::-1, :].copy()
).to(torch.float32) ).to(torch.float32)
/ 255.0 / 255.0
).to(device) ).to(DEVICE)
# Set up model. # Set up model.
model = SceneModel().to(device) model = SceneModel().to(DEVICE)
# Optimizer. # Optimizer.
optimizer = optim.SGD( optimizer = optim.SGD(
[ [
@ -173,7 +185,7 @@ optimizer = optim.SGD(
] ]
) )
print("Writing video to `%s`." % (path.abspath("cam-pt3d.gif"))) LOGGER.info("Writing video to `%s`.", path.abspath("cam-pt3d.gif"))
writer = imageio.get_writer("cam-pt3d.gif", format="gif", fps=25) writer = imageio.get_writer("cam-pt3d.gif", format="gif", fps=25)
# Optimize. # Optimize.
@ -204,7 +216,13 @@ for i in range(300):
cv2.waitKey(1) cv2.waitKey(1)
# Update. # Update.
loss = ((result - ref) ** 2).sum() loss = ((result - ref) ** 2).sum()
print("loss {}: {}".format(i, loss.item())) LOGGER.info("loss %d: %f", i, loss.item())
loss.backward() loss.backward()
optimizer.step() optimizer.step()
writer.close() writer.close()
LOGGER.info("Done.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cli()

View File

@ -3,7 +3,8 @@
""" """
This example demonstrates multiview 3D reconstruction using the plain This example demonstrates multiview 3D reconstruction using the plain
pulsar interface. For this, reference images have been pre-generated pulsar interface. For this, reference images have been pre-generated
(you can find them at `../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png`). (you can find them at
`../../tests/pulsar/reference/examples_TestRenderer_test_multiview_%d.png`).
The camera parameters are assumed given. The scene is initialized with The camera parameters are assumed given. The scene is initialized with
random spheres. Gradient-based optimization is used to optimize sphere random spheres. Gradient-based optimization is used to optimize sphere
parameters and prune spheres to converge to a 3D representation. parameters and prune spheres to converge to a 3D representation.
@ -14,6 +15,7 @@ structures yet.
""" """
import math import math
from os import path from os import path
import logging
import cv2 import cv2
import imageio import imageio
@ -23,11 +25,12 @@ from pytorch3d.renderer.points.pulsar import Renderer
from torch import nn, optim from torch import nn, optim
n_points = 400_000 LOGGER = logging.getLogger(__name__)
width = 1_000 N_POINTS = 400_000
height = 1_000 WIDTH = 1_000
visualize_ids = [0, 1] HEIGHT = 1_000
device = torch.device("cuda") VISUALIZE_IDS = [0, 1]
DEVICE = torch.device("cuda")
class SceneModel(nn.Module): class SceneModel(nn.Module):
@ -50,27 +53,27 @@ class SceneModel(nn.Module):
self.gamma = 1.0 self.gamma = 1.0
# Points. # Points.
torch.manual_seed(1) torch.manual_seed(1)
vert_pos = torch.rand((1, n_points, 3), dtype=torch.float32) * 10.0 vert_pos = torch.rand((1, N_POINTS, 3), dtype=torch.float32) * 10.0
vert_pos[:, :, 2] += 25.0 vert_pos[:, :, 2] += 25.0
vert_pos[:, :, :2] -= 5.0 vert_pos[:, :, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True)) self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
self.register_parameter( self.register_parameter(
"vert_col", "vert_col",
nn.Parameter( nn.Parameter(
torch.ones(1, n_points, 3, dtype=torch.float32) * 0.5, torch.ones(1, N_POINTS, 3, dtype=torch.float32) * 0.5,
requires_grad=True, requires_grad=True,
), ),
) )
self.register_parameter( self.register_parameter(
"vert_rad", "vert_rad",
nn.Parameter( nn.Parameter(
torch.ones(1, n_points, dtype=torch.float32) * 0.05, requires_grad=True torch.ones(1, N_POINTS, dtype=torch.float32) * 0.05, requires_grad=True
), ),
) )
self.register_parameter( self.register_parameter(
"vert_opy", "vert_opy",
nn.Parameter( nn.Parameter(
torch.ones(1, n_points, dtype=torch.float32), requires_grad=True torch.ones(1, N_POINTS, dtype=torch.float32), requires_grad=True
), ),
) )
self.register_buffer( self.register_buffer(
@ -92,7 +95,7 @@ class SceneModel(nn.Module):
dtype=torch.float32, dtype=torch.float32,
), ),
) )
self.renderer = Renderer(width, height, n_points, right_handed_system=True) self.renderer = Renderer(WIDTH, HEIGHT, N_POINTS, right_handed_system=True)
def forward(self, cam=None): def forward(self, cam=None):
if cam is None: if cam is None:
@ -110,6 +113,16 @@ class SceneModel(nn.Module):
) )
def cli():
"""
Simple demonstration for a multi-view 3D reconstruction using pulsar.
This example makes use of opacity, which is not yet supported through
the unified PyTorch3D interface.
Writes to `multiview.gif`.
"""
LOGGER.info("Loading reference...")
# Load reference. # Load reference.
ref = torch.stack( ref = torch.stack(
[ [
@ -122,9 +135,9 @@ ref = torch.stack(
/ 255.0 / 255.0
for idx in range(8) for idx in range(8)
] ]
).to(device) ).to(DEVICE)
# Set up model. # Set up model.
model = SceneModel().to(device) model = SceneModel().to(DEVICE)
# Optimizer. # Optimizer.
optimizer = optim.SGD( optimizer = optim.SGD(
[ [
@ -136,7 +149,7 @@ optimizer = optim.SGD(
# For visualization. # For visualization.
angle = 0.0 angle = 0.0
print("Writing video to `%s`." % (path.abspath("multiview.avi"))) LOGGER.info("Writing video to `%s`.", path.abspath("multiview.avi"))
writer = imageio.get_writer("multiview.gif", format="gif", fps=25) writer = imageio.get_writer("multiview.gif", format="gif", fps=25)
# Optimize. # Optimize.
@ -166,7 +179,7 @@ for i in range(300):
cv2.waitKey(1) cv2.waitKey(1)
# Update. # Update.
loss = ((result - ref) ** 2).sum() loss = ((result - ref) ** 2).sum()
print("loss {}: {}".format(i, loss.item())) LOGGER.info("loss %d: %f", i, loss.item())
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# Cleanup. # Cleanup.
@ -176,7 +189,7 @@ for i in range(300):
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0 model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001 model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
vd = ( vd = (
(model.vert_col - torch.ones(1, 1, 3, dtype=torch.float32).to(device)) (model.vert_col - torch.ones(1, 1, 3, dtype=torch.float32).to(DEVICE))
.abs() .abs()
.sum(dim=2) .sum(dim=2)
) )
@ -196,7 +209,7 @@ for i in range(300):
] ]
], ],
dtype=torch.float32, dtype=torch.float32,
).to(device) ).to(DEVICE)
with torch.no_grad(): with torch.no_grad():
result = model.forward(cam=cam_control)[0] result = model.forward(cam=cam_control)[0]
result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
@ -204,3 +217,9 @@ for i in range(300):
writer.append_data(result_im) writer.append_data(result_im)
angle += 0.05 angle += 0.05
writer.close() writer.close()
LOGGER.info("Done.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cli()

View File

@ -9,6 +9,7 @@ optimization is used to converge towards a faithful
scene representation. scene representation.
""" """
import math import math
import logging
import cv2 import cv2
import imageio import imageio
@ -18,10 +19,11 @@ from pytorch3d.renderer.points.pulsar import Renderer
from torch import nn, optim from torch import nn, optim
n_points = 10_000 LOGGER = logging.getLogger(__name__)
width = 1_000 N_POINTS = 10_000
height = 1_000 WIDTH = 1_000
device = torch.device("cuda") HEIGHT = 1_000
DEVICE = torch.device("cuda")
class SceneModel(nn.Module): class SceneModel(nn.Module):
@ -42,20 +44,20 @@ class SceneModel(nn.Module):
self.gamma = 1.0 self.gamma = 1.0
# Points. # Points.
torch.manual_seed(1) torch.manual_seed(1)
vert_pos = torch.rand(n_points, 3, dtype=torch.float32) * 10.0 vert_pos = torch.rand(N_POINTS, 3, dtype=torch.float32) * 10.0
vert_pos[:, 2] += 25.0 vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0 vert_pos[:, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True)) self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
self.register_parameter( self.register_parameter(
"vert_col", "vert_col",
nn.Parameter( nn.Parameter(
torch.ones(n_points, 3, dtype=torch.float32) * 0.5, requires_grad=True torch.ones(N_POINTS, 3, dtype=torch.float32) * 0.5, requires_grad=True
), ),
) )
self.register_parameter( self.register_parameter(
"vert_rad", "vert_rad",
nn.Parameter( nn.Parameter(
torch.ones(n_points, dtype=torch.float32) * 0.3, requires_grad=True torch.ones(N_POINTS, dtype=torch.float32) * 0.3, requires_grad=True
), ),
) )
self.register_buffer( self.register_buffer(
@ -67,7 +69,7 @@ class SceneModel(nn.Module):
# The volumetric optimization works better with a higher number of tracked # The volumetric optimization works better with a higher number of tracked
# intersections per ray. # intersections per ray.
self.renderer = Renderer( self.renderer = Renderer(
width, height, n_points, n_track=32, right_handed_system=True WIDTH, HEIGHT, N_POINTS, n_track=32, right_handed_system=True
) )
def forward(self): def forward(self):
@ -82,6 +84,11 @@ class SceneModel(nn.Module):
) )
def cli():
"""
Scene optimization example using pulsar.
"""
LOGGER.info("Loading reference...")
# Load reference. # Load reference.
ref = ( ref = (
torch.from_numpy( torch.from_numpy(
@ -90,9 +97,9 @@ ref = (
)[:, ::-1, :].copy() )[:, ::-1, :].copy()
).to(torch.float32) ).to(torch.float32)
/ 255.0 / 255.0
).to(device) ).to(DEVICE)
# Set up model. # Set up model.
model = SceneModel().to(device) model = SceneModel().to(DEVICE)
# Optimizer. # Optimizer.
optimizer = optim.SGD( optimizer = optim.SGD(
[ [
@ -101,7 +108,7 @@ optimizer = optim.SGD(
{"params": [model.vert_pos], "lr": 1e-2}, {"params": [model.vert_pos], "lr": 1e-2},
] ]
) )
LOGGER.info("Optimizing...")
# Optimize. # Optimize.
for i in range(500): for i in range(500):
optimizer.zero_grad() optimizer.zero_grad()
@ -129,7 +136,7 @@ for i in range(500):
cv2.waitKey(1) cv2.waitKey(1)
# Update. # Update.
loss = ((result - ref) ** 2).sum() loss = ((result - ref) ** 2).sum()
print("loss {}: {}".format(i, loss.item())) LOGGER.info("loss %d: %f", i, loss.item())
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# Cleanup. # Cleanup.
@ -139,8 +146,14 @@ for i in range(500):
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0 model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001 model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
vd = ( vd = (
(model.vert_col - torch.ones(3, dtype=torch.float32).to(device)) (model.vert_col - torch.ones(3, dtype=torch.float32).to(DEVICE))
.abs() .abs()
.sum(dim=1) .sum(dim=1)
) )
model.vert_pos.data[vd <= 0.2] = -1000.0 model.vert_pos.data[vd <= 0.2] = -1000.0
LOGGER.info("Done.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cli()

View File

@ -9,11 +9,15 @@ optimization is used to converge towards a faithful
scene representation. scene representation.
""" """
import math import math
import logging
import cv2 import cv2
import imageio import imageio
import numpy as np import numpy as np
import torch import torch
# Import `look_at_view_transform` as needed in the suggestion later in the
# example.
from pytorch3d.renderer.cameras import PerspectiveCameras # , look_at_view_transform from pytorch3d.renderer.cameras import PerspectiveCameras # , look_at_view_transform
from pytorch3d.renderer.points import ( from pytorch3d.renderer.points import (
PointsRasterizationSettings, PointsRasterizationSettings,
@ -24,10 +28,11 @@ from pytorch3d.structures.pointclouds import Pointclouds
from torch import nn, optim from torch import nn, optim
n_points = 10_000 LOGGER = logging.getLogger(__name__)
width = 1_000 N_POINTS = 10_000
height = 1_000 WIDTH = 1_000
device = torch.device("cuda") HEIGHT = 1_000
DEVICE = torch.device("cuda")
class SceneModel(nn.Module): class SceneModel(nn.Module):
@ -48,21 +53,21 @@ class SceneModel(nn.Module):
self.gamma = 1.0 self.gamma = 1.0
# Points. # Points.
torch.manual_seed(1) torch.manual_seed(1)
vert_pos = torch.rand(n_points, 3, dtype=torch.float32, device=device) * 10.0 vert_pos = torch.rand(N_POINTS, 3, dtype=torch.float32, device=DEVICE) * 10.0
vert_pos[:, 2] += 25.0 vert_pos[:, 2] += 25.0
vert_pos[:, :2] -= 5.0 vert_pos[:, :2] -= 5.0
self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True)) self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True))
self.register_parameter( self.register_parameter(
"vert_col", "vert_col",
nn.Parameter( nn.Parameter(
torch.ones(n_points, 3, dtype=torch.float32, device=device) * 0.5, torch.ones(N_POINTS, 3, dtype=torch.float32, device=DEVICE) * 0.5,
requires_grad=True, requires_grad=True,
), ),
) )
self.register_parameter( self.register_parameter(
"vert_rad", "vert_rad",
nn.Parameter( nn.Parameter(
torch.ones(n_points, dtype=torch.float32) * 0.3, requires_grad=True torch.ones(N_POINTS, dtype=torch.float32) * 0.3, requires_grad=True
), ),
) )
self.register_buffer( self.register_buffer(
@ -77,13 +82,13 @@ class SceneModel(nn.Module):
# sensor width (see the pulsar example). This means we need here # sensor width (see the pulsar example). This means we need here
# 5.0 * 2.0 / 2.0 to get the equivalent results as in pulsar. # 5.0 * 2.0 / 2.0 to get the equivalent results as in pulsar.
focal_length=5.0, focal_length=5.0,
R=torch.eye(3, dtype=torch.float32, device=device)[None, ...], R=torch.eye(3, dtype=torch.float32, device=DEVICE)[None, ...],
T=torch.zeros((1, 3), dtype=torch.float32, device=device), T=torch.zeros((1, 3), dtype=torch.float32, device=DEVICE),
image_size=((width, height),), image_size=((WIDTH, HEIGHT),),
device=device, device=DEVICE,
) )
raster_settings = PointsRasterizationSettings( raster_settings = PointsRasterizationSettings(
image_size=(width, height), image_size=(WIDTH, HEIGHT),
radius=self.vert_rad, radius=self.vert_rad,
) )
rasterizer = PointsRasterizer( rasterizer = PointsRasterizer(
@ -103,10 +108,15 @@ class SceneModel(nn.Module):
zfar=(45.0,), zfar=(45.0,),
znear=(1.0,), znear=(1.0,),
radius_world=True, radius_world=True,
bg_col=torch.ones((3,), dtype=torch.float32, device=device), bg_col=torch.ones((3,), dtype=torch.float32, device=DEVICE),
)[0] )[0]
def cli():
"""
Scene optimization example using pulsar and the unified PyTorch3D interface.
"""
LOGGER.info("Loading reference...")
# Load reference. # Load reference.
ref = ( ref = (
torch.from_numpy( torch.from_numpy(
@ -115,9 +125,9 @@ ref = (
)[:, ::-1, :].copy() )[:, ::-1, :].copy()
).to(torch.float32) ).to(torch.float32)
/ 255.0 / 255.0
).to(device) ).to(DEVICE)
# Set up model. # Set up model.
model = SceneModel().to(device) model = SceneModel().to(DEVICE)
# Optimizer. # Optimizer.
optimizer = optim.SGD( optimizer = optim.SGD(
[ [
@ -126,7 +136,7 @@ optimizer = optim.SGD(
{"params": [model.vert_pos], "lr": 1e-2}, {"params": [model.vert_pos], "lr": 1e-2},
] ]
) )
LOGGER.info("Optimizing...")
# Optimize. # Optimize.
for i in range(500): for i in range(500):
optimizer.zero_grad() optimizer.zero_grad()
@ -154,7 +164,7 @@ for i in range(500):
cv2.waitKey(1) cv2.waitKey(1)
# Update. # Update.
loss = ((result - ref) ** 2).sum() loss = ((result - ref) ** 2).sum()
print("loss {}: {}".format(i, loss.item())) LOGGER.info("loss %d: %f", i, loss.item())
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# Cleanup. # Cleanup.
@ -164,8 +174,14 @@ for i in range(500):
model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0 model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0
model.vert_rad.data[model.vert_rad < 0.001] = 0.0001 model.vert_rad.data[model.vert_rad < 0.001] = 0.0001
vd = ( vd = (
(model.vert_col - torch.ones(3, dtype=torch.float32).to(device)) (model.vert_col - torch.ones(3, dtype=torch.float32).to(DEVICE))
.abs() .abs()
.sum(dim=1) .sum(dim=1)
) )
model.vert_pos.data[vd <= 0.2] = -1000.0 model.vert_pos.data[vd <= 0.2] = -1000.0
LOGGER.info("Done.")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cli()

View File

@ -44,6 +44,8 @@ class TestDepth(TestCaseMixin, unittest.TestCase):
n_channels=1, n_channels=1,
).to(device) ).to(device)
data = torch.load(IN_REF_FP, map_location="cpu") data = torch.load(IN_REF_FP, map_location="cpu")
# For creating the reference files.
# Use in case of updates.
# data["pos"] = torch.rand_like(data["pos"]) # data["pos"] = torch.rand_like(data["pos"])
# data["pos"][:, 0] = data["pos"][:, 0] * 2. - 1. # data["pos"][:, 0] = data["pos"][:, 0] * 2. - 1.
# data["pos"][:, 1] = data["pos"][:, 1] * 2. - 1. # data["pos"][:, 1] = data["pos"][:, 1] * 2. - 1.
@ -74,6 +76,8 @@ class TestDepth(TestCaseMixin, unittest.TestCase):
), ),
depth_vis.cpu().numpy().astype(np.uint8), depth_vis.cpu().numpy().astype(np.uint8),
) )
# For creating the reference files.
# Use in case of updates.
# torch.save( # torch.save(
# data, path.join(path.dirname(__file__), "reference", "nr0000-in.pth") # data, path.join(path.dirname(__file__), "reference", "nr0000-in.pth")
# ) # )

View File

@ -123,7 +123,7 @@ class TestSmallSpheres(unittest.TestCase):
self.assertTrue( self.assertTrue(
(sphere_ids == idx).sum() > 0, "Sphere ID %d missing!" % (idx) (sphere_ids == idx).sum() > 0, "Sphere ID %d missing!" % (idx)
) )
# Visualize. # Visualization code. Activate for debugging.
# result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) # result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8)
# cv2.imshow("res", result_im[0, :, :, ::-1]) # cv2.imshow("res", result_im[0, :, :, ::-1])
# cv2.waitKey(0) # cv2.waitKey(0)