mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 19:36:23 +08:00
Initial commit
fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
8
docs/tutorials/utils/__init__.py
Normal file
8
docs/tutorials/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .camera_visualisation import (
|
||||
get_camera_wireframe,
|
||||
plot_camera_scene,
|
||||
plot_cameras,
|
||||
)
|
||||
from .plot_image_grid import image_grid
|
||||
71
docs/tutorials/utils/camera_visualisation.py
Normal file
71
docs/tutorials/utils/camera_visualisation.py
Normal file
@@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import
|
||||
|
||||
|
||||
def get_camera_wireframe(scale: float = 0.3):
|
||||
"""
|
||||
Returns a wireframe of a 3D line-plot of a camera symbol.
|
||||
"""
|
||||
a = 0.5 * torch.tensor([-2, 1.5, 4])
|
||||
b = 0.5 * torch.tensor([2, 1.5, 4])
|
||||
c = 0.5 * torch.tensor([-2, -1.5, 4])
|
||||
d = 0.5 * torch.tensor([2, -1.5, 4])
|
||||
C = torch.zeros(3)
|
||||
F = torch.tensor([0, 0, 3])
|
||||
camera_points = [a, b, d, c, a, C, b, d, C, c, C, F]
|
||||
lines = torch.stack([x.float() for x in camera_points]) * scale
|
||||
return lines
|
||||
|
||||
|
||||
def plot_cameras(ax, cameras, color: str = "blue"):
|
||||
"""
|
||||
Plots a set of `cameras` objects into the maplotlib axis `ax` with
|
||||
color `color`.
|
||||
"""
|
||||
cam_wires_canonical = get_camera_wireframe().cuda()[None]
|
||||
cam_trans = cameras.get_world_to_view_transform().inverse()
|
||||
cam_wires_trans = cam_trans.transform_points(cam_wires_canonical)
|
||||
plot_handles = []
|
||||
for wire in cam_wires_trans:
|
||||
# the Z and Y axes are flipped intentionally here!
|
||||
x_, z_, y_ = wire.detach().cpu().numpy().T.astype(float)
|
||||
(h,) = ax.plot(x_, y_, z_, color=color, linewidth=0.3)
|
||||
plot_handles.append(h)
|
||||
return plot_handles
|
||||
|
||||
|
||||
def plot_camera_scene(cameras, cameras_gt, status: str):
|
||||
"""
|
||||
Plots a set of predicted cameras `cameras` and their corresponding
|
||||
ground truth locations `cameras_gt`. The plot is named with
|
||||
a string passed inside the `status` argument.
|
||||
"""
|
||||
fig = plt.figure()
|
||||
ax = fig.gca(projection="3d")
|
||||
ax.clear()
|
||||
ax.set_title(status)
|
||||
handle_cam = plot_cameras(ax, cameras, color="#FF7D1E")
|
||||
handle_cam_gt = plot_cameras(ax, cameras_gt, color="#812CE5")
|
||||
plot_radius = 3
|
||||
ax.set_xlim3d([-plot_radius, plot_radius])
|
||||
ax.set_ylim3d([3 - plot_radius, 3 + plot_radius])
|
||||
ax.set_zlim3d([-plot_radius, plot_radius])
|
||||
ax.set_xlabel("x")
|
||||
ax.set_ylabel("z")
|
||||
ax.set_zlabel("y")
|
||||
labels_handles = {
|
||||
"Estimated cameras": handle_cam[0],
|
||||
"GT cameras": handle_cam_gt[0],
|
||||
}
|
||||
ax.legend(
|
||||
labels_handles.values(),
|
||||
labels_handles.keys(),
|
||||
loc="upper center",
|
||||
bbox_to_anchor=(0.5, 0),
|
||||
)
|
||||
plt.show()
|
||||
return fig
|
||||
54
docs/tutorials/utils/plot_image_grid.py
Normal file
54
docs/tutorials/utils/plot_image_grid.py
Normal file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def image_grid(
|
||||
images,
|
||||
rows=None,
|
||||
cols=None,
|
||||
fill: bool = True,
|
||||
show_axes: bool = False,
|
||||
rgb: bool = True,
|
||||
):
|
||||
"""
|
||||
A util function for plotting a grid of images.
|
||||
|
||||
Args:
|
||||
images: (N, H, W, 4) array of RGBA images
|
||||
rows: number of rows in the grid
|
||||
cols: number of columns in the grid
|
||||
fill: boolean indicating if the space between images should be filled
|
||||
show_axes: boolean indicating if the axes of the plots should be visible
|
||||
rgb: boolean, If True, only RGB channels are plotted.
|
||||
If False, only the alpha channel is plotted.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if (rows is None) != (cols is None):
|
||||
raise ValueError("Specify either both rows and cols or neither.")
|
||||
|
||||
if rows is None:
|
||||
rows = len(images)
|
||||
cols = 1
|
||||
|
||||
gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
|
||||
fig, axarr = plt.subplots(
|
||||
rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9)
|
||||
)
|
||||
bleed = 0
|
||||
fig.subplots_adjust(
|
||||
left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)
|
||||
)
|
||||
|
||||
for ax, im in zip(axarr.ravel(), images):
|
||||
if rgb:
|
||||
# only render RGB channels
|
||||
ax.imshow(im[..., :3])
|
||||
else:
|
||||
# only render Alpha channel
|
||||
ax.imshow(im[..., 3])
|
||||
if not show_axes:
|
||||
ax.set_axis_off()
|
||||
Reference in New Issue
Block a user