pytorch3d/pytorch3d/implicitron/tools/eval_video_trajectory.py
Jeremy Reizenstein cdd2142dd5
implicitron v0 (#1133)
Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
2022-03-21 13:20:10 -07:00

227 lines
7.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Tuple
import torch
from pytorch3d.common.compat import eigh
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform
from pytorch3d.transforms import Scale
def generate_eval_video_cameras(
train_cameras,
n_eval_cams: int = 100,
trajectory_type: str = "figure_eight",
trajectory_scale: float = 0.2,
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
focal_length: Optional[torch.FloatTensor] = None,
principal_point: Optional[torch.FloatTensor] = None,
time: Optional[torch.FloatTensor] = None,
infer_up_as_plane_normal: bool = True,
traj_offset: Optional[Tuple[float, float, float]] = None,
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
) -> PerspectiveCameras:
"""
Generate a camera trajectory rendering a scene from multiple viewpoints.
Args:
train_dataset: The training dataset object.
n_eval_cams: Number of cameras in the trajectory.
trajectory_type: The type of the camera trajectory. Can be one of:
circular_lsq_fit: Camera centers follow a trajectory obtained
by fitting a 3D circle to train_cameras centers.
All cameras are looking towards scene_center.
figure_eight: Figure-of-8 trajectory around the center of the
central camera of the training dataset.
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
of a figure-eight knot
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
trajectory_scale: The extent of the trajectory.
up: The "up" vector of the scene (=the normal of the scene floor).
Active for the `trajectory_type="circular"`.
scene_center: The center of the scene in world coordinates which all
the cameras from the generated trajectory look at.
Returns:
Dictionary of camera instances which can be used as the test dataset
"""
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
cam_centers = train_cameras.get_camera_center()
# get the nearest camera center to the mean of centers
mean_camera_idx = (
((cam_centers - cam_centers.mean(dim=0)[None]) ** 2)
.sum(dim=1)
.min(dim=0)
.indices
)
# generate the knot trajectory in canonical coords
if time is None:
time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams]
else:
assert time.numel() == n_eval_cams
if trajectory_type == "trefoil_knot":
traj = _trefoil_knot(time)
elif trajectory_type == "figure_eight_knot":
traj = _figure_eight_knot(time)
elif trajectory_type == "figure_eight":
traj = _figure_eight(time)
else:
raise ValueError(f"bad trajectory type: {trajectory_type}")
traj[:, 2] -= traj[:, 2].max()
# transform the canonical knot to the coord frame of the mean camera
mean_camera = PerspectiveCameras(
**{
k: getattr(train_cameras, k)[[int(mean_camera_idx)]]
for k in ("focal_length", "principal_point", "R", "T")
}
)
traj_trans = Scale(cam_centers.std(dim=0).mean() * trajectory_scale).compose(
mean_camera.get_world_to_view_transform().inverse()
)
if traj_offset_canonical is not None:
traj_trans = traj_trans.translate(
torch.FloatTensor(traj_offset_canonical)[None].to(traj)
)
traj = traj_trans.transform_points(traj)
plane_normal = _fit_plane(cam_centers)[:, 0]
if infer_up_as_plane_normal:
up = _disambiguate_normal(plane_normal, up)
elif trajectory_type == "circular_lsq_fit":
### fit plane to the camera centers
# get the center of the plane as the median of the camera centers
cam_centers = train_cameras.get_camera_center()
if time is not None:
angle = time
else:
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams).to(cam_centers)
fit = fit_circle_in_3d(
cam_centers,
angles=angle,
offset=angle.new_tensor(traj_offset_canonical)
if traj_offset_canonical is not None
else None,
up=angle.new_tensor(up),
)
traj = fit.generated_points
# scalethe trajectory
_t_mu = traj.mean(dim=0, keepdim=True)
traj = (traj - _t_mu) * trajectory_scale + _t_mu
plane_normal = fit.normal
if infer_up_as_plane_normal:
up = _disambiguate_normal(plane_normal, up)
else:
raise ValueError(f"Uknown trajectory_type {trajectory_type}.")
if traj_offset is not None:
traj = traj + torch.FloatTensor(traj_offset)[None].to(traj)
# point all cameras towards the center of the scene
R, T = look_at_view_transform(
eye=traj,
at=(scene_center,), # (1, 3)
up=(up,), # (1, 3)
device=traj.device,
)
# get the average focal length and principal point
if focal_length is None:
focal_length = train_cameras.focal_length.mean(dim=0).repeat(n_eval_cams, 1)
if principal_point is None:
principal_point = train_cameras.principal_point.mean(dim=0).repeat(
n_eval_cams, 1
)
test_cameras = PerspectiveCameras(
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
device=focal_length.device,
)
# _visdom_plot_scene(
# train_cameras,
# test_cameras,
# )
return test_cameras
def _disambiguate_normal(normal, up):
up_t = torch.tensor(up).to(normal)
flip = (up_t * normal).sum().sign()
up = normal * flip
up = up.tolist()
return up
def _fit_plane(x):
x = x - x.mean(dim=0)[None]
cov = (x.t() @ x) / x.shape[0]
_, e_vec = eigh(cov)
return e_vec
def _visdom_plot_scene(
train_cameras,
test_cameras,
) -> None:
from pytorch3d.vis.plotly_vis import plot_scene
p = plot_scene(
{
"scene": {
"train_cams": train_cameras,
"test_cams": test_cameras,
}
}
)
from visdom import Visdom
viz = Visdom()
viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs")
import pdb
pdb.set_trace()
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
x = (2 + (2 * t).cos()) * (3 * t).cos()
y = (2 + (2 * t).cos()) * (3 * t).sin()
z = (4 * t).sin() * z_scale
return torch.stack((x, y, z), dim=-1)
def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5):
x = t.sin() + 2 * (2 * t).sin()
y = t.cos() - 2 * (2 * t).cos()
z = -(3 * t).sin() * z_scale
return torch.stack((x, y, z), dim=-1)
def _figure_eight(t: torch.Tensor, z_scale: float = 0.5):
x = t.cos()
y = (2 * t).sin() / 2
z = t.sin() * z_scale
return torch.stack((x, y, z), dim=-1)