Initial commit

fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
facebook-github-bot
2020-01-23 11:53:41 -08:00
commit dbf06b504b
211 changed files with 47362 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .camera_visualisation import (
get_camera_wireframe,
plot_camera_scene,
plot_cameras,
)
from .plot_image_grid import image_grid

View File

@@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import matplotlib.pyplot as plt
import torch
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
def get_camera_wireframe(scale: float = 0.3):
"""
Returns a wireframe of a 3D line-plot of a camera symbol.
"""
a = 0.5 * torch.tensor([-2, 1.5, 4])
b = 0.5 * torch.tensor([2, 1.5, 4])
c = 0.5 * torch.tensor([-2, -1.5, 4])
d = 0.5 * torch.tensor([2, -1.5, 4])
C = torch.zeros(3)
F = torch.tensor([0, 0, 3])
camera_points = [a, b, d, c, a, C, b, d, C, c, C, F]
lines = torch.stack([x.float() for x in camera_points]) * scale
return lines
def plot_cameras(ax, cameras, color: str = "blue"):
"""
Plots a set of `cameras` objects into the maplotlib axis `ax` with
color `color`.
"""
cam_wires_canonical = get_camera_wireframe().cuda()[None]
cam_trans = cameras.get_world_to_view_transform().inverse()
cam_wires_trans = cam_trans.transform_points(cam_wires_canonical)
plot_handles = []
for wire in cam_wires_trans:
# the Z and Y axes are flipped intentionally here!
x_, z_, y_ = wire.detach().cpu().numpy().T.astype(float)
(h,) = ax.plot(x_, y_, z_, color=color, linewidth=0.3)
plot_handles.append(h)
return plot_handles
def plot_camera_scene(cameras, cameras_gt, status: str):
"""
Plots a set of predicted cameras `cameras` and their corresponding
ground truth locations `cameras_gt`. The plot is named with
a string passed inside the `status` argument.
"""
fig = plt.figure()
ax = fig.gca(projection="3d")
ax.clear()
ax.set_title(status)
handle_cam = plot_cameras(ax, cameras, color="#FF7D1E")
handle_cam_gt = plot_cameras(ax, cameras_gt, color="#812CE5")
plot_radius = 3
ax.set_xlim3d([-plot_radius, plot_radius])
ax.set_ylim3d([3 - plot_radius, 3 + plot_radius])
ax.set_zlim3d([-plot_radius, plot_radius])
ax.set_xlabel("x")
ax.set_ylabel("z")
ax.set_zlabel("y")
labels_handles = {
"Estimated cameras": handle_cam[0],
"GT cameras": handle_cam_gt[0],
}
ax.legend(
labels_handles.values(),
labels_handles.keys(),
loc="upper center",
bbox_to_anchor=(0.5, 0),
)
plt.show()
return fig

View File

@@ -0,0 +1,54 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import matplotlib.pyplot as plt
def image_grid(
images,
rows=None,
cols=None,
fill: bool = True,
show_axes: bool = False,
rgb: bool = True,
):
"""
A util function for plotting a grid of images.
Args:
images: (N, H, W, 4) array of RGBA images
rows: number of rows in the grid
cols: number of columns in the grid
fill: boolean indicating if the space between images should be filled
show_axes: boolean indicating if the axes of the plots should be visible
rgb: boolean, If True, only RGB channels are plotted.
If False, only the alpha channel is plotted.
Returns:
None
"""
if (rows is None) != (cols is None):
raise ValueError("Specify either both rows and cols or neither.")
if rows is None:
rows = len(images)
cols = 1
gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
fig, axarr = plt.subplots(
rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9)
)
bleed = 0
fig.subplots_adjust(
left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)
)
for ax, im in zip(axarr.ravel(), images):
if rgb:
# only render RGB channels
ax.imshow(im[..., :3])
else:
# only render Alpha channel
ax.imshow(im[..., 3])
if not show_axes:
ax.set_axis_off()